udp stun proxy

This commit is contained in:
Simon Ding
2025-05-07 22:09:35 +08:00
parent 9719c6a7c9
commit 2dae168cb2
3 changed files with 125 additions and 106 deletions

View File

@@ -11,6 +11,7 @@ import (
"polaris/pkg/tmdb"
"polaris/pkg/transmission"
"polaris/pkg/utils"
"time"
"github.com/pkg/errors"
"github.com/robfig/cron"
@@ -50,6 +51,12 @@ func (c *Engine) Init() {
go c.reloadTasks()
c.addSysCron()
go c.checkW500PosterOnStartup()
go func() {
time.Sleep(10*time.Second)
if err := c.stunProxyDownloadClient(); err != nil {
log.Errorf("stun proxy error: %v", err)
}
}()
}
func (c *Engine) GetTask(id int) (*Task, bool) {

41
engine/stun.go Normal file
View File

@@ -0,0 +1,41 @@
package engine
import (
"net/url"
"polaris/ent/downloadclients"
"polaris/pkg/nat"
"polaris/pkg/qbittorrent"
"strconv"
)
func (s *Engine) stunProxyDownloadClient() error {
downloader, e, err := s.GetDownloadClient()
if err != nil {
return err
}
if e.Implementation != downloadclients.ImplementationQbittorrent {
return nil
}
d, ok := downloader.(*qbittorrent.Client)
if !ok {
return nil
}
n, err := nat.NewNatTraversal()
if err != nil {
return err
}
addr, err := n.StunAddr()
if err != nil {
return err
}
err = d.SetListenPort(addr.Port)
if err != nil {
return err
}
u, err := url.Parse(d.URL)
if err != nil {
return err
}
return n.StartProxy(u.Hostname() + ":" + strconv.Itoa(addr.Port))
}

View File

@@ -2,9 +2,8 @@ package nat
import (
"fmt"
"log"
"net"
"time"
"polaris/log"
"github.com/pion/stun/v3"
)
@@ -16,139 +15,109 @@ const (
timeoutMillis = 500
)
type natTraversal struct {
peerAddr *net.UDPAddr
cancel chan struct{}
port <-chan int
}
func (s *natTraversal) Port() int {
return <-s.port
}
func (s *natTraversal) Cancel() {
s.cancel <- struct{}{}
}
func NatTraversal(targetAddr string) (*natTraversal, error) { //nolint:gocognit,cyclop
srvAddr, err := net.ResolveUDPAddr(udp, getStunServers()[0])
if err != nil {
log.Fatalf("Failed to resolve server addr: %s", err)
}
func NewNatTraversal() (*NatTraversal, error) {
conn, err := net.ListenUDP(udp, nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
log.Printf("Listening on %s", conn.LocalAddr())
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
if err != nil {
return nil, fmt.Errorf("resolve peeraddr: %w", err)
}
err = sendBindingRequest(conn, srvAddr)
if err != nil {
return nil, fmt.Errorf("send binding request: %w", err)
}
nt := &natTraversal{
peerAddr: peerAddr,
cancel: make(chan struct{}),
port: make(chan int),
}
go func() {
err := doTraversal(conn, peerAddr, nt.cancel)
if err != nil {
log.Println("nat traversal error:", err)
}
}()
return nt, nil
}
func doTraversal(conn *net.UDPConn, peerAddr *net.UDPAddr, quit <-chan struct{}) error {
defer func() {
_ = conn.Close()
}()
var publicAddr stun.XORMappedAddress
log.Infof("Listening on %s", conn.LocalAddr())
messageChan := listen(conn)
//var peerAddrChan <-chan string
keepalive := time.Tick(timeoutMillis * time.Millisecond)
keepaliveMsg := pingMsg
return &NatTraversal{
conn: conn,
messageChan: messageChan,
cancel: make(chan struct{}),
}, nil
}
gotPong := false
sentPong := false
type NatTraversal struct {
//peerAddr *net.UDPAddr
conn *net.UDPConn
messageChan <-chan []byte
cancel chan struct{}
for {
stunAddr *stun.XORMappedAddress
}
func (s *NatTraversal) Cancel() {
close(s.cancel)
s.conn.Close()
}
func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
for _, srv := range getStunServers() {
log.Debugf("try to connect to stun server: %s", srv)
srvAddr, err := net.ResolveUDPAddr(udp, srv)
if err != nil {
log.Warnf("Failed to resolve server addr: %s", err)
continue
}
err = sendBindingRequest(s.conn, srvAddr)
if err != nil {
return nil, fmt.Errorf("send binding request: %w", err)
}
select {
case message, ok := <-messageChan:
case message, ok := <-s.messageChan:
if !ok {
return nil
continue
}
switch {
case string(message) == pingMsg:
keepaliveMsg = pongMsg
case string(message) == pongMsg:
if !gotPong {
log.Println("Received pong message.")
}
// One client may skip sending ping if it receives
// a ping message before knowning the peer address.
keepaliveMsg = pongMsg
gotPong = true
case stun.IsMessage(message):
if stun.IsMessage(message) {
m := new(stun.Message)
m.Raw = message
decErr := m.Decode()
if decErr != nil {
log.Println("decode:", decErr)
log.Warnf("decode:", decErr)
break
}
var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil {
log.Println("getFrom:", getErr)
log.Warnf("getFrom:", getErr)
break
continue
}
if publicAddr.String() != xorAddr.String() {
log.Printf("My public address: %s\n", xorAddr)
publicAddr = xorAddr
//peerAddrChan = getPeerAddr()
if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() {
log.Warnf("My public address: %s\n", xorAddr)
s.stunAddr = &xorAddr
}
return &xorAddr, nil
default:
send(message, conn, peerAddr)
}
case <-keepalive:
// Keep NAT binding alive using STUN server or the peer once it's known
err := sendStr(keepaliveMsg, conn, peerAddr)
if keepaliveMsg == pongMsg {
sentPong = true
}
_ = sentPong
if err != nil {
log.Panicln("keepalive:", err)
}
case <-quit:
_ = conn.Close()
}
}
return nil, fmt.Errorf("failed to get STUN address")
}
func (s *NatTraversal) StartProxy(targetAddr string) error {
log.Infof("Starting NAT traversal proxy to %s", targetAddr)
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
if err != nil {
return fmt.Errorf("resolve peeraddr: %w", err)
}
if s.stunAddr == nil {
addr, err := s.StunAddr()
if err != nil {
return fmt.Errorf("get STUN address: %w", err)
}
log.Infof("STUN address: %s", addr)
}
for {
select {
case <-s.cancel:
log.Infof("cancelled")
return nil
case m := <-s.messageChan:
//log.Infof("Received message: %d", len(m))
send(m, s.conn, peerAddr)
}
}
}
func listen(conn *net.UDPConn) <-chan []byte {
@@ -157,13 +126,15 @@ func listen(conn *net.UDPConn) <-chan []byte {
for {
buf := make([]byte, 10240)
n, _, err := conn.ReadFromUDP(buf)
n, addr, err := conn.ReadFromUDP(buf)
if err != nil {
close(messages)
return
}
log.Debugf("Received message from %s: %d", addr, n)
buf = buf[:n]
log.Debugf("recevied message %s", string(buf))
messages <- buf
}