Files
higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request/lb_policy.go

180 lines
5.6 KiB
Go

package global_least_request
import (
"errors"
"fmt"
"math/rand"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLua = `local seed = KEYS[1]
local hset_key = KEYS[2]
local current_target = KEYS[3]
local current_count = 0
math.randomseed(seed)
local function randomBool()
return math.random() >= 0.5
end
local function is_healthy(addr)
for i = 4, #KEYS do
if addr == KEYS[i] then
return true
end
end
return false
end
if redis.call('HEXISTS', hset_key, current_target) == 1 then
current_count = redis.call('HGET', hset_key, current_target)
local hash = redis.call('HGETALL', hset_key)
for i = 1, #hash, 2 do
local addr = hash[i]
local count = hash[i+1]
if is_healthy(addr) then
if tonumber(count) < tonumber(current_count) then
current_target = addr
current_count = count
elseif count == current_count and randomBool() then
current_target = addr
current_count = count
end
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
return current_target`
)
type GlobalLeastRequestLoadBalancer struct {
redisClient wrapper.RedisClient
}
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
lb := GlobalLeastRequestLoadBalancer{}
serviceFQDN := json.Get("serviceFQDN").String()
servicePort := json.Get("servicePort").Int()
if serviceFQDN == "" || servicePort == 0 {
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
return lb, errors.New("invalid redis service config")
}
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: serviceFQDN,
Port: servicePort,
})
username := json.Get("username").String()
password := json.Get("password").String()
timeout := json.Get("timeout").Int()
if timeout == 0 {
timeout = 3000
}
// database default is 0
database := json.Get("database").Int()
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
routeName, err := utils.GetRouteName()
if err != nil || routeName == "" {
ctx.SetContext("error", true)
return types.ActionContinue
} else {
ctx.SetContext("routeName", routeName)
}
clusterName, err := utils.GetClusterName()
if err != nil || clusterName == "" {
ctx.SetContext("error", true)
return types.ActionContinue
} else {
ctx.SetContext("clusterName", clusterName)
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
ctx.SetContext("error", true)
return types.ActionContinue
}
// Only healthy host can be selected
healthyHostArray := []string{}
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
healthyHostArray = append(healthyHostArray, hostInfo[0])
}
}
if len(healthyHostArray) == 0 {
ctx.SetContext("error", true)
return types.ActionContinue
}
randomIndex := rand.Intn(len(healthyHostArray))
hostSelected := healthyHostArray[randomIndex]
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
for _, v := range healthyHostArray {
keys = append(keys, v)
}
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("HGetAll failed: %+v", err)
ctx.SetContext("error", true)
proxywasm.ResumeHttpRequest()
return
}
hostSelected = response.String()
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
ctx.SetContext("error", true)
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
}
log.Debugf("host_selected: %s", hostSelected)
ctx.SetContext("host_selected", hostSelected)
proxywasm.ResumeHttpRequest()
})
if err != nil {
ctx.SetContext("error", true)
return types.ActionContinue
}
return types.ActionPause
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
return types.ActionContinue
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {
isErr, _ := ctx.GetContext("error").(bool)
if !isErr {
routeName, _ := ctx.GetContext("routeName").(string)
clusterName, _ := ctx.GetContext("clusterName").(string)
host_selected, _ := ctx.GetContext("host_selected").(string)
if host_selected == "" {
log.Errorf("get host_selected failed")
} else {
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
}
}
}