From 992fa7ddd0b467e0c590d2c5eee87697698ec067 Mon Sep 17 00:00:00 2001 From: Simon Ding Date: Thu, 8 May 2025 16:13:30 +0800 Subject: [PATCH] feat: update stun proxy logic --- db/db.go | 4 +- engine/stun.go | 27 +++-- ent/downloadclients.go | 13 ++- ent/downloadclients/downloadclients.go | 10 ++ ent/downloadclients/where.go | 25 +++++ ent/downloadclients_create.go | 22 ++++ ent/downloadclients_update.go | 52 +++++++++ ent/migrate/schema.go | 1 + ent/mutation.go | 75 ++++++++++++- ent/runtime.go | 12 ++- ent/schema/downloadclients.go | 1 + pkg/nat/traversal.go | 141 ++++++++++++++++--------- 12 files changed, 310 insertions(+), 73 deletions(-) diff --git a/db/db.go b/db/db.go index fc392f3..54cc45c 100644 --- a/db/db.go +++ b/db/db.go @@ -329,11 +329,11 @@ func (c *client) SaveDownloader(downloader *ent.DownloadClients) error { count := c.ent.DownloadClients.Query().Where(downloadclients.Name(downloader.Name)).CountX(context.TODO()) if count != 0 { err := c.ent.DownloadClients.Update().Where(downloadclients.Name(downloader.Name)).SetImplementation(downloader.Implementation). - SetURL(downloader.URL).SetUser(downloader.User).SetPassword(downloader.Password).SetPriority1(downloader.Priority1).Exec(context.TODO()) + SetURL(downloader.URL).SetUser(downloader.User).SetUseNatTraversal(downloader.UseNatTraversal).SetPassword(downloader.Password).SetPriority1(downloader.Priority1).Exec(context.TODO()) return err } - _, err := c.ent.DownloadClients.Create().SetEnable(true).SetImplementation(downloader.Implementation). + _, err := c.ent.DownloadClients.Create().SetEnable(true).SetImplementation(downloader.Implementation).SetUseNatTraversal(downloader.UseNatTraversal). SetName(downloader.Name).SetURL(downloader.URL).SetUser(downloader.User).SetPriority1(downloader.Priority1).SetPassword(downloader.Password).Save(context.TODO()) return err } diff --git a/engine/stun.go b/engine/stun.go index 0c451ce..c66dd75 100644 --- a/engine/stun.go +++ b/engine/stun.go @@ -5,7 +5,8 @@ import ( "polaris/ent/downloadclients" "polaris/pkg/nat" "polaris/pkg/qbittorrent" - "strconv" + + "github.com/pion/stun/v3" ) func (s *Engine) stunProxyDownloadClient() error { @@ -13,6 +14,9 @@ func (s *Engine) stunProxyDownloadClient() error { if err != nil { return err } + if !e.UseNatTraversal { + return nil + } if e.Implementation != downloadclients.ImplementationQbittorrent { return nil } @@ -20,22 +24,17 @@ func (s *Engine) stunProxyDownloadClient() error { if !ok { return nil } - n, err := nat.NewNatTraversal() - if err != nil { - return err - } - addr, err := n.StunAddr() - if err != nil { - return err - } - err = d.SetListenPort(addr.Port) - if err != nil { - return err - } u, err := url.Parse(d.URL) if err != nil { return err } - return n.StartProxy(u.Hostname() + ":" + strconv.Itoa(addr.Port)) + n, err := nat.NewNatTraversal(func(xa stun.XORMappedAddress) error { + return d.SetListenPort(xa.Port) + }, u.Hostname()) + if err != nil { + return err + } + n.StartProxy() + return nil } diff --git a/ent/downloadclients.go b/ent/downloadclients.go index e2f66fa..f4bade2 100644 --- a/ent/downloadclients.go +++ b/ent/downloadclients.go @@ -33,6 +33,8 @@ type DownloadClients struct { Settings string `json:"settings,omitempty"` // Priority1 holds the value of the "priority1" field. Priority1 int `json:"priority1,omitempty"` + // use stun server to do nat traversal, enable download client to do uploading successfully + UseNatTraversal bool `json:"use_nat_traversal,omitempty"` // RemoveCompletedDownloads holds the value of the "remove_completed_downloads" field. RemoveCompletedDownloads bool `json:"remove_completed_downloads,omitempty"` // RemoveFailedDownloads holds the value of the "remove_failed_downloads" field. @@ -49,7 +51,7 @@ func (*DownloadClients) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case downloadclients.FieldEnable, downloadclients.FieldRemoveCompletedDownloads, downloadclients.FieldRemoveFailedDownloads: + case downloadclients.FieldEnable, downloadclients.FieldUseNatTraversal, downloadclients.FieldRemoveCompletedDownloads, downloadclients.FieldRemoveFailedDownloads: values[i] = new(sql.NullBool) case downloadclients.FieldID, downloadclients.FieldPriority1: values[i] = new(sql.NullInt64) @@ -126,6 +128,12 @@ func (dc *DownloadClients) assignValues(columns []string, values []any) error { } else if value.Valid { dc.Priority1 = int(value.Int64) } + case downloadclients.FieldUseNatTraversal: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field use_nat_traversal", values[i]) + } else if value.Valid { + dc.UseNatTraversal = value.Bool + } case downloadclients.FieldRemoveCompletedDownloads: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field remove_completed_downloads", values[i]) @@ -210,6 +218,9 @@ func (dc *DownloadClients) String() string { builder.WriteString("priority1=") builder.WriteString(fmt.Sprintf("%v", dc.Priority1)) builder.WriteString(", ") + builder.WriteString("use_nat_traversal=") + builder.WriteString(fmt.Sprintf("%v", dc.UseNatTraversal)) + builder.WriteString(", ") builder.WriteString("remove_completed_downloads=") builder.WriteString(fmt.Sprintf("%v", dc.RemoveCompletedDownloads)) builder.WriteString(", ") diff --git a/ent/downloadclients/downloadclients.go b/ent/downloadclients/downloadclients.go index 08c9c89..1dd8fe2 100644 --- a/ent/downloadclients/downloadclients.go +++ b/ent/downloadclients/downloadclients.go @@ -30,6 +30,8 @@ const ( FieldSettings = "settings" // FieldPriority1 holds the string denoting the priority1 field in the database. FieldPriority1 = "priority1" + // FieldUseNatTraversal holds the string denoting the use_nat_traversal field in the database. + FieldUseNatTraversal = "use_nat_traversal" // FieldRemoveCompletedDownloads holds the string denoting the remove_completed_downloads field in the database. FieldRemoveCompletedDownloads = "remove_completed_downloads" // FieldRemoveFailedDownloads holds the string denoting the remove_failed_downloads field in the database. @@ -53,6 +55,7 @@ var Columns = []string{ FieldPassword, FieldSettings, FieldPriority1, + FieldUseNatTraversal, FieldRemoveCompletedDownloads, FieldRemoveFailedDownloads, FieldTags, @@ -80,6 +83,8 @@ var ( DefaultPriority1 int // Priority1Validator is a validator for the "priority1" field. It is called by the builders before save. Priority1Validator func(int) error + // DefaultUseNatTraversal holds the default value on creation for the "use_nat_traversal" field. + DefaultUseNatTraversal bool // DefaultRemoveCompletedDownloads holds the default value on creation for the "remove_completed_downloads" field. DefaultRemoveCompletedDownloads bool // DefaultRemoveFailedDownloads holds the default value on creation for the "remove_failed_downloads" field. @@ -162,6 +167,11 @@ func ByPriority1(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldPriority1, opts...).ToFunc() } +// ByUseNatTraversal orders the results by the use_nat_traversal field. +func ByUseNatTraversal(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUseNatTraversal, opts...).ToFunc() +} + // ByRemoveCompletedDownloads orders the results by the remove_completed_downloads field. func ByRemoveCompletedDownloads(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldRemoveCompletedDownloads, opts...).ToFunc() diff --git a/ent/downloadclients/where.go b/ent/downloadclients/where.go index 0827c79..8b55964 100644 --- a/ent/downloadclients/where.go +++ b/ent/downloadclients/where.go @@ -89,6 +89,11 @@ func Priority1(v int) predicate.DownloadClients { return predicate.DownloadClients(sql.FieldEQ(FieldPriority1, v)) } +// UseNatTraversal applies equality check predicate on the "use_nat_traversal" field. It's identical to UseNatTraversalEQ. +func UseNatTraversal(v bool) predicate.DownloadClients { + return predicate.DownloadClients(sql.FieldEQ(FieldUseNatTraversal, v)) +} + // RemoveCompletedDownloads applies equality check predicate on the "remove_completed_downloads" field. It's identical to RemoveCompletedDownloadsEQ. func RemoveCompletedDownloads(v bool) predicate.DownloadClients { return predicate.DownloadClients(sql.FieldEQ(FieldRemoveCompletedDownloads, v)) @@ -504,6 +509,26 @@ func Priority1LTE(v int) predicate.DownloadClients { return predicate.DownloadClients(sql.FieldLTE(FieldPriority1, v)) } +// UseNatTraversalEQ applies the EQ predicate on the "use_nat_traversal" field. +func UseNatTraversalEQ(v bool) predicate.DownloadClients { + return predicate.DownloadClients(sql.FieldEQ(FieldUseNatTraversal, v)) +} + +// UseNatTraversalNEQ applies the NEQ predicate on the "use_nat_traversal" field. +func UseNatTraversalNEQ(v bool) predicate.DownloadClients { + return predicate.DownloadClients(sql.FieldNEQ(FieldUseNatTraversal, v)) +} + +// UseNatTraversalIsNil applies the IsNil predicate on the "use_nat_traversal" field. +func UseNatTraversalIsNil() predicate.DownloadClients { + return predicate.DownloadClients(sql.FieldIsNull(FieldUseNatTraversal)) +} + +// UseNatTraversalNotNil applies the NotNil predicate on the "use_nat_traversal" field. +func UseNatTraversalNotNil() predicate.DownloadClients { + return predicate.DownloadClients(sql.FieldNotNull(FieldUseNatTraversal)) +} + // RemoveCompletedDownloadsEQ applies the EQ predicate on the "remove_completed_downloads" field. func RemoveCompletedDownloadsEQ(v bool) predicate.DownloadClients { return predicate.DownloadClients(sql.FieldEQ(FieldRemoveCompletedDownloads, v)) diff --git a/ent/downloadclients_create.go b/ent/downloadclients_create.go index 2091c67..368d50a 100644 --- a/ent/downloadclients_create.go +++ b/ent/downloadclients_create.go @@ -100,6 +100,20 @@ func (dcc *DownloadClientsCreate) SetNillablePriority1(i *int) *DownloadClientsC return dcc } +// SetUseNatTraversal sets the "use_nat_traversal" field. +func (dcc *DownloadClientsCreate) SetUseNatTraversal(b bool) *DownloadClientsCreate { + dcc.mutation.SetUseNatTraversal(b) + return dcc +} + +// SetNillableUseNatTraversal sets the "use_nat_traversal" field if the given value is not nil. +func (dcc *DownloadClientsCreate) SetNillableUseNatTraversal(b *bool) *DownloadClientsCreate { + if b != nil { + dcc.SetUseNatTraversal(*b) + } + return dcc +} + // SetRemoveCompletedDownloads sets the "remove_completed_downloads" field. func (dcc *DownloadClientsCreate) SetRemoveCompletedDownloads(b bool) *DownloadClientsCreate { dcc.mutation.SetRemoveCompletedDownloads(b) @@ -207,6 +221,10 @@ func (dcc *DownloadClientsCreate) defaults() { v := downloadclients.DefaultPriority1 dcc.mutation.SetPriority1(v) } + if _, ok := dcc.mutation.UseNatTraversal(); !ok { + v := downloadclients.DefaultUseNatTraversal + dcc.mutation.SetUseNatTraversal(v) + } if _, ok := dcc.mutation.RemoveCompletedDownloads(); !ok { v := downloadclients.DefaultRemoveCompletedDownloads dcc.mutation.SetRemoveCompletedDownloads(v) @@ -328,6 +346,10 @@ func (dcc *DownloadClientsCreate) createSpec() (*DownloadClients, *sqlgraph.Crea _spec.SetField(downloadclients.FieldPriority1, field.TypeInt, value) _node.Priority1 = value } + if value, ok := dcc.mutation.UseNatTraversal(); ok { + _spec.SetField(downloadclients.FieldUseNatTraversal, field.TypeBool, value) + _node.UseNatTraversal = value + } if value, ok := dcc.mutation.RemoveCompletedDownloads(); ok { _spec.SetField(downloadclients.FieldRemoveCompletedDownloads, field.TypeBool, value) _node.RemoveCompletedDownloads = value diff --git a/ent/downloadclients_update.go b/ent/downloadclients_update.go index 2df820f..b371abe 100644 --- a/ent/downloadclients_update.go +++ b/ent/downloadclients_update.go @@ -146,6 +146,26 @@ func (dcu *DownloadClientsUpdate) AddPriority1(i int) *DownloadClientsUpdate { return dcu } +// SetUseNatTraversal sets the "use_nat_traversal" field. +func (dcu *DownloadClientsUpdate) SetUseNatTraversal(b bool) *DownloadClientsUpdate { + dcu.mutation.SetUseNatTraversal(b) + return dcu +} + +// SetNillableUseNatTraversal sets the "use_nat_traversal" field if the given value is not nil. +func (dcu *DownloadClientsUpdate) SetNillableUseNatTraversal(b *bool) *DownloadClientsUpdate { + if b != nil { + dcu.SetUseNatTraversal(*b) + } + return dcu +} + +// ClearUseNatTraversal clears the value of the "use_nat_traversal" field. +func (dcu *DownloadClientsUpdate) ClearUseNatTraversal() *DownloadClientsUpdate { + dcu.mutation.ClearUseNatTraversal() + return dcu +} + // SetRemoveCompletedDownloads sets the "remove_completed_downloads" field. func (dcu *DownloadClientsUpdate) SetRemoveCompletedDownloads(b bool) *DownloadClientsUpdate { dcu.mutation.SetRemoveCompletedDownloads(b) @@ -274,6 +294,12 @@ func (dcu *DownloadClientsUpdate) sqlSave(ctx context.Context) (n int, err error if value, ok := dcu.mutation.AddedPriority1(); ok { _spec.AddField(downloadclients.FieldPriority1, field.TypeInt, value) } + if value, ok := dcu.mutation.UseNatTraversal(); ok { + _spec.SetField(downloadclients.FieldUseNatTraversal, field.TypeBool, value) + } + if dcu.mutation.UseNatTraversalCleared() { + _spec.ClearField(downloadclients.FieldUseNatTraversal, field.TypeBool) + } if value, ok := dcu.mutation.RemoveCompletedDownloads(); ok { _spec.SetField(downloadclients.FieldRemoveCompletedDownloads, field.TypeBool, value) } @@ -425,6 +451,26 @@ func (dcuo *DownloadClientsUpdateOne) AddPriority1(i int) *DownloadClientsUpdate return dcuo } +// SetUseNatTraversal sets the "use_nat_traversal" field. +func (dcuo *DownloadClientsUpdateOne) SetUseNatTraversal(b bool) *DownloadClientsUpdateOne { + dcuo.mutation.SetUseNatTraversal(b) + return dcuo +} + +// SetNillableUseNatTraversal sets the "use_nat_traversal" field if the given value is not nil. +func (dcuo *DownloadClientsUpdateOne) SetNillableUseNatTraversal(b *bool) *DownloadClientsUpdateOne { + if b != nil { + dcuo.SetUseNatTraversal(*b) + } + return dcuo +} + +// ClearUseNatTraversal clears the value of the "use_nat_traversal" field. +func (dcuo *DownloadClientsUpdateOne) ClearUseNatTraversal() *DownloadClientsUpdateOne { + dcuo.mutation.ClearUseNatTraversal() + return dcuo +} + // SetRemoveCompletedDownloads sets the "remove_completed_downloads" field. func (dcuo *DownloadClientsUpdateOne) SetRemoveCompletedDownloads(b bool) *DownloadClientsUpdateOne { dcuo.mutation.SetRemoveCompletedDownloads(b) @@ -583,6 +629,12 @@ func (dcuo *DownloadClientsUpdateOne) sqlSave(ctx context.Context) (_node *Downl if value, ok := dcuo.mutation.AddedPriority1(); ok { _spec.AddField(downloadclients.FieldPriority1, field.TypeInt, value) } + if value, ok := dcuo.mutation.UseNatTraversal(); ok { + _spec.SetField(downloadclients.FieldUseNatTraversal, field.TypeBool, value) + } + if dcuo.mutation.UseNatTraversalCleared() { + _spec.ClearField(downloadclients.FieldUseNatTraversal, field.TypeBool) + } if value, ok := dcuo.mutation.RemoveCompletedDownloads(); ok { _spec.SetField(downloadclients.FieldRemoveCompletedDownloads, field.TypeBool, value) } diff --git a/ent/migrate/schema.go b/ent/migrate/schema.go index ea5b560..504a032 100644 --- a/ent/migrate/schema.go +++ b/ent/migrate/schema.go @@ -35,6 +35,7 @@ var ( {Name: "password", Type: field.TypeString, Default: ""}, {Name: "settings", Type: field.TypeString, Default: ""}, {Name: "priority1", Type: field.TypeInt, Default: 1}, + {Name: "use_nat_traversal", Type: field.TypeBool, Nullable: true, Default: false}, {Name: "remove_completed_downloads", Type: field.TypeBool, Default: true}, {Name: "remove_failed_downloads", Type: field.TypeBool, Default: true}, {Name: "tags", Type: field.TypeString, Default: ""}, diff --git a/ent/mutation.go b/ent/mutation.go index 6e9cd5d..4071032 100644 --- a/ent/mutation.go +++ b/ent/mutation.go @@ -792,6 +792,7 @@ type DownloadClientsMutation struct { settings *string priority1 *int addpriority1 *int + use_nat_traversal *bool remove_completed_downloads *bool remove_failed_downloads *bool tags *string @@ -1208,6 +1209,55 @@ func (m *DownloadClientsMutation) ResetPriority1() { m.addpriority1 = nil } +// SetUseNatTraversal sets the "use_nat_traversal" field. +func (m *DownloadClientsMutation) SetUseNatTraversal(b bool) { + m.use_nat_traversal = &b +} + +// UseNatTraversal returns the value of the "use_nat_traversal" field in the mutation. +func (m *DownloadClientsMutation) UseNatTraversal() (r bool, exists bool) { + v := m.use_nat_traversal + if v == nil { + return + } + return *v, true +} + +// OldUseNatTraversal returns the old "use_nat_traversal" field's value of the DownloadClients entity. +// If the DownloadClients object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DownloadClientsMutation) OldUseNatTraversal(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUseNatTraversal is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUseNatTraversal requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUseNatTraversal: %w", err) + } + return oldValue.UseNatTraversal, nil +} + +// ClearUseNatTraversal clears the value of the "use_nat_traversal" field. +func (m *DownloadClientsMutation) ClearUseNatTraversal() { + m.use_nat_traversal = nil + m.clearedFields[downloadclients.FieldUseNatTraversal] = struct{}{} +} + +// UseNatTraversalCleared returns if the "use_nat_traversal" field was cleared in this mutation. +func (m *DownloadClientsMutation) UseNatTraversalCleared() bool { + _, ok := m.clearedFields[downloadclients.FieldUseNatTraversal] + return ok +} + +// ResetUseNatTraversal resets all changes to the "use_nat_traversal" field. +func (m *DownloadClientsMutation) ResetUseNatTraversal() { + m.use_nat_traversal = nil + delete(m.clearedFields, downloadclients.FieldUseNatTraversal) +} + // SetRemoveCompletedDownloads sets the "remove_completed_downloads" field. func (m *DownloadClientsMutation) SetRemoveCompletedDownloads(b bool) { m.remove_completed_downloads = &b @@ -1399,7 +1449,7 @@ func (m *DownloadClientsMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *DownloadClientsMutation) Fields() []string { - fields := make([]string, 0, 12) + fields := make([]string, 0, 13) if m.enable != nil { fields = append(fields, downloadclients.FieldEnable) } @@ -1424,6 +1474,9 @@ func (m *DownloadClientsMutation) Fields() []string { if m.priority1 != nil { fields = append(fields, downloadclients.FieldPriority1) } + if m.use_nat_traversal != nil { + fields = append(fields, downloadclients.FieldUseNatTraversal) + } if m.remove_completed_downloads != nil { fields = append(fields, downloadclients.FieldRemoveCompletedDownloads) } @@ -1460,6 +1513,8 @@ func (m *DownloadClientsMutation) Field(name string) (ent.Value, bool) { return m.Settings() case downloadclients.FieldPriority1: return m.Priority1() + case downloadclients.FieldUseNatTraversal: + return m.UseNatTraversal() case downloadclients.FieldRemoveCompletedDownloads: return m.RemoveCompletedDownloads() case downloadclients.FieldRemoveFailedDownloads: @@ -1493,6 +1548,8 @@ func (m *DownloadClientsMutation) OldField(ctx context.Context, name string) (en return m.OldSettings(ctx) case downloadclients.FieldPriority1: return m.OldPriority1(ctx) + case downloadclients.FieldUseNatTraversal: + return m.OldUseNatTraversal(ctx) case downloadclients.FieldRemoveCompletedDownloads: return m.OldRemoveCompletedDownloads(ctx) case downloadclients.FieldRemoveFailedDownloads: @@ -1566,6 +1623,13 @@ func (m *DownloadClientsMutation) SetField(name string, value ent.Value) error { } m.SetPriority1(v) return nil + case downloadclients.FieldUseNatTraversal: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUseNatTraversal(v) + return nil case downloadclients.FieldRemoveCompletedDownloads: v, ok := value.(bool) if !ok { @@ -1639,6 +1703,9 @@ func (m *DownloadClientsMutation) AddField(name string, value ent.Value) error { // mutation. func (m *DownloadClientsMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(downloadclients.FieldUseNatTraversal) { + fields = append(fields, downloadclients.FieldUseNatTraversal) + } if m.FieldCleared(downloadclients.FieldCreateTime) { fields = append(fields, downloadclients.FieldCreateTime) } @@ -1656,6 +1723,9 @@ func (m *DownloadClientsMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *DownloadClientsMutation) ClearField(name string) error { switch name { + case downloadclients.FieldUseNatTraversal: + m.ClearUseNatTraversal() + return nil case downloadclients.FieldCreateTime: m.ClearCreateTime() return nil @@ -1691,6 +1761,9 @@ func (m *DownloadClientsMutation) ResetField(name string) error { case downloadclients.FieldPriority1: m.ResetPriority1() return nil + case downloadclients.FieldUseNatTraversal: + m.ResetUseNatTraversal() + return nil case downloadclients.FieldRemoveCompletedDownloads: m.ResetRemoveCompletedDownloads() return nil diff --git a/ent/runtime.go b/ent/runtime.go index d304f2f..00a8996 100644 --- a/ent/runtime.go +++ b/ent/runtime.go @@ -45,20 +45,24 @@ func init() { downloadclients.DefaultPriority1 = downloadclientsDescPriority1.Default.(int) // downloadclients.Priority1Validator is a validator for the "priority1" field. It is called by the builders before save. downloadclients.Priority1Validator = downloadclientsDescPriority1.Validators[0].(func(int) error) + // downloadclientsDescUseNatTraversal is the schema descriptor for use_nat_traversal field. + downloadclientsDescUseNatTraversal := downloadclientsFields[8].Descriptor() + // downloadclients.DefaultUseNatTraversal holds the default value on creation for the use_nat_traversal field. + downloadclients.DefaultUseNatTraversal = downloadclientsDescUseNatTraversal.Default.(bool) // downloadclientsDescRemoveCompletedDownloads is the schema descriptor for remove_completed_downloads field. - downloadclientsDescRemoveCompletedDownloads := downloadclientsFields[8].Descriptor() + downloadclientsDescRemoveCompletedDownloads := downloadclientsFields[9].Descriptor() // downloadclients.DefaultRemoveCompletedDownloads holds the default value on creation for the remove_completed_downloads field. downloadclients.DefaultRemoveCompletedDownloads = downloadclientsDescRemoveCompletedDownloads.Default.(bool) // downloadclientsDescRemoveFailedDownloads is the schema descriptor for remove_failed_downloads field. - downloadclientsDescRemoveFailedDownloads := downloadclientsFields[9].Descriptor() + downloadclientsDescRemoveFailedDownloads := downloadclientsFields[10].Descriptor() // downloadclients.DefaultRemoveFailedDownloads holds the default value on creation for the remove_failed_downloads field. downloadclients.DefaultRemoveFailedDownloads = downloadclientsDescRemoveFailedDownloads.Default.(bool) // downloadclientsDescTags is the schema descriptor for tags field. - downloadclientsDescTags := downloadclientsFields[10].Descriptor() + downloadclientsDescTags := downloadclientsFields[11].Descriptor() // downloadclients.DefaultTags holds the default value on creation for the tags field. downloadclients.DefaultTags = downloadclientsDescTags.Default.(string) // downloadclientsDescCreateTime is the schema descriptor for create_time field. - downloadclientsDescCreateTime := downloadclientsFields[11].Descriptor() + downloadclientsDescCreateTime := downloadclientsFields[12].Descriptor() // downloadclients.DefaultCreateTime holds the default value on creation for the create_time field. downloadclients.DefaultCreateTime = downloadclientsDescCreateTime.Default.(func() time.Time) episodeFields := schema.Episode{}.Fields() diff --git a/ent/schema/downloadclients.go b/ent/schema/downloadclients.go index 9648ae2..de3bcd3 100644 --- a/ent/schema/downloadclients.go +++ b/ent/schema/downloadclients.go @@ -32,6 +32,7 @@ func (DownloadClients) Fields() []ent.Field { } return nil }), + field.Bool("use_nat_traversal").Optional().Default(false).Comment("use stun server to do nat traversal, enable download client to do uploading successfully"), field.Bool("remove_completed_downloads").Default(true), field.Bool("remove_failed_downloads").Default(true), field.String("tags").Default(""), diff --git a/pkg/nat/traversal.go b/pkg/nat/traversal.go index 84c651b..5497b11 100644 --- a/pkg/nat/traversal.go +++ b/pkg/nat/traversal.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "polaris/log" + "time" "github.com/pion/stun/v3" ) @@ -15,7 +16,7 @@ const ( timeoutMillis = 500 ) -func NewNatTraversal() (*NatTraversal, error) { +func NewNatTraversal(addrCallback func(stun.XORMappedAddress) error, targetHost string) (*NatTraversal, error) { conn, err := net.ListenUDP(udp, nil) if err != nil { return nil, fmt.Errorf("listen: %w", err) @@ -24,31 +25,57 @@ func NewNatTraversal() (*NatTraversal, error) { log.Infof("Listening on %s", conn.LocalAddr()) messageChan := listen(conn) + s := &NatTraversal{ + conn: conn, + messageChan: messageChan, + cancel: make(chan struct{}), + addrChan: make(chan stun.XORMappedAddress), + addrCallback: addrCallback, + targetHost: targetHost, + } - return &NatTraversal{ - conn: conn, - messageChan: messageChan, - cancel: make(chan struct{}), - }, nil + go s.updateNatAddr() + + return s, nil } type NatTraversal struct { //peerAddr *net.UDPAddr conn *net.UDPConn messageChan <-chan []byte + addrChan chan stun.XORMappedAddress cancel chan struct{} - stunAddr *stun.XORMappedAddress + stunAddr *stun.XORMappedAddress + addrCallback func(stun.XORMappedAddress) error + targetHost string + targetPort int } func (s *NatTraversal) Cancel() { - + close(s.cancel) s.conn.Close() } +func (s *NatTraversal) updateNatAddr() { + for addr := range s.addrChan { + if s.stunAddr == nil || s.stunAddr.String() != addr.String() { //new address + log.Warnf("My public address: %s\n", addr) + if s.addrCallback != nil { //execute callback + if err := s.addrCallback(addr); err != nil { + log.Warnf("callback error: %v", err) + } + } -func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) { + s.targetPort = addr.Port + log.Infof("now proxy to target host: %s:%d", s.targetHost, s.targetPort) + s.stunAddr = &addr + } + } +} + +func (s *NatTraversal) sendStunServerBindingMsg() error { for _, srv := range getStunServers() { log.Debugf("try to connect to stun server: %s", srv) srvAddr, err := net.ResolveUDPAddr(udp, srv) @@ -58,62 +85,74 @@ func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) { } err = sendBindingRequest(s.conn, srvAddr) if err != nil { - return nil, fmt.Errorf("send binding request: %w", err) - } - message, ok := <-s.messageChan - if !ok { + log.Warnf("send binding request: %w", err) continue } - if stun.IsMessage(message) { - m := new(stun.Message) - m.Raw = message - decErr := m.Decode() - if decErr != nil { - log.Warnf("decode:", decErr) - - break - } - var xorAddr stun.XORMappedAddress - if getErr := xorAddr.GetFrom(m); getErr != nil { - log.Warnf("getFrom:", getErr) - - continue - } - if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() { - log.Warnf("My public address: %s\n", xorAddr) - s.stunAddr = &xorAddr - } - return &xorAddr, nil - } - + return nil } - return nil, fmt.Errorf("failed to get STUN address") + return fmt.Errorf("failed to get STUN address") } -func (s *NatTraversal) StartProxy(targetAddr string) error { - log.Infof("Starting NAT traversal proxy to %s", targetAddr) - peerAddr, err := net.ResolveUDPAddr(udp, targetAddr) - if err != nil { - return fmt.Errorf("resolve peeraddr: %w", err) +func (s *NatTraversal) getNatAddr(msg []byte) (*stun.XORMappedAddress, error) { + if !stun.IsMessage(msg) { + return nil, fmt.Errorf("not a stun message") } - if s.stunAddr == nil { - addr, err := s.StunAddr() - if err != nil { - return fmt.Errorf("get STUN address: %w", err) - } - log.Infof("STUN address: %s", addr) + m := new(stun.Message) + m.Raw = msg + decErr := m.Decode() + if decErr != nil { + return nil, fmt.Errorf("decode: %w", decErr) } + var xorAddr stun.XORMappedAddress + if getErr := xorAddr.GetFrom(m); getErr != nil { + return nil, fmt.Errorf("getFrom: %w", getErr) + } + s.addrChan <- xorAddr + + return &xorAddr, nil + +} + +func (s *NatTraversal) StartProxy() { + + tick := time.NewTicker(10 * time.Second) + + go func() { //tcker message to check public ip and port + defer tick.Stop() + for { + select { + case <-s.cancel: + log.Infof("stun nat proxy cancelled") + return + case <-tick.C: + err := s.sendStunServerBindingMsg() + if err != nil { + log.Warnf("send stun server binding msg: %w", err) + } + } + } + }() + for { select { case <-s.cancel: log.Infof("stun nat proxy cancelled") - return nil + return case m := <-s.messageChan: - //log.Infof("Received message: %d", len(m)) - send(m, s.conn, peerAddr) + if stun.IsMessage(m) { + s.getNatAddr(m) + } else { + peerAddr, err := net.ResolveUDPAddr(udp, fmt.Sprintf("%s:%d", s.targetHost, s.targetPort)) + if err != nil { + log.Errorf("resolve peeraddr: %w", err) + continue + } + + send(m, s.conn, peerAddr) + } } - + } }