diff --git a/engine/client.go b/engine/client.go index 5f4a3e6..61c3fe0 100644 --- a/engine/client.go +++ b/engine/client.go @@ -11,6 +11,7 @@ import ( "polaris/pkg/tmdb" "polaris/pkg/transmission" "polaris/pkg/utils" + "time" "github.com/pkg/errors" "github.com/robfig/cron" @@ -50,6 +51,12 @@ func (c *Engine) Init() { go c.reloadTasks() c.addSysCron() go c.checkW500PosterOnStartup() + go func() { + time.Sleep(10*time.Second) + if err := c.stunProxyDownloadClient(); err != nil { + log.Errorf("stun proxy error: %v", err) + } + }() } func (c *Engine) GetTask(id int) (*Task, bool) { diff --git a/engine/stun.go b/engine/stun.go new file mode 100644 index 0000000..0c451ce --- /dev/null +++ b/engine/stun.go @@ -0,0 +1,41 @@ +package engine + +import ( + "net/url" + "polaris/ent/downloadclients" + "polaris/pkg/nat" + "polaris/pkg/qbittorrent" + "strconv" +) + +func (s *Engine) stunProxyDownloadClient() error { + downloader, e, err := s.GetDownloadClient() + if err != nil { + return err + } + if e.Implementation != downloadclients.ImplementationQbittorrent { + return nil + } + d, ok := downloader.(*qbittorrent.Client) + if !ok { + return nil + } + n, err := nat.NewNatTraversal() + if err != nil { + return err + } + addr, err := n.StunAddr() + if err != nil { + return err + } + err = d.SetListenPort(addr.Port) + if err != nil { + return err + } + u, err := url.Parse(d.URL) + if err != nil { + return err + } + + return n.StartProxy(u.Hostname() + ":" + strconv.Itoa(addr.Port)) +} diff --git a/pkg/nat/traversal.go b/pkg/nat/traversal.go index b69a4f0..22d46fa 100644 --- a/pkg/nat/traversal.go +++ b/pkg/nat/traversal.go @@ -2,9 +2,8 @@ package nat import ( "fmt" - "log" "net" - "time" + "polaris/log" "github.com/pion/stun/v3" ) @@ -16,139 +15,109 @@ const ( timeoutMillis = 500 ) -type natTraversal struct { - peerAddr *net.UDPAddr - cancel chan struct{} - port <-chan int -} - -func (s *natTraversal) Port() int { - return <-s.port -} - -func (s *natTraversal) Cancel() { - s.cancel <- struct{}{} -} - -func NatTraversal(targetAddr string) (*natTraversal, error) { //nolint:gocognit,cyclop - - srvAddr, err := net.ResolveUDPAddr(udp, getStunServers()[0]) - if err != nil { - log.Fatalf("Failed to resolve server addr: %s", err) - } - +func NewNatTraversal() (*NatTraversal, error) { conn, err := net.ListenUDP(udp, nil) if err != nil { return nil, fmt.Errorf("listen: %w", err) } - log.Printf("Listening on %s", conn.LocalAddr()) - - peerAddr, err := net.ResolveUDPAddr(udp, targetAddr) - if err != nil { - return nil, fmt.Errorf("resolve peeraddr: %w", err) - } - err = sendBindingRequest(conn, srvAddr) - if err != nil { - return nil, fmt.Errorf("send binding request: %w", err) - } - nt := &natTraversal{ - peerAddr: peerAddr, - cancel: make(chan struct{}), - port: make(chan int), - } - go func() { - err := doTraversal(conn, peerAddr, nt.cancel) - if err != nil { - log.Println("nat traversal error:", err) - } - }() - return nt, nil - -} - -func doTraversal(conn *net.UDPConn, peerAddr *net.UDPAddr, quit <-chan struct{}) error { - defer func() { - _ = conn.Close() - }() - - var publicAddr stun.XORMappedAddress + log.Infof("Listening on %s", conn.LocalAddr()) messageChan := listen(conn) - //var peerAddrChan <-chan string - keepalive := time.Tick(timeoutMillis * time.Millisecond) - keepaliveMsg := pingMsg + return &NatTraversal{ + conn: conn, + messageChan: messageChan, + cancel: make(chan struct{}), + }, nil +} - gotPong := false - sentPong := false +type NatTraversal struct { + //peerAddr *net.UDPAddr + conn *net.UDPConn + messageChan <-chan []byte + cancel chan struct{} - for { + stunAddr *stun.XORMappedAddress +} + +func (s *NatTraversal) Cancel() { + + close(s.cancel) + s.conn.Close() +} + + +func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) { + for _, srv := range getStunServers() { + log.Debugf("try to connect to stun server: %s", srv) + srvAddr, err := net.ResolveUDPAddr(udp, srv) + if err != nil { + log.Warnf("Failed to resolve server addr: %s", err) + continue + } + err = sendBindingRequest(s.conn, srvAddr) + if err != nil { + return nil, fmt.Errorf("send binding request: %w", err) + } select { - case message, ok := <-messageChan: + case message, ok := <-s.messageChan: if !ok { - return nil + continue } - - switch { - case string(message) == pingMsg: - keepaliveMsg = pongMsg - - case string(message) == pongMsg: - if !gotPong { - log.Println("Received pong message.") - } - - // One client may skip sending ping if it receives - // a ping message before knowning the peer address. - keepaliveMsg = pongMsg - - gotPong = true - - case stun.IsMessage(message): + if stun.IsMessage(message) { m := new(stun.Message) m.Raw = message decErr := m.Decode() if decErr != nil { - log.Println("decode:", decErr) + log.Warnf("decode:", decErr) break } var xorAddr stun.XORMappedAddress if getErr := xorAddr.GetFrom(m); getErr != nil { - log.Println("getFrom:", getErr) + log.Warnf("getFrom:", getErr) - break + continue } - - if publicAddr.String() != xorAddr.String() { - log.Printf("My public address: %s\n", xorAddr) - publicAddr = xorAddr - - //peerAddrChan = getPeerAddr() + if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() { + log.Warnf("My public address: %s\n", xorAddr) + s.stunAddr = &xorAddr } + return &xorAddr, nil - default: - send(message, conn, peerAddr) } - - case <-keepalive: - // Keep NAT binding alive using STUN server or the peer once it's known - err := sendStr(keepaliveMsg, conn, peerAddr) - if keepaliveMsg == pongMsg { - sentPong = true - } - _ = sentPong - - if err != nil { - log.Panicln("keepalive:", err) - } - - case <-quit: - _ = conn.Close() } + + } + return nil, fmt.Errorf("failed to get STUN address") +} + +func (s *NatTraversal) StartProxy(targetAddr string) error { + log.Infof("Starting NAT traversal proxy to %s", targetAddr) + peerAddr, err := net.ResolveUDPAddr(udp, targetAddr) + if err != nil { + return fmt.Errorf("resolve peeraddr: %w", err) } + if s.stunAddr == nil { + addr, err := s.StunAddr() + if err != nil { + return fmt.Errorf("get STUN address: %w", err) + } + log.Infof("STUN address: %s", addr) + } + for { + select { + case <-s.cancel: + log.Infof("cancelled") + return nil + case m := <-s.messageChan: + //log.Infof("Received message: %d", len(m)) + send(m, s.conn, peerAddr) + } + + } } func listen(conn *net.UDPConn) <-chan []byte { @@ -157,13 +126,15 @@ func listen(conn *net.UDPConn) <-chan []byte { for { buf := make([]byte, 10240) - n, _, err := conn.ReadFromUDP(buf) + n, addr, err := conn.ReadFromUDP(buf) if err != nil { close(messages) return } + log.Debugf("Received message from %s: %d", addr, n) buf = buf[:n] + log.Debugf("recevied message %s", string(buf)) messages <- buf }