diff --git a/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go index 46fcb819c..daf87d04d 100644 --- a/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go +++ b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go @@ -16,40 +16,91 @@ import ( ) const ( - RedisKeyFormat = "higress:global_least_request_table:%s:%s" - RedisLua = `local seed = KEYS[1] + RedisKeyFormat = "higress:global_least_request_table:%s:%s" + RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s" + RedisLua = `local seed = tonumber(KEYS[1]) local hset_key = KEYS[2] -local current_target = KEYS[3] -local current_count = 0 +local last_clean_key = KEYS[3] +local clean_interval = tonumber(KEYS[4]) +local current_target = KEYS[5] +local healthy_count = tonumber(KEYS[6]) +local enable_detail_log = KEYS[7] math.randomseed(seed) -local function randomBool() - return math.random() >= 0.5 -end +-- 1. Selection +local current_count = 0 +local same_count_hits = 0 -if redis.call('HEXISTS', hset_key, current_target) == 1 then - current_count = redis.call('HGET', hset_key, current_target) - for i = 4, #KEYS do - if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then - local count = redis.call('HGET', hset_key, KEYS[i]) - if tonumber(count) < tonumber(current_count) then - current_target = KEYS[i] - current_count = count - elseif count == current_count and randomBool() then - current_target = KEYS[i] - end - end - end +for i = 8, 8 + healthy_count - 1 do + local host = KEYS[i] + local count = 0 + local val = redis.call('HGET', hset_key, host) + if val then + count = tonumber(val) or 0 + end + + if same_count_hits == 0 or count < current_count then + current_target = host + current_count = count + same_count_hits = 1 + elseif count == current_count then + same_count_hits = same_count_hits + 1 + if math.random(same_count_hits) == 1 then + current_target = host + end + end end redis.call("HINCRBY", hset_key, current_target, 1) +local new_count = redis.call("HGET", hset_key, current_target) -return current_target` +-- Collect host counts for logging +local host_details = {} +if enable_detail_log == "1" then + local fields = {} + for i = 8, #KEYS do + table.insert(fields, KEYS[i]) + end + if #fields > 0 then + local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields)) + for i, val in ipairs(values) do + table.insert(host_details, fields[i]) + table.insert(host_details, tostring(val or 0)) + end + end +end + +-- 2. Cleanup +local current_time = math.floor(seed / 1000000) +local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0) + +if current_time - last_clean_time >= clean_interval then + local all_keys = redis.call('HKEYS', hset_key) + if #all_keys > 0 then + -- Create a lookup table for current hosts (from index 8 onwards) + local current_hosts = {} + for i = 8, #KEYS do + current_hosts[KEYS[i]] = true + end + -- Remove keys not in current hosts + for _, host in ipairs(all_keys) do + if not current_hosts[host] then + redis.call('HDEL', hset_key, host) + end + end + end + redis.call('SET', last_clean_key, current_time) +end + +return {current_target, new_count, host_details}` ) type GlobalLeastRequestLoadBalancer struct { - redisClient wrapper.RedisClient + redisClient wrapper.RedisClient + maxRequestCount int64 + cleanInterval int64 // seconds + enableDetailLog bool } func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) { @@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa } // database default is 0 database := json.Get("database").Int() + lb.maxRequestCount = json.Get("maxRequestCount").Int() + lb.cleanInterval = json.Get("cleanInterval").Int() + if lb.cleanInterval == 0 { + lb.cleanInterval = 60 * 60 // default 60 minutes + } else { + lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds + } + lb.enableDetailLog = true + if val := json.Get("enableDetailLog"); val.Exists() { + lb.enableDetailLog = val.Bool() + } + log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog) return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database))) } @@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC ctx.SetContext("error", true) return types.ActionContinue } + allHostMap := make(map[string]struct{}) // Only healthy host can be selected healthyHostArray := []string{} for _, hostInfo := range hostInfos { + allHostMap[hostInfo[0]] = struct{}{} if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" { healthyHostArray = append(healthyHostArray, hostInfo[0]) } @@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC } randomIndex := rand.Intn(len(healthyHostArray)) hostSelected := healthyHostArray[randomIndex] - keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected} + + // KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts] + keys := []interface{}{ + time.Now().UnixMicro(), + fmt.Sprintf(RedisKeyFormat, routeName, clusterName), + fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName), + lb.cleanInterval, + hostSelected, + len(healthyHostArray), + "0", + } + if lb.enableDetailLog { + keys[6] = "1" + } for _, v := range healthyHostArray { keys = append(keys, v) } + // Append unhealthy hosts (those in allHostMap but not in healthyHostArray) + for host := range allHostMap { + isHealthy := false + for _, hh := range healthyHostArray { + if host == hh { + isHealthy = true + break + } + } + if !isHealthy { + keys = append(keys, host) + } + } + err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) { if err := response.Error(); err != nil { log.Errorf("HGetAll failed: %+v", err) @@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC proxywasm.ResumeHttpRequest() return } - hostSelected = response.String() + valArray := response.Array() + if len(valArray) < 2 { + log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray) + ctx.SetContext("error", true) + proxywasm.ResumeHttpRequest() + return + } + hostSelected = valArray[0].String() + currentCount := valArray[1].Integer() + + // detail log + if lb.enableDetailLog && len(valArray) >= 3 { + detailLogStr := "host and count: " + details := valArray[2].Array() + for i := 0; i+1 < len(details); i += 2 { + h := details[i].String() + c := details[i+1].String() + detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c) + } + log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr) + } + + // check rate limit + if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) { + ctx.SetContext("error", true) + log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount) + // return 429 + proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1) + ctx.DontReadResponseBody() + return + } + if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil { ctx.SetContext("error", true) log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err) + proxywasm.ResumeHttpRequest() + return } + log.Debugf("host_selected: %s", hostSelected) + + // finally resume the request ctx.SetContext("host_selected", hostSelected) proxywasm.ResumeHttpRequest() }) if err != nil { ctx.SetContext("error", true) + log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err) return types.ActionContinue } return types.ActionPause @@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo if host_selected == "" { log.Errorf("get host_selected failed") } else { - lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil) + err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil) + if err != nil { + log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err) + } } } } diff --git a/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_script_test.lua b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_script_test.lua new file mode 100644 index 000000000..3d4edd842 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_script_test.lua @@ -0,0 +1,220 @@ +-- Mocking Redis environment +local redis_data = { + hset = {}, + kv = {} +} + +local redis = { + call = function(cmd, ...) + local args = {...} + if cmd == "HGET" then + local key, field = args[1], args[2] + return redis_data.hset[field] + elseif cmd == "HSET" then + local key, field, val = args[1], args[2], args[3] + redis_data.hset[field] = val + elseif cmd == "HINCRBY" then + local key, field, increment = args[1], args[2], args[3] + local val = tonumber(redis_data.hset[field] or 0) + redis_data.hset[field] = tostring(val + increment) + return redis_data.hset[field] + elseif cmd == "HKEYS" then + local keys = {} + for k, _ in pairs(redis_data.hset) do + table.insert(keys, k) + end + return keys + elseif cmd == "HDEL" then + local key, field = args[1], args[2] + redis_data.hset[field] = nil + elseif cmd == "GET" then + return redis_data.kv[args[1]] + elseif cmd == "HMGET" then + local key = args[1] + local res = {} + for i = 2, #args do + table.insert(res, redis_data.hset[args[i]]) + end + return res + elseif cmd == "SET" then + redis_data.kv[args[1]] = args[2] + end + end +} + +-- The actual logic from lb_policy.go +local function run_lb_logic(KEYS) + local seed = tonumber(KEYS[1]) + local hset_key = KEYS[2] + local last_clean_key = KEYS[3] + local clean_interval = tonumber(KEYS[4]) + local current_target = KEYS[5] + local healthy_count = tonumber(KEYS[6]) + local enable_detail_log = KEYS[7] + + math.randomseed(seed) + + -- 1. Selection + local current_count = 0 + local same_count_hits = 0 + + for i = 8, 8 + healthy_count - 1 do + local host = KEYS[i] + local count = 0 + local val = redis.call('HGET', hset_key, host) + if val then + count = tonumber(val) or 0 + end + + if same_count_hits == 0 or count < current_count then + current_target = host + current_count = count + same_count_hits = 1 + elseif count == current_count then + same_count_hits = same_count_hits + 1 + if math.random(same_count_hits) == 1 then + current_target = host + end + end + end + + redis.call("HINCRBY", hset_key, current_target, 1) + local new_count = redis.call("HGET", hset_key, current_target) + + -- Collect host counts for logging + local host_details = {} + if enable_detail_log == "1" then + local fields = {} + for i = 8, #KEYS do + table.insert(fields, KEYS[i]) + end + if #fields > 0 then + local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields)) + for i, val in ipairs(values) do + table.insert(host_details, fields[i]) + table.insert(host_details, tostring(val or 0)) + end + end + end + + -- 2. Cleanup + local current_time = math.floor(seed / 1000000) + local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0) + + if current_time - last_clean_time >= clean_interval then + local all_keys = redis.call('HKEYS', hset_key) + if #all_keys > 0 then + -- Create a lookup table for current hosts (from index 8 onwards) + local current_hosts = {} + for i = 8, #KEYS do + current_hosts[KEYS[i]] = true + end + -- Remove keys not in current hosts + for _, host in ipairs(all_keys) do + if not current_hosts[host] then + redis.call('HDEL', hset_key, host) + end + end + end + redis.call('SET', last_clean_key, current_time) + end + + return {current_target, new_count, host_details} +end + +-- --- Test 1: Load Balancing Distribution --- +print("--- Test 1: Load Balancing Distribution ---") +local hosts = {"host1", "host2", "host3", "host4", "host5"} +local iterations = 100000 +local results = {} +for _, h in ipairs(hosts) do results[h] = 0 end + +-- Reset redis +redis_data.hset = {} +for _, h in ipairs(hosts) do redis_data.hset[h] = "0" end + +print(string.format("Running %d iterations with %d hosts (all counts started at 0)...", iterations, #hosts)) + +for i = 1, iterations do + local initial_host = hosts[math.random(#hosts)] + -- KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, enable_detail_log, ...healthy_hosts] + local keys = {i * 1000000, "table_key", "clean_key", 3600, initial_host, #hosts, "1"} + for _, h in ipairs(hosts) do table.insert(keys, h) end + + local res = run_lb_logic(keys) + local selected = res[1] + results[selected] = results[selected] + 1 +end + +for _, h in ipairs(hosts) do + local percentage = (results[h] / iterations) * 100 + print(string.format("%s: %6d (%.2f%%)", h, results[h], percentage)) +end + +-- --- Test 2: IP Cleanup Logic --- +print("\n--- Test 2: IP Cleanup Logic ---") + +local function test_cleanup() + redis_data.hset = { + ["host1"] = "10", + ["host2"] = "5", + ["old_ip_1"] = "1", + ["old_ip_2"] = "1", + } + redis_data.kv["clean_key"] = "1000" -- Last cleaned at 1000s + + local current_hosts = {"host1", "host2"} + local current_time_ms = 1000 * 1000000 + 500 * 1000000 -- 1500s (interval is 300s, let's say) + local clean_interval = 300 + + print("Initial Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end return res end)(), ", ")) + + -- Run logic (seed is microtime) + local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "1"} + for _, h in ipairs(current_hosts) do table.insert(keys, h) end + + run_lb_logic(keys) + + print("After Cleanup Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end table.sort(res) return res end)(), ", ")) + + local exists_old1 = redis_data.hset["old_ip_1"] ~= nil + local exists_old2 = redis_data.hset["old_ip_2"] ~= nil + + if not exists_old1 and not exists_old2 then + print("Success: Outdated IPs removed.") + else + print("Failure: Outdated IPs still exist.") + end + + print("New last_clean_time:", redis_data.kv["clean_key"]) +end + +test_cleanup() + +-- --- Test 3: No Cleanup if Interval Not Reached --- +print("\n--- Test 3: No Cleanup if Interval Not Reached ---") + +local function test_no_cleanup() + redis_data.hset = { + ["host1"] = "10", + ["old_ip_1"] = "1", + } + redis_data.kv["clean_key"] = "1000" + + local current_hosts = {"host1"} + local current_time_ms = 1000 * 1000000 + 100 * 1000000 -- 1100s (interval 300s, not reached) + local clean_interval = 300 + + local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "0"} + for _, h in ipairs(current_hosts) do table.insert(keys, h) end + + run_lb_logic(keys) + + if redis_data.hset["old_ip_1"] then + print("Success: Cleanup not triggered as expected.") + else + print("Failure: Cleanup triggered unexpectedly.") + end +end + +test_no_cleanup() diff --git a/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/rate_limit.go b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/rate_limit.go new file mode 100644 index 000000000..bae3ad777 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/rate_limit.go @@ -0,0 +1,24 @@ +package global_least_request + +import ( + "fmt" + + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +func (lb GlobalLeastRequestLoadBalancer) checkRateLimit(hostSelected string, currentCount int64, ctx wrapper.HttpContext, routeName string, clusterName string) bool { + // 如果没有配置最大请求数,直接通过 + if lb.maxRequestCount <= 0 { + return true + } + + // 如果当前请求数大于最大请求数,则限流 + // 注意:Lua脚本已经加了1,所以这里比较的是加1后的值 + if currentCount > lb.maxRequestCount { + // 恢复 Redis 计数 + lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected, -1, nil) + return false + } + + return true +}