udp stun proxy

This commit is contained in:
Simon Ding
2025-05-07 22:09:35 +08:00
parent 9719c6a7c9
commit 2dae168cb2
3 changed files with 125 additions and 106 deletions

View File

@@ -11,6 +11,7 @@ import (
"polaris/pkg/tmdb" "polaris/pkg/tmdb"
"polaris/pkg/transmission" "polaris/pkg/transmission"
"polaris/pkg/utils" "polaris/pkg/utils"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/robfig/cron" "github.com/robfig/cron"
@@ -50,6 +51,12 @@ func (c *Engine) Init() {
go c.reloadTasks() go c.reloadTasks()
c.addSysCron() c.addSysCron()
go c.checkW500PosterOnStartup() go c.checkW500PosterOnStartup()
go func() {
time.Sleep(10*time.Second)
if err := c.stunProxyDownloadClient(); err != nil {
log.Errorf("stun proxy error: %v", err)
}
}()
} }
func (c *Engine) GetTask(id int) (*Task, bool) { func (c *Engine) GetTask(id int) (*Task, bool) {

41
engine/stun.go Normal file
View File

@@ -0,0 +1,41 @@
package engine
import (
"net/url"
"polaris/ent/downloadclients"
"polaris/pkg/nat"
"polaris/pkg/qbittorrent"
"strconv"
)
func (s *Engine) stunProxyDownloadClient() error {
downloader, e, err := s.GetDownloadClient()
if err != nil {
return err
}
if e.Implementation != downloadclients.ImplementationQbittorrent {
return nil
}
d, ok := downloader.(*qbittorrent.Client)
if !ok {
return nil
}
n, err := nat.NewNatTraversal()
if err != nil {
return err
}
addr, err := n.StunAddr()
if err != nil {
return err
}
err = d.SetListenPort(addr.Port)
if err != nil {
return err
}
u, err := url.Parse(d.URL)
if err != nil {
return err
}
return n.StartProxy(u.Hostname() + ":" + strconv.Itoa(addr.Port))
}

View File

@@ -2,9 +2,8 @@ package nat
import ( import (
"fmt" "fmt"
"log"
"net" "net"
"time" "polaris/log"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
) )
@@ -16,139 +15,109 @@ const (
timeoutMillis = 500 timeoutMillis = 500
) )
type natTraversal struct { func NewNatTraversal() (*NatTraversal, error) {
peerAddr *net.UDPAddr
cancel chan struct{}
port <-chan int
}
func (s *natTraversal) Port() int {
return <-s.port
}
func (s *natTraversal) Cancel() {
s.cancel <- struct{}{}
}
func NatTraversal(targetAddr string) (*natTraversal, error) { //nolint:gocognit,cyclop
srvAddr, err := net.ResolveUDPAddr(udp, getStunServers()[0])
if err != nil {
log.Fatalf("Failed to resolve server addr: %s", err)
}
conn, err := net.ListenUDP(udp, nil) conn, err := net.ListenUDP(udp, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("listen: %w", err) return nil, fmt.Errorf("listen: %w", err)
} }
log.Printf("Listening on %s", conn.LocalAddr()) log.Infof("Listening on %s", conn.LocalAddr())
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
if err != nil {
return nil, fmt.Errorf("resolve peeraddr: %w", err)
}
err = sendBindingRequest(conn, srvAddr)
if err != nil {
return nil, fmt.Errorf("send binding request: %w", err)
}
nt := &natTraversal{
peerAddr: peerAddr,
cancel: make(chan struct{}),
port: make(chan int),
}
go func() {
err := doTraversal(conn, peerAddr, nt.cancel)
if err != nil {
log.Println("nat traversal error:", err)
}
}()
return nt, nil
}
func doTraversal(conn *net.UDPConn, peerAddr *net.UDPAddr, quit <-chan struct{}) error {
defer func() {
_ = conn.Close()
}()
var publicAddr stun.XORMappedAddress
messageChan := listen(conn) messageChan := listen(conn)
//var peerAddrChan <-chan string
keepalive := time.Tick(timeoutMillis * time.Millisecond) return &NatTraversal{
keepaliveMsg := pingMsg conn: conn,
messageChan: messageChan,
cancel: make(chan struct{}),
}, nil
}
gotPong := false type NatTraversal struct {
sentPong := false //peerAddr *net.UDPAddr
conn *net.UDPConn
messageChan <-chan []byte
cancel chan struct{}
for { stunAddr *stun.XORMappedAddress
}
func (s *NatTraversal) Cancel() {
close(s.cancel)
s.conn.Close()
}
func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
for _, srv := range getStunServers() {
log.Debugf("try to connect to stun server: %s", srv)
srvAddr, err := net.ResolveUDPAddr(udp, srv)
if err != nil {
log.Warnf("Failed to resolve server addr: %s", err)
continue
}
err = sendBindingRequest(s.conn, srvAddr)
if err != nil {
return nil, fmt.Errorf("send binding request: %w", err)
}
select { select {
case message, ok := <-messageChan: case message, ok := <-s.messageChan:
if !ok { if !ok {
return nil continue
} }
if stun.IsMessage(message) {
switch {
case string(message) == pingMsg:
keepaliveMsg = pongMsg
case string(message) == pongMsg:
if !gotPong {
log.Println("Received pong message.")
}
// One client may skip sending ping if it receives
// a ping message before knowning the peer address.
keepaliveMsg = pongMsg
gotPong = true
case stun.IsMessage(message):
m := new(stun.Message) m := new(stun.Message)
m.Raw = message m.Raw = message
decErr := m.Decode() decErr := m.Decode()
if decErr != nil { if decErr != nil {
log.Println("decode:", decErr) log.Warnf("decode:", decErr)
break break
} }
var xorAddr stun.XORMappedAddress var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil { if getErr := xorAddr.GetFrom(m); getErr != nil {
log.Println("getFrom:", getErr) log.Warnf("getFrom:", getErr)
break continue
} }
if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() {
if publicAddr.String() != xorAddr.String() { log.Warnf("My public address: %s\n", xorAddr)
log.Printf("My public address: %s\n", xorAddr) s.stunAddr = &xorAddr
publicAddr = xorAddr
//peerAddrChan = getPeerAddr()
} }
return &xorAddr, nil
default:
send(message, conn, peerAddr)
} }
case <-keepalive:
// Keep NAT binding alive using STUN server or the peer once it's known
err := sendStr(keepaliveMsg, conn, peerAddr)
if keepaliveMsg == pongMsg {
sentPong = true
}
_ = sentPong
if err != nil {
log.Panicln("keepalive:", err)
}
case <-quit:
_ = conn.Close()
} }
}
return nil, fmt.Errorf("failed to get STUN address")
}
func (s *NatTraversal) StartProxy(targetAddr string) error {
log.Infof("Starting NAT traversal proxy to %s", targetAddr)
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
if err != nil {
return fmt.Errorf("resolve peeraddr: %w", err)
} }
if s.stunAddr == nil {
addr, err := s.StunAddr()
if err != nil {
return fmt.Errorf("get STUN address: %w", err)
}
log.Infof("STUN address: %s", addr)
}
for {
select {
case <-s.cancel:
log.Infof("cancelled")
return nil
case m := <-s.messageChan:
//log.Infof("Received message: %d", len(m))
send(m, s.conn, peerAddr)
}
}
} }
func listen(conn *net.UDPConn) <-chan []byte { func listen(conn *net.UDPConn) <-chan []byte {
@@ -157,13 +126,15 @@ func listen(conn *net.UDPConn) <-chan []byte {
for { for {
buf := make([]byte, 10240) buf := make([]byte, 10240)
n, _, err := conn.ReadFromUDP(buf) n, addr, err := conn.ReadFromUDP(buf)
if err != nil { if err != nil {
close(messages) close(messages)
return return
} }
log.Debugf("Received message from %s: %d", addr, n)
buf = buf[:n] buf = buf[:n]
log.Debugf("recevied message %s", string(buf))
messages <- buf messages <- buf
} }