feat: update stun proxy logic

This commit is contained in:
Simon Ding
2025-05-08 16:13:30 +08:00
parent bb2c567da7
commit 992fa7ddd0
12 changed files with 310 additions and 73 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"polaris/log"
"time"
"github.com/pion/stun/v3"
)
@@ -15,7 +16,7 @@ const (
timeoutMillis = 500
)
func NewNatTraversal() (*NatTraversal, error) {
func NewNatTraversal(addrCallback func(stun.XORMappedAddress) error, targetHost string) (*NatTraversal, error) {
conn, err := net.ListenUDP(udp, nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
@@ -24,31 +25,57 @@ func NewNatTraversal() (*NatTraversal, error) {
log.Infof("Listening on %s", conn.LocalAddr())
messageChan := listen(conn)
s := &NatTraversal{
conn: conn,
messageChan: messageChan,
cancel: make(chan struct{}),
addrChan: make(chan stun.XORMappedAddress),
addrCallback: addrCallback,
targetHost: targetHost,
}
return &NatTraversal{
conn: conn,
messageChan: messageChan,
cancel: make(chan struct{}),
}, nil
go s.updateNatAddr()
return s, nil
}
type NatTraversal struct {
//peerAddr *net.UDPAddr
conn *net.UDPConn
messageChan <-chan []byte
addrChan chan stun.XORMappedAddress
cancel chan struct{}
stunAddr *stun.XORMappedAddress
stunAddr *stun.XORMappedAddress
addrCallback func(stun.XORMappedAddress) error
targetHost string
targetPort int
}
func (s *NatTraversal) Cancel() {
close(s.cancel)
s.conn.Close()
}
func (s *NatTraversal) updateNatAddr() {
for addr := range s.addrChan {
if s.stunAddr == nil || s.stunAddr.String() != addr.String() { //new address
log.Warnf("My public address: %s\n", addr)
if s.addrCallback != nil { //execute callback
if err := s.addrCallback(addr); err != nil {
log.Warnf("callback error: %v", err)
}
}
func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
s.targetPort = addr.Port
log.Infof("now proxy to target host: %s:%d", s.targetHost, s.targetPort)
s.stunAddr = &addr
}
}
}
func (s *NatTraversal) sendStunServerBindingMsg() error {
for _, srv := range getStunServers() {
log.Debugf("try to connect to stun server: %s", srv)
srvAddr, err := net.ResolveUDPAddr(udp, srv)
@@ -58,62 +85,74 @@ func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
}
err = sendBindingRequest(s.conn, srvAddr)
if err != nil {
return nil, fmt.Errorf("send binding request: %w", err)
}
message, ok := <-s.messageChan
if !ok {
log.Warnf("send binding request: %w", err)
continue
}
if stun.IsMessage(message) {
m := new(stun.Message)
m.Raw = message
decErr := m.Decode()
if decErr != nil {
log.Warnf("decode:", decErr)
break
}
var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil {
log.Warnf("getFrom:", getErr)
continue
}
if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() {
log.Warnf("My public address: %s\n", xorAddr)
s.stunAddr = &xorAddr
}
return &xorAddr, nil
}
return nil
}
return nil, fmt.Errorf("failed to get STUN address")
return fmt.Errorf("failed to get STUN address")
}
func (s *NatTraversal) StartProxy(targetAddr string) error {
log.Infof("Starting NAT traversal proxy to %s", targetAddr)
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
if err != nil {
return fmt.Errorf("resolve peeraddr: %w", err)
func (s *NatTraversal) getNatAddr(msg []byte) (*stun.XORMappedAddress, error) {
if !stun.IsMessage(msg) {
return nil, fmt.Errorf("not a stun message")
}
if s.stunAddr == nil {
addr, err := s.StunAddr()
if err != nil {
return fmt.Errorf("get STUN address: %w", err)
}
log.Infof("STUN address: %s", addr)
m := new(stun.Message)
m.Raw = msg
decErr := m.Decode()
if decErr != nil {
return nil, fmt.Errorf("decode: %w", decErr)
}
var xorAddr stun.XORMappedAddress
if getErr := xorAddr.GetFrom(m); getErr != nil {
return nil, fmt.Errorf("getFrom: %w", getErr)
}
s.addrChan <- xorAddr
return &xorAddr, nil
}
func (s *NatTraversal) StartProxy() {
tick := time.NewTicker(10 * time.Second)
go func() { //tcker message to check public ip and port
defer tick.Stop()
for {
select {
case <-s.cancel:
log.Infof("stun nat proxy cancelled")
return
case <-tick.C:
err := s.sendStunServerBindingMsg()
if err != nil {
log.Warnf("send stun server binding msg: %w", err)
}
}
}
}()
for {
select {
case <-s.cancel:
log.Infof("stun nat proxy cancelled")
return nil
return
case m := <-s.messageChan:
//log.Infof("Received message: %d", len(m))
send(m, s.conn, peerAddr)
if stun.IsMessage(m) {
s.getNatAddr(m)
} else {
peerAddr, err := net.ResolveUDPAddr(udp, fmt.Sprintf("%s:%d", s.targetHost, s.targetPort))
if err != nil {
log.Errorf("resolve peeraddr: %w", err)
continue
}
send(m, s.conn, peerAddr)
}
}
}
}