feat(ai-load-balancer): enhance global least request load balancer (#3255)

This commit is contained in:
nixidexiangjiao
2025-12-29 09:28:56 +08:00
committed by jingze
parent 4f04ac067b
commit d55b9a0837
3 changed files with 401 additions and 25 deletions

View File

@@ -16,40 +16,91 @@ import (
)
const (
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLua = `local seed = KEYS[1]
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s"
RedisLua = `local seed = tonumber(KEYS[1])
local hset_key = KEYS[2]
local current_target = KEYS[3]
local current_count = 0
local last_clean_key = KEYS[3]
local clean_interval = tonumber(KEYS[4])
local current_target = KEYS[5]
local healthy_count = tonumber(KEYS[6])
local enable_detail_log = KEYS[7]
math.randomseed(seed)
local function randomBool()
return math.random() >= 0.5
end
-- 1. Selection
local current_count = 0
local same_count_hits = 0
if redis.call('HEXISTS', hset_key, current_target) == 1 then
current_count = redis.call('HGET', hset_key, current_target)
for i = 4, #KEYS do
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
local count = redis.call('HGET', hset_key, KEYS[i])
if tonumber(count) < tonumber(current_count) then
current_target = KEYS[i]
current_count = count
elseif count == current_count and randomBool() then
current_target = KEYS[i]
end
end
end
for i = 8, 8 + healthy_count - 1 do
local host = KEYS[i]
local count = 0
local val = redis.call('HGET', hset_key, host)
if val then
count = tonumber(val) or 0
end
if same_count_hits == 0 or count < current_count then
current_target = host
current_count = count
same_count_hits = 1
elseif count == current_count then
same_count_hits = same_count_hits + 1
if math.random(same_count_hits) == 1 then
current_target = host
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
local new_count = redis.call("HGET", hset_key, current_target)
return current_target`
-- Collect host counts for logging
local host_details = {}
if enable_detail_log == "1" then
local fields = {}
for i = 8, #KEYS do
table.insert(fields, KEYS[i])
end
if #fields > 0 then
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
for i, val in ipairs(values) do
table.insert(host_details, fields[i])
table.insert(host_details, tostring(val or 0))
end
end
end
-- 2. Cleanup
local current_time = math.floor(seed / 1000000)
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
if current_time - last_clean_time >= clean_interval then
local all_keys = redis.call('HKEYS', hset_key)
if #all_keys > 0 then
-- Create a lookup table for current hosts (from index 8 onwards)
local current_hosts = {}
for i = 8, #KEYS do
current_hosts[KEYS[i]] = true
end
-- Remove keys not in current hosts
for _, host in ipairs(all_keys) do
if not current_hosts[host] then
redis.call('HDEL', hset_key, host)
end
end
end
redis.call('SET', last_clean_key, current_time)
end
return {current_target, new_count, host_details}`
)
type GlobalLeastRequestLoadBalancer struct {
redisClient wrapper.RedisClient
redisClient wrapper.RedisClient
maxRequestCount int64
cleanInterval int64 // seconds
enableDetailLog bool
}
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
@@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa
}
// database default is 0
database := json.Get("database").Int()
lb.maxRequestCount = json.Get("maxRequestCount").Int()
lb.cleanInterval = json.Get("cleanInterval").Int()
if lb.cleanInterval == 0 {
lb.cleanInterval = 60 * 60 // default 60 minutes
} else {
lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds
}
lb.enableDetailLog = true
if val := json.Get("enableDetailLog"); val.Exists() {
lb.enableDetailLog = val.Bool()
}
log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog)
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
}
@@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
ctx.SetContext("error", true)
return types.ActionContinue
}
allHostMap := make(map[string]struct{})
// Only healthy host can be selected
healthyHostArray := []string{}
for _, hostInfo := range hostInfos {
allHostMap[hostInfo[0]] = struct{}{}
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
healthyHostArray = append(healthyHostArray, hostInfo[0])
}
@@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
}
randomIndex := rand.Intn(len(healthyHostArray))
hostSelected := healthyHostArray[randomIndex]
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
// KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts]
keys := []interface{}{
time.Now().UnixMicro(),
fmt.Sprintf(RedisKeyFormat, routeName, clusterName),
fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName),
lb.cleanInterval,
hostSelected,
len(healthyHostArray),
"0",
}
if lb.enableDetailLog {
keys[6] = "1"
}
for _, v := range healthyHostArray {
keys = append(keys, v)
}
// Append unhealthy hosts (those in allHostMap but not in healthyHostArray)
for host := range allHostMap {
isHealthy := false
for _, hh := range healthyHostArray {
if host == hh {
isHealthy = true
break
}
}
if !isHealthy {
keys = append(keys, host)
}
}
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("HGetAll failed: %+v", err)
@@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
proxywasm.ResumeHttpRequest()
return
}
hostSelected = response.String()
valArray := response.Array()
if len(valArray) < 2 {
log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray)
ctx.SetContext("error", true)
proxywasm.ResumeHttpRequest()
return
}
hostSelected = valArray[0].String()
currentCount := valArray[1].Integer()
// detail log
if lb.enableDetailLog && len(valArray) >= 3 {
detailLogStr := "host and count: "
details := valArray[2].Array()
for i := 0; i+1 < len(details); i += 2 {
h := details[i].String()
c := details[i+1].String()
detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c)
}
log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr)
}
// check rate limit
if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) {
ctx.SetContext("error", true)
log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount)
// return 429
proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1)
ctx.DontReadResponseBody()
return
}
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
ctx.SetContext("error", true)
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
proxywasm.ResumeHttpRequest()
return
}
log.Debugf("host_selected: %s", hostSelected)
// finally resume the request
ctx.SetContext("host_selected", hostSelected)
proxywasm.ResumeHttpRequest()
})
if err != nil {
ctx.SetContext("error", true)
log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err)
return types.ActionContinue
}
return types.ActionPause
@@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo
if host_selected == "" {
log.Errorf("get host_selected failed")
} else {
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
if err != nil {
log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err)
}
}
}
}