mirror of
https://github.com/simon-ding/polaris.git
synced 2026-06-09 11:39:46 +08:00
udp stun proxy
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
"polaris/pkg/tmdb"
|
"polaris/pkg/tmdb"
|
||||||
"polaris/pkg/transmission"
|
"polaris/pkg/transmission"
|
||||||
"polaris/pkg/utils"
|
"polaris/pkg/utils"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/robfig/cron"
|
"github.com/robfig/cron"
|
||||||
@@ -50,6 +51,12 @@ func (c *Engine) Init() {
|
|||||||
go c.reloadTasks()
|
go c.reloadTasks()
|
||||||
c.addSysCron()
|
c.addSysCron()
|
||||||
go c.checkW500PosterOnStartup()
|
go c.checkW500PosterOnStartup()
|
||||||
|
go func() {
|
||||||
|
time.Sleep(10*time.Second)
|
||||||
|
if err := c.stunProxyDownloadClient(); err != nil {
|
||||||
|
log.Errorf("stun proxy error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Engine) GetTask(id int) (*Task, bool) {
|
func (c *Engine) GetTask(id int) (*Task, bool) {
|
||||||
|
|||||||
41
engine/stun.go
Normal file
41
engine/stun.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package engine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"polaris/ent/downloadclients"
|
||||||
|
"polaris/pkg/nat"
|
||||||
|
"polaris/pkg/qbittorrent"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Engine) stunProxyDownloadClient() error {
|
||||||
|
downloader, e, err := s.GetDownloadClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if e.Implementation != downloadclients.ImplementationQbittorrent {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
d, ok := downloader.(*qbittorrent.Client)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
n, err := nat.NewNatTraversal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
addr, err := n.StunAddr()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = d.SetListenPort(addr.Port)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
u, err := url.Parse(d.URL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n.StartProxy(u.Hostname() + ":" + strconv.Itoa(addr.Port))
|
||||||
|
}
|
||||||
@@ -2,9 +2,8 @@ package nat
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"polaris/log"
|
||||||
|
|
||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
)
|
)
|
||||||
@@ -16,139 +15,109 @@ const (
|
|||||||
timeoutMillis = 500
|
timeoutMillis = 500
|
||||||
)
|
)
|
||||||
|
|
||||||
type natTraversal struct {
|
func NewNatTraversal() (*NatTraversal, error) {
|
||||||
peerAddr *net.UDPAddr
|
|
||||||
cancel chan struct{}
|
|
||||||
port <-chan int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *natTraversal) Port() int {
|
|
||||||
return <-s.port
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *natTraversal) Cancel() {
|
|
||||||
s.cancel <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NatTraversal(targetAddr string) (*natTraversal, error) { //nolint:gocognit,cyclop
|
|
||||||
|
|
||||||
srvAddr, err := net.ResolveUDPAddr(udp, getStunServers()[0])
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to resolve server addr: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP(udp, nil)
|
conn, err := net.ListenUDP(udp, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listen: %w", err)
|
return nil, fmt.Errorf("listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Listening on %s", conn.LocalAddr())
|
log.Infof("Listening on %s", conn.LocalAddr())
|
||||||
|
|
||||||
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("resolve peeraddr: %w", err)
|
|
||||||
}
|
|
||||||
err = sendBindingRequest(conn, srvAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("send binding request: %w", err)
|
|
||||||
}
|
|
||||||
nt := &natTraversal{
|
|
||||||
peerAddr: peerAddr,
|
|
||||||
cancel: make(chan struct{}),
|
|
||||||
port: make(chan int),
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
err := doTraversal(conn, peerAddr, nt.cancel)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("nat traversal error:", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nt, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func doTraversal(conn *net.UDPConn, peerAddr *net.UDPAddr, quit <-chan struct{}) error {
|
|
||||||
defer func() {
|
|
||||||
_ = conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var publicAddr stun.XORMappedAddress
|
|
||||||
|
|
||||||
messageChan := listen(conn)
|
messageChan := listen(conn)
|
||||||
//var peerAddrChan <-chan string
|
|
||||||
|
|
||||||
keepalive := time.Tick(timeoutMillis * time.Millisecond)
|
return &NatTraversal{
|
||||||
keepaliveMsg := pingMsg
|
conn: conn,
|
||||||
|
messageChan: messageChan,
|
||||||
|
cancel: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
gotPong := false
|
type NatTraversal struct {
|
||||||
sentPong := false
|
//peerAddr *net.UDPAddr
|
||||||
|
conn *net.UDPConn
|
||||||
|
messageChan <-chan []byte
|
||||||
|
cancel chan struct{}
|
||||||
|
|
||||||
for {
|
stunAddr *stun.XORMappedAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *NatTraversal) Cancel() {
|
||||||
|
|
||||||
|
close(s.cancel)
|
||||||
|
s.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
|
||||||
|
for _, srv := range getStunServers() {
|
||||||
|
log.Debugf("try to connect to stun server: %s", srv)
|
||||||
|
srvAddr, err := net.ResolveUDPAddr(udp, srv)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to resolve server addr: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = sendBindingRequest(s.conn, srvAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send binding request: %w", err)
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case message, ok := <-messageChan:
|
case message, ok := <-s.messageChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
continue
|
||||||
}
|
}
|
||||||
|
if stun.IsMessage(message) {
|
||||||
switch {
|
|
||||||
case string(message) == pingMsg:
|
|
||||||
keepaliveMsg = pongMsg
|
|
||||||
|
|
||||||
case string(message) == pongMsg:
|
|
||||||
if !gotPong {
|
|
||||||
log.Println("Received pong message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// One client may skip sending ping if it receives
|
|
||||||
// a ping message before knowning the peer address.
|
|
||||||
keepaliveMsg = pongMsg
|
|
||||||
|
|
||||||
gotPong = true
|
|
||||||
|
|
||||||
case stun.IsMessage(message):
|
|
||||||
m := new(stun.Message)
|
m := new(stun.Message)
|
||||||
m.Raw = message
|
m.Raw = message
|
||||||
decErr := m.Decode()
|
decErr := m.Decode()
|
||||||
if decErr != nil {
|
if decErr != nil {
|
||||||
log.Println("decode:", decErr)
|
log.Warnf("decode:", decErr)
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
var xorAddr stun.XORMappedAddress
|
var xorAddr stun.XORMappedAddress
|
||||||
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
||||||
log.Println("getFrom:", getErr)
|
log.Warnf("getFrom:", getErr)
|
||||||
|
|
||||||
break
|
continue
|
||||||
}
|
}
|
||||||
|
if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() {
|
||||||
if publicAddr.String() != xorAddr.String() {
|
log.Warnf("My public address: %s\n", xorAddr)
|
||||||
log.Printf("My public address: %s\n", xorAddr)
|
s.stunAddr = &xorAddr
|
||||||
publicAddr = xorAddr
|
|
||||||
|
|
||||||
//peerAddrChan = getPeerAddr()
|
|
||||||
}
|
}
|
||||||
|
return &xorAddr, nil
|
||||||
|
|
||||||
default:
|
|
||||||
send(message, conn, peerAddr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-keepalive:
|
|
||||||
// Keep NAT binding alive using STUN server or the peer once it's known
|
|
||||||
err := sendStr(keepaliveMsg, conn, peerAddr)
|
|
||||||
if keepaliveMsg == pongMsg {
|
|
||||||
sentPong = true
|
|
||||||
}
|
|
||||||
_ = sentPong
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Panicln("keepalive:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-quit:
|
|
||||||
_ = conn.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get STUN address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *NatTraversal) StartProxy(targetAddr string) error {
|
||||||
|
log.Infof("Starting NAT traversal proxy to %s", targetAddr)
|
||||||
|
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("resolve peeraddr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.stunAddr == nil {
|
||||||
|
addr, err := s.StunAddr()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get STUN address: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("STUN address: %s", addr)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.cancel:
|
||||||
|
log.Infof("cancelled")
|
||||||
|
return nil
|
||||||
|
case m := <-s.messageChan:
|
||||||
|
//log.Infof("Received message: %d", len(m))
|
||||||
|
send(m, s.conn, peerAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func listen(conn *net.UDPConn) <-chan []byte {
|
func listen(conn *net.UDPConn) <-chan []byte {
|
||||||
@@ -157,13 +126,15 @@ func listen(conn *net.UDPConn) <-chan []byte {
|
|||||||
for {
|
for {
|
||||||
buf := make([]byte, 10240)
|
buf := make([]byte, 10240)
|
||||||
|
|
||||||
n, _, err := conn.ReadFromUDP(buf)
|
n, addr, err := conn.ReadFromUDP(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
close(messages)
|
close(messages)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Debugf("Received message from %s: %d", addr, n)
|
||||||
buf = buf[:n]
|
buf = buf[:n]
|
||||||
|
log.Debugf("recevied message %s", string(buf))
|
||||||
|
|
||||||
messages <- buf
|
messages <- buf
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user