From 9a45f0797230f79c4160c6b432c355000e680cd9 Mon Sep 17 00:00:00 2001 From: rinfx Date: Wed, 9 Jul 2025 15:42:00 +0800 Subject: [PATCH] fix: [ai-load-balancer]move the logic of request count to HttpStreamDone phase (#2564) --- .../global_least_request/lb_policy.go | 33 +++++------ .../ai-load-balancer/least_busy/lb_policy.go | 2 + .../extensions/ai-load-balancer/main.go | 6 ++ .../prefix_cache/lb_policy.go | 55 ++++++++++--------- 4 files changed, 55 insertions(+), 41 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go index 2ac51b3da..23db26d2a 100644 --- a/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go +++ b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go @@ -37,14 +37,14 @@ local function is_healthy(addr) return false end -if redis.call('HEXISTS', hset_key, current_target) ~= 0 then +if redis.call('HEXISTS', hset_key, current_target) == 1 then current_count = redis.call('HGET', hset_key, current_target) local hash = redis.call('HGETALL', hset_key) for i = 1, #hash, 2 do local addr = hash[i] local count = hash[i+1] if is_healthy(addr) then - if count < current_count then + if tonumber(count) < tonumber(current_count) then current_target = addr current_count = count elseif count == current_count and randomBool() then @@ -125,7 +125,7 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC } randomIndex := rand.Intn(len(healthyHostArray)) hostSelected := healthyHostArray[randomIndex] - keys := []interface{}{time.Now().Unix(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected} + keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected} for _, v := range healthyHostArray { keys = append(keys, v) } @@ -157,22 +157,23 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.H } func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte { - if endOfStream { - isErr, _ := ctx.GetContext("error").(bool) - if !isErr { - routeName, _ := ctx.GetContext("routeName").(string) - clusterName, _ := ctx.GetContext("clusterName").(string) - host_selected, _ := ctx.GetContext("host_selected").(string) - if host_selected == "" { - log.Errorf("get host_selected failed") - } else { - lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil) - } - } - } return data } func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action { return types.ActionContinue } + +func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) { + isErr, _ := ctx.GetContext("error").(bool) + if !isErr { + routeName, _ := ctx.GetContext("routeName").(string) + clusterName, _ := ctx.GetContext("clusterName").(string) + host_selected, _ := ctx.GetContext("host_selected").(string) + if host_selected == "" { + log.Errorf("get host_selected failed") + } else { + lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil) + } + } +} diff --git a/plugins/wasm-go/extensions/ai-load-balancer/least_busy/lb_policy.go b/plugins/wasm-go/extensions/ai-load-balancer/least_busy/lb_policy.go index acdc4000c..6de789920 100644 --- a/plugins/wasm-go/extensions/ai-load-balancer/least_busy/lb_policy.go +++ b/plugins/wasm-go/extensions/ai-load-balancer/least_busy/lb_policy.go @@ -77,3 +77,5 @@ func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.Http func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action { return types.ActionContinue } + +func (lb LeastBusyLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {} diff --git a/plugins/wasm-go/extensions/ai-load-balancer/main.go b/plugins/wasm-go/extensions/ai-load-balancer/main.go index 9df8d4ca9..ddf85f9a1 100644 --- a/plugins/wasm-go/extensions/ai-load-balancer/main.go +++ b/plugins/wasm-go/extensions/ai-load-balancer/main.go @@ -23,6 +23,7 @@ func init() { wrapper.ProcessResponseHeaders(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody), wrapper.ProcessResponseBody(onHttpResponseBody), + wrapper.ProcessStreamDone(onHttpStreamDone), ) } @@ -32,6 +33,7 @@ type LoadBalancer interface { HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action + HandleHttpStreamDone(ctx wrapper.HttpContext) } type Config struct { @@ -80,3 +82,7 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config Config, data [] func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action { return config.lb.HandleHttpResponseBody(ctx, body) } + +func onHttpStreamDone(ctx wrapper.HttpContext, config Config) { + config.lb.HandleHttpStreamDone(ctx) +} diff --git a/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache/lb_policy.go b/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache/lb_policy.go index 074827d52..5663e759c 100644 --- a/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache/lb_policy.go +++ b/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache/lb_policy.go @@ -79,21 +79,20 @@ local function is_healthy(addr) return false end +local function randomBool() + return math.random() >= 0.5 +end + local target = "" local key = "" local current_key = "" -local count = #ARGV local ttl = KEYS[1] local hset_key = KEYS[2] local default_target = KEYS[3] -if count == 0 then - return target -end - -- find longest prefix local index = 1 -while index <= count do +while index <= #ARGV do if current_key == "" then current_key = ARGV[index] else @@ -120,15 +119,20 @@ if target == "" then index = 1 local current_count = 0 target = default_target - if redis.call('HEXISTS', hset_key, target) ~= 0 then + if redis.call('HEXISTS', hset_key, target) == 1 then current_count = redis.call('HGET', hset_key, target) local hash = redis.call('HGETALL', hset_key) for i = 1, #hash, 2 do local addr = hash[i] local count = hash[i+1] - if count < current_count and is_healthy(addr) then - target = addr - current_count = count + if is_healthy(addr) then + if tonumber(count) < tonumber(current_count) then + target = addr + current_count = count + elseif count == current_count and randomBool() then + target = addr + current_count = count + end end end end @@ -138,7 +142,7 @@ end redis.call("HINCRBY", hset_key, target, 1) -- add tree-path -while index <= count do +while index <= #ARGV do if key == "" then key = ARGV[index] else @@ -177,7 +181,7 @@ func NewPrefixCacheLoadBalancer(json gjson.Result) (PrefixCacheLoadBalancer, err } // database default is 0 database := json.Get("database").Int() - if json.Get("redisKeyTTL").Int() == 0 { + if json.Get("redisKeyTTL").Int() != 0 { lb.redisKeyTTL = int(json.Get("redisKeyTTL").Int()) } else { lb.redisKeyTTL = 1800 @@ -275,19 +279,6 @@ func (lb PrefixCacheLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpCont } func (lb PrefixCacheLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte { - if endOfStream { - isErr, _ := ctx.GetContext("error").(bool) - if !isErr { - routeName, _ := ctx.GetContext("routeName").(string) - clusterName, _ := ctx.GetContext("clusterName").(string) - host_selected, _ := ctx.GetContext("host_selected").(string) - if host_selected == "" { - log.Errorf("get host_selected failed") - } else { - lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil) - } - } - } return data } @@ -295,6 +286,20 @@ func (lb PrefixCacheLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext return types.ActionContinue } +func (lb PrefixCacheLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) { + isErr, _ := ctx.GetContext("error").(bool) + if !isErr { + routeName, _ := ctx.GetContext("routeName").(string) + clusterName, _ := ctx.GetContext("clusterName").(string) + host_selected, _ := ctx.GetContext("host_selected").(string) + if host_selected == "" { + log.Errorf("get host_selected failed") + } else { + lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil) + } + } +} + func computeSHA1(data string) string { hasher := sha1.New() hasher.Write([]byte(data))