mirror of
https://github.com/simon-ding/polaris.git
synced 2026-02-06 23:21:00 +08:00
feat: update stun proxy logic
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"polaris/log"
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun/v3"
|
||||
)
|
||||
@@ -15,7 +16,7 @@ const (
|
||||
timeoutMillis = 500
|
||||
)
|
||||
|
||||
func NewNatTraversal() (*NatTraversal, error) {
|
||||
func NewNatTraversal(addrCallback func(stun.XORMappedAddress) error, targetHost string) (*NatTraversal, error) {
|
||||
conn, err := net.ListenUDP(udp, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen: %w", err)
|
||||
@@ -24,31 +25,57 @@ func NewNatTraversal() (*NatTraversal, error) {
|
||||
log.Infof("Listening on %s", conn.LocalAddr())
|
||||
|
||||
messageChan := listen(conn)
|
||||
s := &NatTraversal{
|
||||
conn: conn,
|
||||
messageChan: messageChan,
|
||||
cancel: make(chan struct{}),
|
||||
addrChan: make(chan stun.XORMappedAddress),
|
||||
addrCallback: addrCallback,
|
||||
targetHost: targetHost,
|
||||
}
|
||||
|
||||
return &NatTraversal{
|
||||
conn: conn,
|
||||
messageChan: messageChan,
|
||||
cancel: make(chan struct{}),
|
||||
}, nil
|
||||
go s.updateNatAddr()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type NatTraversal struct {
|
||||
//peerAddr *net.UDPAddr
|
||||
conn *net.UDPConn
|
||||
messageChan <-chan []byte
|
||||
addrChan chan stun.XORMappedAddress
|
||||
cancel chan struct{}
|
||||
|
||||
stunAddr *stun.XORMappedAddress
|
||||
stunAddr *stun.XORMappedAddress
|
||||
addrCallback func(stun.XORMappedAddress) error
|
||||
targetHost string
|
||||
targetPort int
|
||||
}
|
||||
|
||||
func (s *NatTraversal) Cancel() {
|
||||
|
||||
|
||||
close(s.cancel)
|
||||
s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *NatTraversal) updateNatAddr() {
|
||||
for addr := range s.addrChan {
|
||||
if s.stunAddr == nil || s.stunAddr.String() != addr.String() { //new address
|
||||
log.Warnf("My public address: %s\n", addr)
|
||||
if s.addrCallback != nil { //execute callback
|
||||
if err := s.addrCallback(addr); err != nil {
|
||||
log.Warnf("callback error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
|
||||
s.targetPort = addr.Port
|
||||
log.Infof("now proxy to target host: %s:%d", s.targetHost, s.targetPort)
|
||||
s.stunAddr = &addr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NatTraversal) sendStunServerBindingMsg() error {
|
||||
for _, srv := range getStunServers() {
|
||||
log.Debugf("try to connect to stun server: %s", srv)
|
||||
srvAddr, err := net.ResolveUDPAddr(udp, srv)
|
||||
@@ -58,62 +85,74 @@ func (s *NatTraversal) StunAddr() (*stun.XORMappedAddress, error) {
|
||||
}
|
||||
err = sendBindingRequest(s.conn, srvAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send binding request: %w", err)
|
||||
}
|
||||
message, ok := <-s.messageChan
|
||||
if !ok {
|
||||
log.Warnf("send binding request: %w", err)
|
||||
continue
|
||||
}
|
||||
if stun.IsMessage(message) {
|
||||
m := new(stun.Message)
|
||||
m.Raw = message
|
||||
decErr := m.Decode()
|
||||
if decErr != nil {
|
||||
log.Warnf("decode:", decErr)
|
||||
|
||||
break
|
||||
}
|
||||
var xorAddr stun.XORMappedAddress
|
||||
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
||||
log.Warnf("getFrom:", getErr)
|
||||
|
||||
continue
|
||||
}
|
||||
if s.stunAddr == nil || s.stunAddr.String() != xorAddr.String() {
|
||||
log.Warnf("My public address: %s\n", xorAddr)
|
||||
s.stunAddr = &xorAddr
|
||||
}
|
||||
return &xorAddr, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get STUN address")
|
||||
return fmt.Errorf("failed to get STUN address")
|
||||
}
|
||||
|
||||
func (s *NatTraversal) StartProxy(targetAddr string) error {
|
||||
log.Infof("Starting NAT traversal proxy to %s", targetAddr)
|
||||
peerAddr, err := net.ResolveUDPAddr(udp, targetAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve peeraddr: %w", err)
|
||||
func (s *NatTraversal) getNatAddr(msg []byte) (*stun.XORMappedAddress, error) {
|
||||
if !stun.IsMessage(msg) {
|
||||
return nil, fmt.Errorf("not a stun message")
|
||||
}
|
||||
|
||||
if s.stunAddr == nil {
|
||||
addr, err := s.StunAddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get STUN address: %w", err)
|
||||
}
|
||||
log.Infof("STUN address: %s", addr)
|
||||
m := new(stun.Message)
|
||||
m.Raw = msg
|
||||
decErr := m.Decode()
|
||||
if decErr != nil {
|
||||
return nil, fmt.Errorf("decode: %w", decErr)
|
||||
}
|
||||
var xorAddr stun.XORMappedAddress
|
||||
if getErr := xorAddr.GetFrom(m); getErr != nil {
|
||||
return nil, fmt.Errorf("getFrom: %w", getErr)
|
||||
}
|
||||
s.addrChan <- xorAddr
|
||||
|
||||
return &xorAddr, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *NatTraversal) StartProxy() {
|
||||
|
||||
tick := time.NewTicker(10 * time.Second)
|
||||
|
||||
go func() { //tcker message to check public ip and port
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.cancel:
|
||||
log.Infof("stun nat proxy cancelled")
|
||||
return
|
||||
case <-tick.C:
|
||||
err := s.sendStunServerBindingMsg()
|
||||
if err != nil {
|
||||
log.Warnf("send stun server binding msg: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.cancel:
|
||||
log.Infof("stun nat proxy cancelled")
|
||||
return nil
|
||||
return
|
||||
case m := <-s.messageChan:
|
||||
//log.Infof("Received message: %d", len(m))
|
||||
send(m, s.conn, peerAddr)
|
||||
if stun.IsMessage(m) {
|
||||
s.getNatAddr(m)
|
||||
} else {
|
||||
peerAddr, err := net.ResolveUDPAddr(udp, fmt.Sprintf("%s:%d", s.targetHost, s.targetPort))
|
||||
if err != nil {
|
||||
log.Errorf("resolve peeraddr: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
send(m, s.conn, peerAddr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user