mirror of
https://github.com/alibaba/higress.git
synced 2026-03-09 03:00:54 +08:00
feat(ai-load-balancer): enhance global least request load balancer (#3255)
This commit is contained in:
@@ -16,40 +16,91 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `local seed = KEYS[1]
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s"
|
||||
RedisLua = `local seed = tonumber(KEYS[1])
|
||||
local hset_key = KEYS[2]
|
||||
local current_target = KEYS[3]
|
||||
local current_count = 0
|
||||
local last_clean_key = KEYS[3]
|
||||
local clean_interval = tonumber(KEYS[4])
|
||||
local current_target = KEYS[5]
|
||||
local healthy_count = tonumber(KEYS[6])
|
||||
local enable_detail_log = KEYS[7]
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
local function randomBool()
|
||||
return math.random() >= 0.5
|
||||
end
|
||||
-- 1. Selection
|
||||
local current_count = 0
|
||||
local same_count_hits = 0
|
||||
|
||||
if redis.call('HEXISTS', hset_key, current_target) == 1 then
|
||||
current_count = redis.call('HGET', hset_key, current_target)
|
||||
for i = 4, #KEYS do
|
||||
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
|
||||
local count = redis.call('HGET', hset_key, KEYS[i])
|
||||
if tonumber(count) < tonumber(current_count) then
|
||||
current_target = KEYS[i]
|
||||
current_count = count
|
||||
elseif count == current_count and randomBool() then
|
||||
current_target = KEYS[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
for i = 8, 8 + healthy_count - 1 do
|
||||
local host = KEYS[i]
|
||||
local count = 0
|
||||
local val = redis.call('HGET', hset_key, host)
|
||||
if val then
|
||||
count = tonumber(val) or 0
|
||||
end
|
||||
|
||||
if same_count_hits == 0 or count < current_count then
|
||||
current_target = host
|
||||
current_count = count
|
||||
same_count_hits = 1
|
||||
elseif count == current_count then
|
||||
same_count_hits = same_count_hits + 1
|
||||
if math.random(same_count_hits) == 1 then
|
||||
current_target = host
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
local new_count = redis.call("HGET", hset_key, current_target)
|
||||
|
||||
return current_target`
|
||||
-- Collect host counts for logging
|
||||
local host_details = {}
|
||||
if enable_detail_log == "1" then
|
||||
local fields = {}
|
||||
for i = 8, #KEYS do
|
||||
table.insert(fields, KEYS[i])
|
||||
end
|
||||
if #fields > 0 then
|
||||
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
|
||||
for i, val in ipairs(values) do
|
||||
table.insert(host_details, fields[i])
|
||||
table.insert(host_details, tostring(val or 0))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. Cleanup
|
||||
local current_time = math.floor(seed / 1000000)
|
||||
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
|
||||
|
||||
if current_time - last_clean_time >= clean_interval then
|
||||
local all_keys = redis.call('HKEYS', hset_key)
|
||||
if #all_keys > 0 then
|
||||
-- Create a lookup table for current hosts (from index 8 onwards)
|
||||
local current_hosts = {}
|
||||
for i = 8, #KEYS do
|
||||
current_hosts[KEYS[i]] = true
|
||||
end
|
||||
-- Remove keys not in current hosts
|
||||
for _, host in ipairs(all_keys) do
|
||||
if not current_hosts[host] then
|
||||
redis.call('HDEL', hset_key, host)
|
||||
end
|
||||
end
|
||||
end
|
||||
redis.call('SET', last_clean_key, current_time)
|
||||
end
|
||||
|
||||
return {current_target, new_count, host_details}`
|
||||
)
|
||||
|
||||
type GlobalLeastRequestLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
redisClient wrapper.RedisClient
|
||||
maxRequestCount int64
|
||||
cleanInterval int64 // seconds
|
||||
enableDetailLog bool
|
||||
}
|
||||
|
||||
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
|
||||
@@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
lb.maxRequestCount = json.Get("maxRequestCount").Int()
|
||||
lb.cleanInterval = json.Get("cleanInterval").Int()
|
||||
if lb.cleanInterval == 0 {
|
||||
lb.cleanInterval = 60 * 60 // default 60 minutes
|
||||
} else {
|
||||
lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds
|
||||
}
|
||||
lb.enableDetailLog = true
|
||||
if val := json.Get("enableDetailLog"); val.Exists() {
|
||||
lb.enableDetailLog = val.Bool()
|
||||
}
|
||||
log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog)
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
@@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
allHostMap := make(map[string]struct{})
|
||||
// Only healthy host can be selected
|
||||
healthyHostArray := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
allHostMap[hostInfo[0]] = struct{}{}
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHostArray = append(healthyHostArray, hostInfo[0])
|
||||
}
|
||||
@@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
}
|
||||
randomIndex := rand.Intn(len(healthyHostArray))
|
||||
hostSelected := healthyHostArray[randomIndex]
|
||||
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
|
||||
|
||||
// KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts]
|
||||
keys := []interface{}{
|
||||
time.Now().UnixMicro(),
|
||||
fmt.Sprintf(RedisKeyFormat, routeName, clusterName),
|
||||
fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName),
|
||||
lb.cleanInterval,
|
||||
hostSelected,
|
||||
len(healthyHostArray),
|
||||
"0",
|
||||
}
|
||||
if lb.enableDetailLog {
|
||||
keys[6] = "1"
|
||||
}
|
||||
for _, v := range healthyHostArray {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
// Append unhealthy hosts (those in allHostMap but not in healthyHostArray)
|
||||
for host := range allHostMap {
|
||||
isHealthy := false
|
||||
for _, hh := range healthyHostArray {
|
||||
if host == hh {
|
||||
isHealthy = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isHealthy {
|
||||
keys = append(keys, host)
|
||||
}
|
||||
}
|
||||
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("HGetAll failed: %+v", err)
|
||||
@@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = response.String()
|
||||
valArray := response.Array()
|
||||
if len(valArray) < 2 {
|
||||
log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray)
|
||||
ctx.SetContext("error", true)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = valArray[0].String()
|
||||
currentCount := valArray[1].Integer()
|
||||
|
||||
// detail log
|
||||
if lb.enableDetailLog && len(valArray) >= 3 {
|
||||
detailLogStr := "host and count: "
|
||||
details := valArray[2].Array()
|
||||
for i := 0; i+1 < len(details); i += 2 {
|
||||
h := details[i].String()
|
||||
c := details[i+1].String()
|
||||
detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c)
|
||||
}
|
||||
log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr)
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) {
|
||||
ctx.SetContext("error", true)
|
||||
log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount)
|
||||
// return 429
|
||||
proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
return
|
||||
}
|
||||
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
|
||||
// finally resume the request
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
@@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
if err != nil {
|
||||
log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
-- Mocking Redis environment
|
||||
local redis_data = {
|
||||
hset = {},
|
||||
kv = {}
|
||||
}
|
||||
|
||||
local redis = {
|
||||
call = function(cmd, ...)
|
||||
local args = {...}
|
||||
if cmd == "HGET" then
|
||||
local key, field = args[1], args[2]
|
||||
return redis_data.hset[field]
|
||||
elseif cmd == "HSET" then
|
||||
local key, field, val = args[1], args[2], args[3]
|
||||
redis_data.hset[field] = val
|
||||
elseif cmd == "HINCRBY" then
|
||||
local key, field, increment = args[1], args[2], args[3]
|
||||
local val = tonumber(redis_data.hset[field] or 0)
|
||||
redis_data.hset[field] = tostring(val + increment)
|
||||
return redis_data.hset[field]
|
||||
elseif cmd == "HKEYS" then
|
||||
local keys = {}
|
||||
for k, _ in pairs(redis_data.hset) do
|
||||
table.insert(keys, k)
|
||||
end
|
||||
return keys
|
||||
elseif cmd == "HDEL" then
|
||||
local key, field = args[1], args[2]
|
||||
redis_data.hset[field] = nil
|
||||
elseif cmd == "GET" then
|
||||
return redis_data.kv[args[1]]
|
||||
elseif cmd == "HMGET" then
|
||||
local key = args[1]
|
||||
local res = {}
|
||||
for i = 2, #args do
|
||||
table.insert(res, redis_data.hset[args[i]])
|
||||
end
|
||||
return res
|
||||
elseif cmd == "SET" then
|
||||
redis_data.kv[args[1]] = args[2]
|
||||
end
|
||||
end
|
||||
}
|
||||
|
||||
-- The actual logic from lb_policy.go
|
||||
local function run_lb_logic(KEYS)
|
||||
local seed = tonumber(KEYS[1])
|
||||
local hset_key = KEYS[2]
|
||||
local last_clean_key = KEYS[3]
|
||||
local clean_interval = tonumber(KEYS[4])
|
||||
local current_target = KEYS[5]
|
||||
local healthy_count = tonumber(KEYS[6])
|
||||
local enable_detail_log = KEYS[7]
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
-- 1. Selection
|
||||
local current_count = 0
|
||||
local same_count_hits = 0
|
||||
|
||||
for i = 8, 8 + healthy_count - 1 do
|
||||
local host = KEYS[i]
|
||||
local count = 0
|
||||
local val = redis.call('HGET', hset_key, host)
|
||||
if val then
|
||||
count = tonumber(val) or 0
|
||||
end
|
||||
|
||||
if same_count_hits == 0 or count < current_count then
|
||||
current_target = host
|
||||
current_count = count
|
||||
same_count_hits = 1
|
||||
elseif count == current_count then
|
||||
same_count_hits = same_count_hits + 1
|
||||
if math.random(same_count_hits) == 1 then
|
||||
current_target = host
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
local new_count = redis.call("HGET", hset_key, current_target)
|
||||
|
||||
-- Collect host counts for logging
|
||||
local host_details = {}
|
||||
if enable_detail_log == "1" then
|
||||
local fields = {}
|
||||
for i = 8, #KEYS do
|
||||
table.insert(fields, KEYS[i])
|
||||
end
|
||||
if #fields > 0 then
|
||||
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
|
||||
for i, val in ipairs(values) do
|
||||
table.insert(host_details, fields[i])
|
||||
table.insert(host_details, tostring(val or 0))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. Cleanup
|
||||
local current_time = math.floor(seed / 1000000)
|
||||
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
|
||||
|
||||
if current_time - last_clean_time >= clean_interval then
|
||||
local all_keys = redis.call('HKEYS', hset_key)
|
||||
if #all_keys > 0 then
|
||||
-- Create a lookup table for current hosts (from index 8 onwards)
|
||||
local current_hosts = {}
|
||||
for i = 8, #KEYS do
|
||||
current_hosts[KEYS[i]] = true
|
||||
end
|
||||
-- Remove keys not in current hosts
|
||||
for _, host in ipairs(all_keys) do
|
||||
if not current_hosts[host] then
|
||||
redis.call('HDEL', hset_key, host)
|
||||
end
|
||||
end
|
||||
end
|
||||
redis.call('SET', last_clean_key, current_time)
|
||||
end
|
||||
|
||||
return {current_target, new_count, host_details}
|
||||
end
|
||||
|
||||
-- --- Test 1: Load Balancing Distribution ---
|
||||
print("--- Test 1: Load Balancing Distribution ---")
|
||||
local hosts = {"host1", "host2", "host3", "host4", "host5"}
|
||||
local iterations = 100000
|
||||
local results = {}
|
||||
for _, h in ipairs(hosts) do results[h] = 0 end
|
||||
|
||||
-- Reset redis
|
||||
redis_data.hset = {}
|
||||
for _, h in ipairs(hosts) do redis_data.hset[h] = "0" end
|
||||
|
||||
print(string.format("Running %d iterations with %d hosts (all counts started at 0)...", iterations, #hosts))
|
||||
|
||||
for i = 1, iterations do
|
||||
local initial_host = hosts[math.random(#hosts)]
|
||||
-- KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, enable_detail_log, ...healthy_hosts]
|
||||
local keys = {i * 1000000, "table_key", "clean_key", 3600, initial_host, #hosts, "1"}
|
||||
for _, h in ipairs(hosts) do table.insert(keys, h) end
|
||||
|
||||
local res = run_lb_logic(keys)
|
||||
local selected = res[1]
|
||||
results[selected] = results[selected] + 1
|
||||
end
|
||||
|
||||
for _, h in ipairs(hosts) do
|
||||
local percentage = (results[h] / iterations) * 100
|
||||
print(string.format("%s: %6d (%.2f%%)", h, results[h], percentage))
|
||||
end
|
||||
|
||||
-- --- Test 2: IP Cleanup Logic ---
|
||||
print("\n--- Test 2: IP Cleanup Logic ---")
|
||||
|
||||
local function test_cleanup()
|
||||
redis_data.hset = {
|
||||
["host1"] = "10",
|
||||
["host2"] = "5",
|
||||
["old_ip_1"] = "1",
|
||||
["old_ip_2"] = "1",
|
||||
}
|
||||
redis_data.kv["clean_key"] = "1000" -- Last cleaned at 1000s
|
||||
|
||||
local current_hosts = {"host1", "host2"}
|
||||
local current_time_ms = 1000 * 1000000 + 500 * 1000000 -- 1500s (interval is 300s, let's say)
|
||||
local clean_interval = 300
|
||||
|
||||
print("Initial Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end return res end)(), ", "))
|
||||
|
||||
-- Run logic (seed is microtime)
|
||||
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "1"}
|
||||
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
|
||||
|
||||
run_lb_logic(keys)
|
||||
|
||||
print("After Cleanup Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end table.sort(res) return res end)(), ", "))
|
||||
|
||||
local exists_old1 = redis_data.hset["old_ip_1"] ~= nil
|
||||
local exists_old2 = redis_data.hset["old_ip_2"] ~= nil
|
||||
|
||||
if not exists_old1 and not exists_old2 then
|
||||
print("Success: Outdated IPs removed.")
|
||||
else
|
||||
print("Failure: Outdated IPs still exist.")
|
||||
end
|
||||
|
||||
print("New last_clean_time:", redis_data.kv["clean_key"])
|
||||
end
|
||||
|
||||
test_cleanup()
|
||||
|
||||
-- --- Test 3: No Cleanup if Interval Not Reached ---
|
||||
print("\n--- Test 3: No Cleanup if Interval Not Reached ---")
|
||||
|
||||
local function test_no_cleanup()
|
||||
redis_data.hset = {
|
||||
["host1"] = "10",
|
||||
["old_ip_1"] = "1",
|
||||
}
|
||||
redis_data.kv["clean_key"] = "1000"
|
||||
|
||||
local current_hosts = {"host1"}
|
||||
local current_time_ms = 1000 * 1000000 + 100 * 1000000 -- 1100s (interval 300s, not reached)
|
||||
local clean_interval = 300
|
||||
|
||||
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "0"}
|
||||
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
|
||||
|
||||
run_lb_logic(keys)
|
||||
|
||||
if redis_data.hset["old_ip_1"] then
|
||||
print("Success: Cleanup not triggered as expected.")
|
||||
else
|
||||
print("Failure: Cleanup triggered unexpectedly.")
|
||||
end
|
||||
end
|
||||
|
||||
test_no_cleanup()
|
||||
@@ -0,0 +1,24 @@
|
||||
package global_least_request
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) checkRateLimit(hostSelected string, currentCount int64, ctx wrapper.HttpContext, routeName string, clusterName string) bool {
|
||||
// 如果没有配置最大请求数,直接通过
|
||||
if lb.maxRequestCount <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果当前请求数大于最大请求数,则限流
|
||||
// 注意:Lua脚本已经加了1,所以这里比较的是加1后的值
|
||||
if currentCount > lb.maxRequestCount {
|
||||
// 恢复 Redis 计数
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected, -1, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user