mirror of
https://github.com/alibaba/higress.git
synced 2026-03-05 17:10:55 +08:00
305 lines
8.3 KiB
Go
305 lines
8.3 KiB
Go
package prefix_cache
|
|
|
|
import (
|
|
"crypto/sha1"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"strings"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
|
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/higress-group/wasm-go/pkg/log"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/resp"
|
|
)
|
|
|
|
const (
|
|
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
|
RedisLua = `-- hex string => bytes
|
|
local function hex_to_bytes(hex)
|
|
local bytes = {}
|
|
for i = 1, #hex, 2 do
|
|
local byte_str = hex:sub(i, i+1)
|
|
local byte_val = tonumber(byte_str, 16)
|
|
table.insert(bytes, byte_val)
|
|
end
|
|
return bytes
|
|
end
|
|
|
|
-- bytes => hex string
|
|
local function bytes_to_hex(bytes)
|
|
local result = ""
|
|
for _, byte in ipairs(bytes) do
|
|
result = result .. string.format("%02X", byte)
|
|
end
|
|
return result
|
|
end
|
|
|
|
-- byte XOR
|
|
local function byte_xor(a, b)
|
|
local result = 0
|
|
for i = 0, 7 do
|
|
local bit_val = 2^i
|
|
if ((a % (bit_val * 2)) >= bit_val) ~= ((b % (bit_val * 2)) >= bit_val) then
|
|
result = result + bit_val
|
|
end
|
|
end
|
|
return result
|
|
end
|
|
|
|
-- hex string XOR
|
|
local function hex_xor(a, b)
|
|
if #a ~= #b then
|
|
error("Hex strings must be of equal length, first is " .. a .. " second is " .. b)
|
|
end
|
|
|
|
local a_bytes = hex_to_bytes(a)
|
|
local b_bytes = hex_to_bytes(b)
|
|
|
|
local result_bytes = {}
|
|
for i = 1, #a_bytes do
|
|
table.insert(result_bytes, byte_xor(a_bytes[i], b_bytes[i]))
|
|
end
|
|
|
|
return bytes_to_hex(result_bytes)
|
|
end
|
|
|
|
-- check host whether healthy
|
|
local function is_healthy(addr)
|
|
for i = 4, #KEYS do
|
|
if addr == KEYS[i] then
|
|
return true
|
|
end
|
|
end
|
|
return false
|
|
end
|
|
|
|
local function randomBool()
|
|
return math.random() >= 0.5
|
|
end
|
|
|
|
local target = ""
|
|
local key = ""
|
|
local current_key = ""
|
|
local ttl = KEYS[1]
|
|
local hset_key = KEYS[2]
|
|
local default_target = KEYS[3]
|
|
|
|
-- find longest prefix
|
|
local index = 1
|
|
while index <= #ARGV do
|
|
if current_key == "" then
|
|
current_key = ARGV[index]
|
|
else
|
|
current_key = hex_xor(current_key, ARGV[index])
|
|
end
|
|
if redis.call("EXISTS", current_key) == 1 then
|
|
key = current_key
|
|
local tmp_target = redis.call("GET", key)
|
|
if not is_healthy(tmp_target) then
|
|
break
|
|
end
|
|
target = tmp_target
|
|
-- update ttl for exist keys
|
|
redis.call("EXPIRE", key, ttl)
|
|
index = index + 1
|
|
else
|
|
break
|
|
end
|
|
end
|
|
|
|
|
|
-- global least request
|
|
if target == "" then
|
|
index = 1
|
|
local current_count = 0
|
|
target = default_target
|
|
if redis.call('HEXISTS', hset_key, target) == 1 then
|
|
current_count = redis.call('HGET', hset_key, target)
|
|
for i = 4, #KEYS do
|
|
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
|
|
local count = redis.call('HGET', hset_key, KEYS[i])
|
|
if tonumber(count) < tonumber(current_count) then
|
|
target = KEYS[i]
|
|
current_count = count
|
|
elseif count == current_count and randomBool() then
|
|
target = KEYS[i]
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
-- update request count
|
|
redis.call("HINCRBY", hset_key, target, 1)
|
|
|
|
-- add tree-path
|
|
while index <= #ARGV do
|
|
if key == "" then
|
|
key = ARGV[index]
|
|
else
|
|
key = hex_xor(key, ARGV[index])
|
|
end
|
|
redis.call("SET", key, target)
|
|
redis.call("EXPIRE", key, ttl)
|
|
index = index + 1
|
|
end
|
|
|
|
return target`
|
|
)
|
|
|
|
type PrefixCacheLoadBalancer struct {
|
|
redisClient wrapper.RedisClient
|
|
redisKeyTTL int
|
|
}
|
|
|
|
func NewPrefixCacheLoadBalancer(json gjson.Result) (PrefixCacheLoadBalancer, error) {
|
|
lb := PrefixCacheLoadBalancer{}
|
|
serviceFQDN := json.Get("serviceFQDN").String()
|
|
servicePort := json.Get("servicePort").Int()
|
|
if serviceFQDN == "" || servicePort == 0 {
|
|
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
|
|
return lb, errors.New("invalid redis service config")
|
|
}
|
|
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
|
FQDN: serviceFQDN,
|
|
Port: servicePort,
|
|
})
|
|
username := json.Get("username").String()
|
|
password := json.Get("password").String()
|
|
timeout := json.Get("timeout").Int()
|
|
if timeout == 0 {
|
|
timeout = 3000
|
|
}
|
|
// database default is 0
|
|
database := json.Get("database").Int()
|
|
if json.Get("redisKeyTTL").Int() != 0 {
|
|
lb.redisKeyTTL = int(json.Get("redisKeyTTL").Int())
|
|
} else {
|
|
lb.redisKeyTTL = 1800
|
|
}
|
|
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
|
}
|
|
|
|
func (lb PrefixCacheLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
|
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
|
return types.HeaderStopIteration
|
|
}
|
|
|
|
func (lb PrefixCacheLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
|
var err error
|
|
routeName, err := utils.GetRouteName()
|
|
if err != nil || routeName == "" {
|
|
ctx.SetContext("error", true)
|
|
return types.ActionContinue
|
|
} else {
|
|
ctx.SetContext("routeName", routeName)
|
|
}
|
|
clusterName, err := utils.GetClusterName()
|
|
if err != nil || clusterName == "" {
|
|
ctx.SetContext("error", true)
|
|
return types.ActionContinue
|
|
} else {
|
|
ctx.SetContext("clusterName", clusterName)
|
|
}
|
|
hostInfos, err := proxywasm.GetUpstreamHosts()
|
|
if err != nil {
|
|
ctx.SetContext("error", true)
|
|
log.Error("get upstream cluster endpoints failed")
|
|
return types.ActionContinue
|
|
}
|
|
healthyHosts := []string{}
|
|
for _, hostInfo := range hostInfos {
|
|
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
|
healthyHosts = append(healthyHosts, hostInfo[0])
|
|
}
|
|
}
|
|
if len(healthyHosts) == 0 {
|
|
log.Info("upstream cluster has no healthy endpoints")
|
|
return types.ActionContinue
|
|
}
|
|
defaultHost := healthyHosts[rand.Intn(len(healthyHosts))]
|
|
params := []interface{}{}
|
|
rawStr := ""
|
|
messages := gjson.GetBytes(body, "messages").Array()
|
|
for index, obj := range messages {
|
|
if !obj.Get("role").Exists() || !obj.Get("content").Exists() {
|
|
ctx.SetContext("error", true)
|
|
log.Info("cannot extract role or content from request body, skip llm load balancing")
|
|
return types.ActionContinue
|
|
}
|
|
role := obj.Get("role").String()
|
|
content := obj.Get("content").String()
|
|
rawStr += role + ":" + content
|
|
if role == "user" || index == len(messages)-1 {
|
|
sha1Str := computeSHA1(rawStr)
|
|
params = append(params, sha1Str)
|
|
rawStr = ""
|
|
}
|
|
}
|
|
if len(params) == 0 {
|
|
return types.ActionContinue
|
|
}
|
|
keys := []interface{}{lb.redisKeyTTL, fmt.Sprintf(RedisKeyFormat, routeName, clusterName), defaultHost}
|
|
for _, v := range healthyHosts {
|
|
keys = append(keys, v)
|
|
}
|
|
err = lb.redisClient.Eval(RedisLua, len(keys), keys, params, func(response resp.Value) {
|
|
defer proxywasm.ResumeHttpRequest()
|
|
if err := response.Error(); err != nil {
|
|
ctx.SetContext("error", true)
|
|
log.Errorf("Redis eval failed: %+v", err)
|
|
return
|
|
}
|
|
hostSelected := response.String()
|
|
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
|
ctx.SetContext("error", true)
|
|
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
|
}
|
|
log.Debugf("host_selected: %s", hostSelected)
|
|
ctx.SetContext("host_selected", hostSelected)
|
|
})
|
|
if err != nil {
|
|
ctx.SetContext("error", true)
|
|
return types.ActionContinue
|
|
}
|
|
return types.ActionPause
|
|
}
|
|
|
|
func (lb PrefixCacheLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func (lb PrefixCacheLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
|
return data
|
|
}
|
|
|
|
func (lb PrefixCacheLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func (lb PrefixCacheLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {
|
|
isErr, _ := ctx.GetContext("error").(bool)
|
|
if !isErr {
|
|
routeName, _ := ctx.GetContext("routeName").(string)
|
|
clusterName, _ := ctx.GetContext("clusterName").(string)
|
|
host_selected, _ := ctx.GetContext("host_selected").(string)
|
|
if host_selected == "" {
|
|
log.Errorf("get host_selected failed")
|
|
} else {
|
|
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
func computeSHA1(data string) string {
|
|
hasher := sha1.New()
|
|
hasher.Write([]byte(data))
|
|
return strings.ToUpper(hex.EncodeToString(hasher.Sum(nil)))
|
|
}
|