diff --git a/plugins/wasm-go/extensions/ip-restriction/main.go b/plugins/wasm-go/extensions/ip-restriction/main.go index 775263981..c9e8b04a2 100644 --- a/plugins/wasm-go/extensions/ip-restriction/main.go +++ b/plugins/wasm-go/extensions/ip-restriction/main.go @@ -3,13 +3,13 @@ package main import ( "encoding/json" "fmt" + "net" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/zmap/go-iptree/iptree" - "net" - "strings" ) const ( @@ -101,9 +101,6 @@ func getDownStreamIp(config RestrictionConfig) (net.IP, error) { if config.IPSourceType == HeaderSourceType { s, err = proxywasm.GetHttpRequestHeader(config.IPHeaderName) - if err == nil { - s = strings.Split(strings.Trim(s, " "), ",")[0] - } } else { var bs []byte bs, err = proxywasm.GetProperty([]string{"source", "address"}) @@ -112,7 +109,7 @@ func getDownStreamIp(config RestrictionConfig) (net.IP, error) { if err != nil { return nil, err } - ip := parseIP(s) + ip := parseIP(s, config.IPSourceType == HeaderSourceType) realIP := net.ParseIP(ip) if realIP == nil { return nil, fmt.Errorf("invalid ip[%s]", ip) diff --git a/plugins/wasm-go/extensions/ip-restriction/utils.go b/plugins/wasm-go/extensions/ip-restriction/utils.go index 869ca063f..7706faa3d 100644 --- a/plugins/wasm-go/extensions/ip-restriction/utils.go +++ b/plugins/wasm-go/extensions/ip-restriction/utils.go @@ -2,9 +2,10 @@ package main import ( "fmt" + "strings" + "github.com/tidwall/gjson" "github.com/zmap/go-iptree/iptree" - "strings" ) // parseIPNets 解析Ip段配置 @@ -24,7 +25,12 @@ func parseIPNets(array []gjson.Result) (*iptree.IPTree, error) { } // parseIP 解析IP -func parseIP(source string) string { +func parseIP(source string, fromHeader bool) string { + + if fromHeader { + source = strings.Split(source, ",")[0] + } + source = strings.Trim(source, " ") if strings.Contains(source, ".") { // parse ipv4 return strings.Split(source, ":")[0] diff --git a/plugins/wasm-go/extensions/ip-restriction/utils_test.go b/plugins/wasm-go/extensions/ip-restriction/utils_test.go index 1c42eea3a..c1acb3671 100644 --- a/plugins/wasm-go/extensions/ip-restriction/utils_test.go +++ b/plugins/wasm-go/extensions/ip-restriction/utils_test.go @@ -1,8 +1,9 @@ package main import ( - "github.com/tidwall/gjson" "testing" + + "github.com/tidwall/gjson" ) func Test_parseIPNets(t *testing.T) { @@ -52,7 +53,8 @@ func Test_parseIPNets(t *testing.T) { func Test_parseIP(t *testing.T) { type args struct { - source string + source string + fromHeader bool } tests := []struct { name string @@ -64,6 +66,7 @@ func Test_parseIP(t *testing.T) { name: "case 1", args: args{ "127.0.0.1", + false, }, want: "127.0.0.1", }, @@ -71,6 +74,7 @@ func Test_parseIP(t *testing.T) { name: "case 2", args: args{ "127.0.0.1:12", + false, }, want: "127.0.0.1", }, @@ -78,6 +82,7 @@ func Test_parseIP(t *testing.T) { name: "case 3", args: args{ "fe80::14d5:8aff:fed9:2114", + false, }, want: "fe80::14d5:8aff:fed9:2114", }, @@ -85,6 +90,7 @@ func Test_parseIP(t *testing.T) { name: "case 4", args: args{ "[fe80::14d5:8aff:fed9:2114]:123", + false, }, want: "fe80::14d5:8aff:fed9:2114", }, @@ -92,13 +98,38 @@ func Test_parseIP(t *testing.T) { name: "case 5", args: args{ "127.0.0.1:12,[fe80::14d5:8aff:fed9:2114]:123", + true, + }, + want: "127.0.0.1", + }, + { + name: "case 6", + args: args{ + "127.0.0.1,[fe80::14d5:8aff:fed9:2114]:123", + true, + }, + want: "127.0.0.1", + }, + { + name: "case 7", + args: args{ + "[fe80::14d5:8aff:fed9:2114]:123,127.0.0.1", + true, + }, + want: "fe80::14d5:8aff:fed9:2114", + }, + { + name: "case 8", + args: args{ + "127.0.0.1 , [fe80::14d5:8aff:fed9:2114]:123", + true, }, want: "127.0.0.1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := parseIP(tt.args.source); got != tt.want { + if got := parseIP(tt.args.source, tt.args.fromHeader); got != tt.want { t.Errorf("parseIP() = %v, want %v", got, tt.want) } })