feat: update stun proxy logic

This commit is contained in:
Simon Ding
2025-05-08 16:13:30 +08:00
parent bb2c567da7
commit 992fa7ddd0
12 changed files with 310 additions and 73 deletions

View File

@@ -329,11 +329,11 @@ func (c *client) SaveDownloader(downloader *ent.DownloadClients) error {
count := c.ent.DownloadClients.Query().Where(downloadclients.Name(downloader.Name)).CountX(context.TODO())
if count != 0 {
err := c.ent.DownloadClients.Update().Where(downloadclients.Name(downloader.Name)).SetImplementation(downloader.Implementation).
SetURL(downloader.URL).SetUser(downloader.User).SetPassword(downloader.Password).SetPriority1(downloader.Priority1).Exec(context.TODO())
SetURL(downloader.URL).SetUser(downloader.User).SetUseNatTraversal(downloader.UseNatTraversal).SetPassword(downloader.Password).SetPriority1(downloader.Priority1).Exec(context.TODO())
return err
}
_, err := c.ent.DownloadClients.Create().SetEnable(true).SetImplementation(downloader.Implementation).
_, err := c.ent.DownloadClients.Create().SetEnable(true).SetImplementation(downloader.Implementation).SetUseNatTraversal(downloader.UseNatTraversal).
SetName(downloader.Name).SetURL(downloader.URL).SetUser(downloader.User).SetPriority1(downloader.Priority1).SetPassword(downloader.Password).Save(context.TODO())
return err
}

View File

@@ -5,7 +5,8 @@ import (
"polaris/ent/downloadclients"
"polaris/pkg/nat"
"polaris/pkg/qbittorrent"
"strconv"
"github.com/pion/stun/v3"
)
func (s *Engine) stunProxyDownloadClient() error {
@@ -13,6 +14,9 @@ func (s *Engine) stunProxyDownloadClient() error {
if err != nil {
return err
}
if !e.UseNatTraversal {
return nil
}
if e.Implementation != downloadclients.ImplementationQbittorrent {
return nil
}
@@ -20,22 +24,17 @@ func (s *Engine) stunProxyDownloadClient() error {
if !ok {
return nil
}
n, err := nat.NewNatTraversal()
if err != nil {
return err
}
addr, err := n.StunAddr()
if err != nil {
return err
}
err = d.SetListenPort(addr.Port)
if err != nil {
return err
}
u, err := url.Parse(d.URL)
if err != nil {
return err
}
return n.StartProxy(u.Hostname() + ":" + strconv.Itoa(addr.Port))
n, err := nat.NewNatTraversal(func(xa stun.XORMappedAddress) error {
return d.SetListenPort(xa.Port)
}, u.Hostname())
if err != nil {
return err
}
n.StartProxy()
return nil
}

View File

@@ -33,6 +33,8 @@ type DownloadClients struct {
Settings string `json:"settings,omitempty"`
// Priority1 holds the value of the "priority1" field.
Priority1 int `json:"priority1,omitempty"`
// use stun server to do nat traversal, enable download client to do uploading successfully
UseNatTraversal bool `json:"use_nat_traversal,omitempty"`
// RemoveCompletedDownloads holds the value of the "remove_completed_downloads" field.
RemoveCompletedDownloads bool `json:"remove_completed_downloads,omitempty"`
// RemoveFailedDownloads holds the value of the "remove_failed_downloads" field.
@@ -49,7 +51,7 @@ func (*DownloadClients) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case downloadclients.FieldEnable, downloadclients.FieldRemoveCompletedDownloads, downloadclients.FieldRemoveFailedDownloads:
case downloadclients.FieldEnable, downloadclients.FieldUseNatTraversal, downloadclients.FieldRemoveCompletedDownloads, downloadclients.FieldRemoveFailedDownloads:
values[i] = new(sql.NullBool)
case downloadclients.FieldID, downloadclients.FieldPriority1:
values[i] = new(sql.NullInt64)
@@ -126,6 +128,12 @@ func (dc *DownloadClients) assignValues(columns []string, values []any) error {
} else if value.Valid {
dc.Priority1 = int(value.Int64)
}
case downloadclients.FieldUseNatTraversal:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field use_nat_traversal", values[i])
} else if value.Valid {
dc.UseNatTraversal = value.Bool
}
case downloadclients.FieldRemoveCompletedDownloads:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field remove_completed_downloads", values[i])
@@ -210,6 +218,9 @@ func (dc *DownloadClients) String() string {
builder.WriteString("priority1=")
builder.WriteString(fmt.Sprintf("%v", dc.Priority1))
builder.WriteString(", ")
builder.WriteString("use_nat_traversal=")
builder.WriteString(fmt.Sprintf("%v", dc.UseNatTraversal))
builder.WriteString(", ")
builder.WriteString("remove_completed_downloads=")
builder.WriteString(fmt.Sprintf("%v", dc.RemoveCompletedDownloads))
builder.WriteString(", ")

View File

@@ -30,6 +30,8 @@ const (
FieldSettings = "settings"
// FieldPriority1 holds the string denoting the priority1 field in the database.
FieldPriority1 = "priority1"
// FieldUseNatTraversal holds the string denoting the use_nat_traversal field in the database.
FieldUseNatTraversal = "use_nat_traversal"
// FieldRemoveCompletedDownloads holds the string denoting the remove_completed_downloads field in the database.
FieldRemoveCompletedDownloads = "remove_completed_downloads"
// FieldRemoveFailedDownloads holds the string denoting the remove_failed_downloads field in the database.
@@ -53,6 +55,7 @@ var Columns = []string{
FieldPassword,
FieldSettings,
FieldPriority1,
FieldUseNatTraversal,
FieldRemoveCompletedDownloads,
FieldRemoveFailedDownloads,
FieldTags,
@@ -80,6 +83,8 @@ var (
DefaultPriority1 int
// Priority1Validator is a validator for the "priority1" field. It is called by the builders before save.
Priority1Validator func(int) error
// DefaultUseNatTraversal holds the default value on creation for the "use_nat_traversal" field.
DefaultUseNatTraversal bool
// DefaultRemoveCompletedDownloads holds the default value on creation for the "remove_completed_downloads" field.
DefaultRemoveCompletedDownloads bool
// DefaultRemoveFailedDownloads holds the default value on creation for the "remove_failed_downloads" field.
@@ -162,6 +167,11 @@ func ByPriority1(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority1, opts...).ToFunc()
}
// ByUseNatTraversal orders the results by the use_nat_traversal field.
func ByUseNatTraversal(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUseNatTraversal, opts...).ToFunc()
}
// ByRemoveCompletedDownloads orders the results by the remove_completed_downloads field.
func ByRemoveCompletedDownloads(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRemoveCompletedDownloads, opts...).ToFunc()

View File

@@ -89,6 +89,11 @@ func Priority1(v int) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldEQ(FieldPriority1, v))
}
// UseNatTraversal applies equality check predicate on the "use_nat_traversal" field. It's identical to UseNatTraversalEQ.
func UseNatTraversal(v bool) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldEQ(FieldUseNatTraversal, v))
}
// RemoveCompletedDownloads applies equality check predicate on the "remove_completed_downloads" field. It's identical to RemoveCompletedDownloadsEQ.
func RemoveCompletedDownloads(v bool) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldEQ(FieldRemoveCompletedDownloads, v))
@@ -504,6 +509,26 @@ func Priority1LTE(v int) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldLTE(FieldPriority1, v))
}
// UseNatTraversalEQ applies the EQ predicate on the "use_nat_traversal" field.
func UseNatTraversalEQ(v bool) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldEQ(FieldUseNatTraversal, v))
}
// UseNatTraversalNEQ applies the NEQ predicate on the "use_nat_traversal" field.
func UseNatTraversalNEQ(v bool) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldNEQ(FieldUseNatTraversal, v))
}
// UseNatTraversalIsNil applies the IsNil predicate on the "use_nat_traversal" field.
func UseNatTraversalIsNil() predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldIsNull(FieldUseNatTraversal))
}
// UseNatTraversalNotNil applies the NotNil predicate on the "use_nat_traversal" field.
func UseNatTraversalNotNil() predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldNotNull(FieldUseNatTraversal))
}
// RemoveCompletedDownloadsEQ applies the EQ predicate on the "remove_completed_downloads" field.
func RemoveCompletedDownloadsEQ(v bool) predicate.DownloadClients {
return predicate.DownloadClients(sql.FieldEQ(FieldRemoveCompletedDownloads, v))

View File

@@ -100,6 +100,20 @@ func (dcc *DownloadClientsCreate) SetNillablePriority1(i *int) *DownloadClientsC
return dcc
}
// SetUseNatTraversal sets the "use_nat_traversal" field.
func (dcc *DownloadClientsCreate) SetUseNatTraversal(b bool) *DownloadClientsCreate {
dcc.mutation.SetUseNatTraversal(b)
return dcc
}
// SetNillableUseNatTraversal sets the "use_nat_traversal" field if the given value is not nil.
func (dcc *DownloadClientsCreate) SetNillableUseNatTraversal(b *bool) *DownloadClientsCreate {
if b != nil {
dcc.SetUseNatTraversal(*b)
}
return dcc
}
// SetRemoveCompletedDownloads sets the "remove_completed_downloads" field.
func (dcc *DownloadClientsCreate) SetRemoveCompletedDownloads(b bool) *DownloadClientsCreate {
dcc.mutation.SetRemoveCompletedDownloads(b)
@@ -207,6 +221,10 @@ func (dcc *DownloadClientsCreate) defaults() {
v := downloadclients.DefaultPriority1
dcc.mutation.SetPriority1(v)
}
if _, ok := dcc.mutation.UseNatTraversal(); !ok {
v := downloadclients.DefaultUseNatTraversal
dcc.mutation.SetUseNatTraversal(v)
}
if _, ok := dcc.mutation.RemoveCompletedDownloads(); !ok {
v := downloadclients.DefaultRemoveCompletedDownloads
dcc.mutation.SetRemoveCompletedDownloads(v)
@@ -328,6 +346,10 @@ func (dcc *DownloadClientsCreate) createSpec() (*DownloadClients, *sqlgraph.Crea
_spec.SetField(downloadclients.FieldPriority1, field.TypeInt, value)
_node.Priority1 = value
}
if value, ok := dcc.mutation.UseNatTraversal(); ok {
_spec.SetField(downloadclients.FieldUseNatTraversal, field.TypeBool, value)
_node.UseNatTraversal = value
}
if value, ok := dcc.mutation.RemoveCompletedDownloads(); ok {
_spec.SetField(downloadclients.FieldRemoveCompletedDownloads, field.TypeBool, value)
_node.RemoveCompletedDownloads = value

View File

@@ -146,6 +146,26 @@ func (dcu *DownloadClientsUpdate) AddPriority1(i int) *DownloadClientsUpdate {
return dcu
}
// SetUseNatTraversal sets the "use_nat_traversal" field.
func (dcu *DownloadClientsUpdate) SetUseNatTraversal(b bool) *DownloadClientsUpdate {
dcu.mutation.SetUseNatTraversal(b)
return dcu
}
// SetNillableUseNatTraversal sets the "use_nat_traversal" field if the given value is not nil.
func (dcu *DownloadClientsUpdate) SetNillableUseNatTraversal(b *bool) *DownloadClientsUpdate {
if b != nil {
dcu.SetUseNatTraversal(*b)
}
return dcu
}
// ClearUseNatTraversal clears the value of the "use_nat_traversal" field.
func (dcu *DownloadClientsUpdate) ClearUseNatTraversal() *DownloadClientsUpdate {
dcu.mutation.ClearUseNatTraversal()
return dcu
}
// SetRemoveCompletedDownloads sets the "remove_completed_downloads" field.
func (dcu *DownloadClientsUpdate) SetRemoveCompletedDownloads(b bool) *DownloadClientsUpdate {
dcu.mutation.SetRemoveCompletedDownloads(b)
@@ -274,6 +294,12 @@ func (dcu *DownloadClientsUpdate) sqlSave(ctx context.Context) (n int, err error
if value, ok := dcu.mutation.AddedPriority1(); ok {
_spec.AddField(downloadclients.FieldPriority1, field.TypeInt, value)
}
if value, ok := dcu.mutation.UseNatTraversal(); ok {
_spec.SetField(downloadclients.FieldUseNatTraversal, field.TypeBool, value)
}
if dcu.mutation.UseNatTraversalCleared() {
_spec.ClearField(downloadclients.FieldUseNatTraversal, field.TypeBool)
}
if value, ok := dcu.mutation.RemoveCompletedDownloads(); ok {
_spec.SetField(downloadclients.FieldRemoveCompletedDownloads, field.TypeBool, value)
}
@@ -425,6 +451,26 @@ func (dcuo *DownloadClientsUpdateOne) AddPriority1(i int) *DownloadClientsUpdate
return dcuo
}
// SetUseNatTraversal sets the "use_nat_traversal" field.
func (dcuo *DownloadClientsUpdateOne) SetUseNatTraversal(b bool) *DownloadClientsUpdateOne {
dcuo.mutation.SetUseNatTraversal(b)
return dcuo
}
// SetNillableUseNatTraversal sets the "use_nat_traversal" field if the given value is not nil.
func (dcuo *DownloadClientsUpdateOne) SetNillableUseNatTraversal(b *bool) *DownloadClientsUpdateOne {
if b != nil {
dcuo.SetUseNatTraversal(*b)
}
return dcuo
}
// ClearUseNatTraversal clears the value of the "use_nat_traversal" field.
func (dcuo *DownloadClientsUpdateOne) ClearUseNatTraversal() *DownloadClientsUpdateOne {
dcuo.mutation.ClearUseNatTraversal()
return dcuo
}
// SetRemoveCompletedDownloads sets the "remove_completed_downloads" field.
func (dcuo *DownloadClientsUpdateOne) SetRemoveCompletedDownloads(b bool) *DownloadClientsUpdateOne {
dcuo.mutation.SetRemoveCompletedDownloads(b)
@@ -583,6 +629,12 @@ func (dcuo *DownloadClientsUpdateOne) sqlSave(ctx context.Context) (_node *Downl
if value, ok := dcuo.mutation.AddedPriority1(); ok {
_spec.AddField(downloadclients.FieldPriority1, field.TypeInt, value)
}
if value, ok := dcuo.mutation.UseNatTraversal(); ok {
_spec.SetField(downloadclients.FieldUseNatTraversal, field.TypeBool, value)
}
if dcuo.mutation.UseNatTraversalCleared() {
_spec.ClearField(downloadclients.FieldUseNatTraversal, field.TypeBool)
}
if value, ok := dcuo.mutation.RemoveCompletedDownloads(); ok {
_spec.SetField(downloadclients.FieldRemoveCompletedDownloads, field.TypeBool, value)
}

View File

@@ -35,6 +35,7 @@ var (
{Name: "password", Type: field.TypeString, Default: ""},
{Name: "settings", Type: field.TypeString, Default: ""},
{Name: "priority1", Type: field.TypeInt, Default: 1},
{Name: "use_nat_traversal", Type: field.TypeBool, Nullable: true, Default: false},
{Name: "remove_completed_downloads", Type: field.TypeBool, Default: true},
{Name: "remove_failed_downloads", Type: field.TypeBool, Default: true},
{Name: "tags", Type: field.TypeString, Default: ""},

View File

@@ -792,6 +792,7 @@ type DownloadClientsMutation struct {
settings *string
priority1 *int
addpriority1 *int
use_nat_traversal *bool
remove_completed_downloads *bool
remove_failed_downloads *bool
tags *string
@@ -1208,6 +1209,55 @@ func (m *DownloadClientsMutation) ResetPriority1() {
m.addpriority1 = nil
}
// SetUseNatTraversal sets the "use_nat_traversal" field.
func (m *DownloadClientsMutation) SetUseNatTraversal(b bool) {
m.use_nat_traversal = &b
}
// UseNatTraversal returns the value of the "use_nat_traversal" field in the mutation.
func (m *DownloadClientsMutation) UseNatTraversal() (r bool, exists bool) {
v := m.use_nat_traversal
if v == nil {
return
}
return *v, true
}
// OldUseNatTraversal returns the old "use_nat_traversal" field's value of the DownloadClients entity.
// If the DownloadClients object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *DownloadClientsMutation) OldUseNatTraversal(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUseNatTraversal is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUseNatTraversal requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUseNatTraversal: %w", err)
}
return oldValue.UseNatTraversal, nil
}
// ClearUseNatTraversal clears the value of the "use_nat_traversal" field.
func (m *DownloadClientsMutation) ClearUseNatTraversal() {
m.use_nat_traversal = nil
m.clearedFields[downloadclients.FieldUseNatTraversal] = struct{}{}
}
// UseNatTraversalCleared returns if the "use_nat_traversal" field was cleared in this mutation.
func (m *DownloadClientsMutation) UseNatTraversalCleared() bool {
_, ok := m.clearedFields[downloadclients.FieldUseNatTraversal]
return ok
}
// ResetUseNatTraversal resets all changes to the "use_nat_traversal" field.
func (m *DownloadClientsMutation) ResetUseNatTraversal() {
m.use_nat_traversal = nil
delete(m.clearedFields, downloadclients.FieldUseNatTraversal)
}
// SetRemoveCompletedDownloads sets the "remove_completed_downloads" field.
func (m *DownloadClientsMutation) SetRemoveCompletedDownloads(b bool) {
m.remove_completed_downloads = &b
@@ -1399,7 +1449,7 @@ func (m *DownloadClientsMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *DownloadClientsMutation) Fields() []string {
fields := make([]string, 0, 12)
fields := make([]string, 0, 13)
if m.enable != nil {
fields = append(fields, downloadclients.FieldEnable)
}
@@ -1424,6 +1474,9 @@ func (m *DownloadClientsMutation) Fields() []string {
if m.priority1 != nil {
fields = append(fields, downloadclients.FieldPriority1)
}
if m.use_nat_traversal != nil {
fields = append(fields, downloadclients.FieldUseNatTraversal)
}
if m.remove_completed_downloads != nil {
fields = append(fields, downloadclients.FieldRemoveCompletedDownloads)
}
@@ -1460,6 +1513,8 @@ func (m *DownloadClientsMutation) Field(name string) (ent.Value, bool) {
return m.Settings()
case downloadclients.FieldPriority1:
return m.Priority1()
case downloadclients.FieldUseNatTraversal:
return m.UseNatTraversal()
case downloadclients.FieldRemoveCompletedDownloads:
return m.RemoveCompletedDownloads()
case downloadclients.FieldRemoveFailedDownloads:
@@ -1493,6 +1548,8 @@ func (m *DownloadClientsMutation) OldField(ctx context.Context, name string) (en
return m.OldSettings(ctx)
case downloadclients.FieldPriority1:
return m.OldPriority1(ctx)
case downloadclients.FieldUseNatTraversal:
return m.OldUseNatTraversal(ctx)
case downloadclients.FieldRemoveCompletedDownloads:
return m.OldRemoveCompletedDownloads(ctx)
case downloadclients.FieldRemoveFailedDownloads:
@@ -1566,6 +1623,13 @@ func (m *DownloadClientsMutation) SetField(name string, value ent.Value) error {
}
m.SetPriority1(v)
return nil
case downloadclients.FieldUseNatTraversal:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUseNatTraversal(v)
return nil
case downloadclients.FieldRemoveCompletedDownloads:
v, ok := value.(bool)
if !ok {
@@ -1639,6 +1703,9 @@ func (m *DownloadClientsMutation) AddField(name string, value ent.Value) error {
// mutation.
func (m *DownloadClientsMutation) ClearedFields() []string {
var fields []string
if m.FieldCleared(downloadclients.FieldUseNatTraversal) {
fields = append(fields, downloadclients.FieldUseNatTraversal)
}
if m.FieldCleared(downloadclients.FieldCreateTime) {
fields = append(fields, downloadclients.FieldCreateTime)
}
@@ -1656,6 +1723,9 @@ func (m *DownloadClientsMutation) FieldCleared(name string) bool {
// error if the field is not defined in the schema.
func (m *DownloadClientsMutation) ClearField(name string) error {
switch name {
case downloadclients.FieldUseNatTraversal:
m.ClearUseNatTraversal()
return nil
case downloadclients.FieldCreateTime:
m.ClearCreateTime()
return nil
@@ -1691,6 +1761,9 @@ func (m *DownloadClientsMutation) ResetField(name string) error {
case downloadclients.FieldPriority1:
m.ResetPriority1()
return nil
case downloadclients.FieldUseNatTraversal:
m.ResetUseNatTraversal()
return nil
case downloadclients.FieldRemoveCompletedDownloads:
m.ResetRemoveCompletedDownloads()
return nil

View File

@@ -45,20 +45,24 @@ func init() {
downloadclients.DefaultPriority1 = downloadclientsDescPriority1.Default.(int)
// downloadclients.Priority1Validator is a validator for the "priority1" field. It is called by the builders before save.
downloadclients.Priority1Validator = downloadclientsDescPriority1.Validators[0].(func(int) error)
// downloadclientsDescUseNatTraversal is the schema descriptor for use_nat_traversal field.
downloadclientsDescUseNatTraversal := downloadclientsFields[8].Descriptor()
// downloadclients.DefaultUseNatTraversal holds the default value on creation for the use_nat_traversal field.
downloadclients.DefaultUseNatTraversal = downloadclientsDescUseNatTraversal.Default.(bool)
// downloadclientsDescRemoveCompletedDownloads is the schema descriptor for remove_completed_downloads field.
downloadclientsDescRemoveCompletedDownloads := downloadclientsFields[8].Descriptor()
downloadclientsDescRemoveCompletedDownloads := downloadclientsFields[9].Descriptor()
// downloadclients.DefaultRemoveCompletedDownloads holds the default value on creation for the remove_completed_downloads field.
downloadclients.DefaultRemoveCompletedDownloads = downloadclientsDescRemoveCompletedDownloads.Default.(bool)
// downloadclientsDescRemoveFailedDownloads is the schema descriptor for remove_failed_downloads field.
downloadclientsDescRemoveFailedDownloads := downloadclientsFields[9].Descriptor()
downloadclientsDescRemoveFailedDownloads := downloadclientsFields[10].Descriptor()
// downloadclients.DefaultRemoveFailedDownloads holds the default value on creation for the remove_failed_downloads field.
downloadclients.DefaultRemoveFailedDownloads = downloadclientsDescRemoveFailedDownloads.Default.(bool)
// downloadclientsDescTags is the schema descriptor for tags field.
downloadclientsDescTags := downloadclientsFields[10].Descriptor()
downloadclientsDescTags := downloadclientsFields[11].Descriptor()
// downloadclients.DefaultTags holds the default value on creation for the tags field.
downloadclients.DefaultTags = downloadclientsDescTags.Default.(string)
// downloadclientsDescCreateTime is the schema descriptor for create_time field.
downloadclientsDescCreateTime := downloadclientsFields[11].Descriptor()
downloadclientsDescCreateTime := downloadclientsFields[12].Descriptor()
// downloadclients.DefaultCreateTime holds the default value on creation for the create_time field.
downloadclients.DefaultCreateTime = downloadclientsDescCreateTime.Default.(func() time.Time)
episodeFields := schema.Episode{}.Fields()

View File

@@ -32,6 +32,7 @@ func (DownloadClients) Fields() []ent.Field {
}
return nil
}),
field.Bool("use_nat_traversal").Optional().Default(false).Comment("use stun server to do nat traversal, enable download client to do uploading successfully"),
field.Bool("remove_completed_downloads").Default(true),
field.Bool("remove_failed_downloads").Default(true),
field.String("tags").Default(""),

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"polaris/log"
"time"
"github.com/pion/stun/v3"
)
@@ -15,7 +16,7 @@ const (
timeoutMillis = 500
)
func NewNatTraversal() (*NatTraversal, error) {
func NewNatTraversal(addrCallback func(stun.XORMappedAddress) error, targetHost string) (*NatTraversal, error) {
conn, err := net.ListenUDP(udp, nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
@@ -24,31 +25,57 @@ func NewNatTraversal() (*NatTraversal, error) {
log.Infof("Listening on %s", conn.LocalAddr())
messageChan := listen(conn)
s := &NatTraversal{
conn: conn,
messageChan: messageChan,
cancel: make(chan struct{}),
addrChan: make(chan stun.XORMappedAddress),
addrCallback: addrCallback,
targetHost: targetHost,
}
return &NatTraversal{
conn: conn,
messageChan: messageChan,
cancel: make(chan struct{}),
}, nil
go s.updateNatAddr()
return s, nil
}
type NatTraversal struct {
//peerAddr *net.UDPAddr
conn *net.UDPConn
messageChan <-chan []byte
addrChan chan stun.XORMappedAddress
cancel chan struct{}
stunAddr *stun.XORMappedAddress
stunAddr *stun.XORMappedAddress
addrCallback func(stun.XORMappedAddress) error
targetHost string
targetPort int
}
func (s *NatTraversal) Cancel() {
close(s.cancel)
s.conn.Close()
}
func (s *NatTraversal) updateNatAddr() {
for addr := range s.addrChan {
if s.stunAddr == nil || s.stunAddr.String() != addr.String() { //new address
log.Warnf("My public address: %s\n", addr)
if s.addrCallback != nil { //execute callback
if err := s.addrCallback(addr); err != nil {
log.Warnf("callback error: %v", err)
}
}
func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
s.targetPort = addr.Port
log.Infof("now proxy to target host: %s:%d", s.targetHost, s.targetPort)
s.stunAddr = &addr
}
}
}
func (s *NatTraversal) sendStunServerBindingMsg() error {
for _, srv := range getStunServers() {
log.Debugf("try to connect to stun server: %s", srv)
srvAddr, err := net.ResolveUDPAddr(udp, srv)
@@ -58,62 +85,74 @@ func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
}
err = sendBindingRequest(s.conn, srvAddr)
if err != nil {
return nil, fmt.Errorf("send binding request: %w", err)
}
message, ok := <-s.messageChan
if !ok {
log.Warnf("send binding request: %w", err)
continue
}
if stun.IsMessage(message) {
m := new(stun.Message)
m.Raw = message
decErr := m.Decode()
if decErr != nil {
log.Warnf("decode:", decErr)
break
}
var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil {
log.Warnf("getFrom:", getErr)
continue
}
if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() {
log.Warnf("My public address: %s\n", xorAddr)
s.stunAddr = &xorAddr
}
return &xorAddr, nil
}
return nil
}
return nil, fmt.Errorf("failed to get STUN address")
return fmt.Errorf("failed to get STUN address")
}
func (s *NatTraversal) StartProxy(targetAddr string) error {
log.Infof("Starting NAT traversal proxy to %s", targetAddr)
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
if err != nil {
return fmt.Errorf("resolve peeraddr: %w", err)
func (s *NatTraversal) getNatAddr(msg []byte) (*stun.XORMappedAddress, error) {
if !stun.IsMessage(msg) {
return nil, fmt.Errorf("not a stun message")
}
if s.stunAddr == nil {
addr, err := s.StunAddr()
if err != nil {
return fmt.Errorf("get STUN address: %w", err)
}
log.Infof("STUN address: %s", addr)
m := new(stun.Message)
m.Raw = msg
decErr := m.Decode()
if decErr != nil {
return nil, fmt.Errorf("decode: %w", decErr)
}
var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil {
return nil, fmt.Errorf("getFrom: %w", getErr)
}
s.addrChan <- xorAddr
return &xorAddr, nil
}
func (s *NatTraversal) StartProxy() {
tick := time.NewTicker(10 * time.Second)
go func() { //tcker message to check public ip and port
defer tick.Stop()
for {
select {
case <-s.cancel:
log.Infof("stun nat proxy cancelled")
return
case <-tick.C:
err := s.sendStunServerBindingMsg()
if err != nil {
log.Warnf("send stun server binding msg: %w", err)
}
}
}
}()
for {
select {
case <-s.cancel:
log.Infof("stun nat proxy cancelled")
return nil
return
case m := <-s.messageChan:
//log.Infof("Received message: %d", len(m))
send(m, s.conn, peerAddr)
if stun.IsMessage(m) {
s.getNatAddr(m)
} else {
peerAddr, err := net.ResolveUDPAddr(udp, fmt.Sprintf("%s:%d", s.targetHost, s.targetPort))
if err != nil {
log.Errorf("resolve peeraddr: %w", err)
continue
}
send(m, s.conn, peerAddr)
}
}
}
}