feat: advanced load balance policys for LLM service through wasm plugin (#2531)

This commit is contained in:
rinfx
2025-07-01 20:08:44 +08:00
committed by GitHub
parent db7dbb24a2
commit 9d68ccbf35
15 changed files with 1656 additions and 0 deletions

View File

@@ -0,0 +1 @@
test/

View File

@@ -0,0 +1,174 @@
---
title: AI负载均衡
keywords: [higress, llm, load balance]
description: 针对LLM服务的负载均衡策略
---
# 功能说明
**注意**
- Higress网关版本需要>=v2.1.5
对LLM服务提供热插拔的负载均衡策略如果关闭插件负载均衡策略会退化为服务本身的负载均衡策略轮训、本地最小请求数、随机、一致性hash等
配置如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
目前支持的负载均衡策略包括:
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
# 全局最小请求数
## 功能说明
```mermaid
sequenceDiagram
participant C as Client
participant H as Higress
participant R as Redis
participant H1 as Host1
participant H2 as Host2
C ->> H: 发起请求
H ->> R: 获取 host ongoing 请求数
R ->> H: 返回结果
H ->> R: 根据结果选择当前请求数最小的host计数+1
R ->> H: 返回结果
H ->> H1: 绕过service原本的负载均衡策略转发请求到对应host
H1 ->> H: 返回响应
H ->> R: host计数-1
H ->> C: 返回响应
```
## 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `serviceFQDN` | string | 必填 | | redis 服务的FQDN例如: `redis.dns` |
| `servicePort` | int | 必填 | | redis 服务的port |
| `username` | string | 必填 | | redis 用户名 |
| `password` | string | 选填 | 空 | redis 密码 |
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
| `database` | int | 选填 | 0 | redis 数据库序号 |
## 配置示例
```yaml
lb_policy: global_least_request
lb_config:
serviceFQDN: redis.static
servicePort: 6379
username: default
password: '123456'
```
# 前缀匹配
## 功能说明
根据 prompt 前缀匹配选择 pod以复用 KV Cache如果通过前缀匹配无法匹配到节点则通过全局最小请求数进行服务节点的选择
例如以下请求被路由到了pod 1
```json
{
"model": "qwen-turbo",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}
```
那么后续具有相同前缀的请求也会被路由到 pod 1
```json
{
"model": "qwen-turbo",
"messages": [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hi! How can I assist you today? 😊"
},
{
"role": "user",
"content": "write a short story aboud 100 words"
}
]
}
```
## 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `serviceFQDN` | string | 必填 | | redis 服务的FQDN例如: `redis.dns` |
| `servicePort` | int | 必填 | | redis 服务的port |
| `username` | string | 必填 | | redis 用户名 |
| `password` | string | 选填 | 空 | redis 密码 |
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
| `database` | int | 选填 | 0 | redis 数据库序号 |
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
## 配置示例
```yaml
lb_policy: prefix_cache
lb_config:
serviceFQDN: redis.static
servicePort: 6379
username: default
password: '123456'
```
# 最小负载
## 功能说明
[gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
```mermaid
sequenceDiagram
participant C as Client
participant H as Higress
participant H1 as Host1
participant H2 as Host2
loop 定期拉取metrics
H ->> H1: /metrics
H1 ->> H: vllm metrics
H ->> H2: /metrics
H2 ->> H: vllm metrics
end
C ->> H: 发起请求
H ->> H1: 根据vllm metrics选择合适的pod绕过服务原始的lb policy直接转发
H1 ->> H: 返回响应
H ->> C: 返回响应
```
<!-- pod选取流程图如下
![](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/docs/scheduler-flowchart.png) -->
## 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `criticalModels` | []string | 选填 | | critical的模型列表 |
## 配置示例
```yaml
lb_policy: least_busy
lb_config:
criticalModels:
- meta-llama/Llama-2-7b-hf
- sql-lora
```

View File

@@ -0,0 +1,177 @@
---
title: AI Load Balance
keywords: [higress, llm, load balance]
description: LLM-oriented load balance policies
---
# Introduction
**Attention**:
- Version of Higress should >= v2.1.5
This plug-in provides the llm-oriented load balancing capability in a hot-swappable manner. If the plugin is closed, the load balancing strategy will degenerate into the load balancing strategy of the service itself (round robin, local minimum request number, random, consistent hash, etc.).
The configuration is:
| Name | Type | Required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `lb_policy` | string | required | | load balance type |
| `lb_config` | object | required | | configuration for the current load balance type |
Current supported load balance policies are:
- `global_least_request`: global least request based on redis
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
# Global Least Request
## Introduction
```mermaid
sequenceDiagram
participant C as Client
participant H as Higress
participant R as Redis
participant H1 as Host1
participant H2 as Host2
C ->> H: Send request
H ->> R: Get host ongoing request number
R ->> H: Return result
H ->> R: According to the result, select the host with the smallest number of current requests, host rq count +1.
R ->> H: Return result
H ->> H1: Bypass the service's original load balancing strategy and forward the request to the corresponding host
H1 ->> H: Return result
H ->> R: host rq count -1
H ->> C: Receive response
```
## Configuration
| Name | Type | required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `serviceFQDN` | string | required | | redis FQDN, e.g. `redis.dns` |
| `servicePort` | int | required | | redis port |
| `username` | string | required | | redis username |
| `password` | string | optional | `` | redis password |
| `timeout` | int | optional | 3000ms | redis request timeout |
| `database` | int | optional | 0 | redis database number |
## Configuration Example
```yaml
lb_policy: global_least_request
lb_config:
serviceFQDN: redis.static
servicePort: 6379
username: default
password: '123456'
```
# Prefix Cache
## Introduction
Select pods based on the prompt prefix match to reuse KV Cache. If no node can be matched by prefix match, select the service node based on the global minimum number of requests.
For example, the following request is routed to pod 1:
```json
{
"model": "qwen-turbo",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}
```
Then subsequent requests with the same prefix will also be routed to pod 1:
```json
{
"model": "qwen-turbo",
"messages": [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hi! How can I assist you today? 😊"
},
{
"role": "user",
"content": "write a short story aboud 100 words"
}
]
}
```
## Configuration
| Name | Type | required | default | description |
|--------------------|-----------------|-----------------------|-------------|---------------------------------|
| `serviceFQDN` | string | required | | redis FQDN, e.g. `redis.dns` |
| `servicePort` | int | required | | redis port |
| `username` | string | required | | redis username |
| `password` | string | optional | `` | redis password |
| `timeout` | int | optional | 3000ms | redis request timeout |
| `database` | int | optional | 0 | redis database number |
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
## Configuration Example
```yaml
lb_policy: prefix_cache
lb_config:
serviceFQDN: redis.static
servicePort: 6379
username: default
password: '123456'
```
# Least Busy
## Introduction
wasm implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
```mermaid
sequenceDiagram
participant C as Client
participant H as Higress
participant H1 as Host1
participant H2 as Host2
loop fetch metrics periodically
H ->> H1: /metrics
H1 ->> H: vllm metrics
H ->> H2: /metrics
H2 ->> H: vllm metrics
end
C ->> H: request
H ->> H1: select pod according to vllm metrics, bypassing original service load balance policy
H1 ->> H: response
H ->> C: response
```
<!-- flowchart for pod selection:
![](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/docs/scheduler-flowchart.png) -->
## Configuration
| Name | Type | Required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `criticalModels` | []string | required | | critical model names |
## Configuration Example
```yaml
lb_policy: least_busy
lb_config:
criticalModels:
- meta-llama/Llama-2-7b-hf
- sql-lora
```

View File

@@ -0,0 +1,178 @@
package global_least_request
import (
"errors"
"fmt"
"math/rand"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLua = `local seed = KEYS[1]
local hset_key = KEYS[2]
local current_target = KEYS[3]
local current_count = 0
math.randomseed(seed)
local function randomBool()
return math.random() >= 0.5
end
local function is_healthy(addr)
for i = 4, #KEYS do
if addr == KEYS[i] then
return true
end
end
return false
end
if redis.call('HEXISTS', hset_key, current_target) ~= 0 then
current_count = redis.call('HGET', hset_key, current_target)
local hash = redis.call('HGETALL', hset_key)
for i = 1, #hash, 2 do
local addr = hash[i]
local count = hash[i+1]
if is_healthy(addr) then
if count < current_count then
current_target = addr
current_count = count
elseif count == current_count and randomBool() then
current_target = addr
current_count = count
end
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
return current_target`
)
type GlobalLeastRequestLoadBalancer struct {
redisClient wrapper.RedisClient
}
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
lb := GlobalLeastRequestLoadBalancer{}
serviceFQDN := json.Get("serviceFQDN").String()
servicePort := json.Get("servicePort").Int()
if serviceFQDN == "" || servicePort == 0 {
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
return lb, errors.New("invalid redis service config")
}
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: serviceFQDN,
Port: servicePort,
})
username := json.Get("username").String()
password := json.Get("password").String()
timeout := json.Get("timeout").Int()
if timeout == 0 {
timeout = 3000
}
// database default is 0
database := json.Get("database").Int()
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
routeName, err := utils.GetRouteName()
if err != nil || routeName == "" {
ctx.SetContext("error", true)
return types.ActionContinue
} else {
ctx.SetContext("routeName", routeName)
}
clusterName, err := utils.GetClusterName()
if err != nil || clusterName == "" {
ctx.SetContext("error", true)
return types.ActionContinue
} else {
ctx.SetContext("clusterName", clusterName)
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
ctx.SetContext("error", true)
return types.ActionContinue
}
// Only healthy host can be selected
healthyHostArray := []string{}
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
healthyHostArray = append(healthyHostArray, hostInfo[0])
}
}
if len(healthyHostArray) == 0 {
ctx.SetContext("error", true)
return types.ActionContinue
}
randomIndex := rand.Intn(len(healthyHostArray))
hostSelected := healthyHostArray[randomIndex]
keys := []interface{}{time.Now().Unix(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
for _, v := range healthyHostArray {
keys = append(keys, v)
}
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("HGetAll failed: %+v", err)
ctx.SetContext("error", true)
proxywasm.ResumeHttpRequest()
return
}
hostSelected = response.String()
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
ctx.SetContext("error", true)
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
}
log.Debugf("host_selected: %s", hostSelected)
ctx.SetContext("host_selected", hostSelected)
proxywasm.ResumeHttpRequest()
})
if err != nil {
ctx.SetContext("error", true)
return types.ActionContinue
}
return types.ActionPause
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
return types.ActionContinue
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
if endOfStream {
isErr, _ := ctx.GetContext("error").(bool)
if !isErr {
routeName, _ := ctx.GetContext("routeName").(string)
clusterName, _ := ctx.GetContext("clusterName").(string)
host_selected, _ := ctx.GetContext("host_selected").(string)
if host_selected == "" {
log.Errorf("get host_selected failed")
} else {
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
}
}
}
return data
}
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}

View File

@@ -0,0 +1,23 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer
go 1.24.1
toolchain go1.24.3
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.2
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
go.uber.org/multierr v1.11.0
)
require (
github.com/google/uuid v1.6.0 // indirect
github.com/prometheus/common v0.64.0
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
google.golang.org/protobuf v1.36.6 // indirect
)

View File

@@ -0,0 +1,35 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545 h1:zPXEonKCAeLvXI1IpwGpIeVSvLY5AZ9h9uTJnOuiA3Q=
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4=
github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,68 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package backend
import "fmt"
type PodSet map[Pod]bool
type Pod struct {
Name string
Address string
}
func (p Pod) String() string {
return p.Name + ":" + p.Address
}
type Metrics struct {
// ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU.
ActiveModels map[string]int
// MaxActiveModels is the maximum number of models that can be loaded to GPU.
MaxActiveModels int
RunningQueueSize int
WaitingQueueSize int
KVCacheUsagePercent float64
KvCacheMaxTokenCapacity int
}
type PodMetrics struct {
Pod
Metrics
}
func (pm *PodMetrics) String() string {
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
}
func (pm *PodMetrics) Clone() *PodMetrics {
cm := make(map[string]int, len(pm.ActiveModels))
for k, v := range pm.ActiveModels {
cm[k] = v
}
clone := &PodMetrics{
Pod: pm.Pod,
Metrics: Metrics{
ActiveModels: cm,
RunningQueueSize: pm.RunningQueueSize,
WaitingQueueSize: pm.WaitingQueueSize,
KVCacheUsagePercent: pm.KVCacheUsagePercent,
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
},
}
return clone
}

View File

@@ -0,0 +1,150 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package vllm provides vllm specific pod metrics implementation.
package vllm
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
dto "github.com/prometheus/client_model/go"
"go.uber.org/multierr"
)
const (
LoraRequestInfoMetricName = "vllm:lora_requests_info"
LoraRequestInfoRunningAdaptersMetricName = "running_lora_adapters"
LoraRequestInfoMaxAdaptersMetricName = "max_lora"
// TODO: Replace these with the num_tokens_running/waiting below once we add those to the fork.
RunningQueueSizeMetricName = "vllm:num_requests_running"
WaitingQueueSizeMetricName = "vllm:num_requests_waiting"
/* TODO: Uncomment this once the following are added to the fork.
RunningQueueSizeMetricName = "vllm:num_tokens_running"
WaitingQueueSizeMetricName = "vllm:num_tokens_waiting"
*/
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
KvCacheMaxTokenCapacityMetricName = "vllm:gpu_cache_max_token_capacity"
)
// promToPodMetrics updates internal pod metrics with scraped prometheus metrics.
// A combined error is returned if errors occur in one or more metric processing.
// it returns a new PodMetrics pointer which can be used to atomically update the pod metrics map.
func PromToPodMetrics(
metricFamilies map[string]*dto.MetricFamily,
existing *backend.PodMetrics,
) (*backend.PodMetrics, error) {
var errs error
updated := existing.Clone()
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
errs = multierr.Append(errs, err)
if err == nil {
updated.RunningQueueSize = int(runningQueueSize.GetGauge().GetValue())
}
waitingQueueSize, err := getLatestMetric(metricFamilies, WaitingQueueSizeMetricName)
errs = multierr.Append(errs, err)
if err == nil {
updated.WaitingQueueSize = int(waitingQueueSize.GetGauge().GetValue())
}
cachePercent, err := getLatestMetric(metricFamilies, KVCacheUsagePercentMetricName)
errs = multierr.Append(errs, err)
if err == nil {
updated.KVCacheUsagePercent = cachePercent.GetGauge().GetValue()
}
loraMetrics, _, err := getLatestLoraMetric(metricFamilies)
errs = multierr.Append(errs, err)
/* TODO: uncomment once this is available in vllm.
kvCap, _, err := getGaugeLatestValue(metricFamilies, KvCacheMaxTokenCapacityMetricName)
errs = multierr.Append(errs, err)
if err != nil {
updated.KvCacheMaxTokenCapacity = int(kvCap)
}
*/
if loraMetrics != nil {
updated.ActiveModels = make(map[string]int)
for _, label := range loraMetrics.GetLabel() {
if label.GetName() == LoraRequestInfoRunningAdaptersMetricName {
if label.GetValue() != "" {
adapterList := strings.Split(label.GetValue(), ",")
for _, adapter := range adapterList {
updated.ActiveModels[adapter] = 0
}
}
}
if label.GetName() == LoraRequestInfoMaxAdaptersMetricName {
if label.GetValue() != "" {
updated.MaxActiveModels, err = strconv.Atoi(label.GetValue())
if err != nil {
errs = multierr.Append(errs, err)
}
}
}
}
}
return updated, errs
}
// getLatestLoraMetric gets latest lora metric series in gauge metric family `vllm:lora_requests_info`
// reason its specially fetched is because each label key value pair permutation generates new series
// and only most recent is useful. The value of each series is the creation timestamp so we can
// retrieve the latest by sorting the value.
func getLatestLoraMetric(metricFamilies map[string]*dto.MetricFamily) (*dto.Metric, time.Time, error) {
loraRequests, ok := metricFamilies[LoraRequestInfoMetricName]
if !ok {
// klog.Warningf("metric family %q not found", LoraRequestInfoMetricName)
return nil, time.Time{}, fmt.Errorf("metric family %q not found", LoraRequestInfoMetricName)
}
var latestTs float64
var latest *dto.Metric
for _, m := range loraRequests.GetMetric() {
if m.GetGauge().GetValue() > latestTs {
latestTs = m.GetGauge().GetValue()
latest = m
}
}
return latest, time.Unix(0, int64(latestTs*1000)), nil
}
// getLatestMetric gets the latest metric of a family. This should be used to get the latest Gauge metric.
// Since vllm doesn't set the timestamp in metric, this metric essentially gets the first metric.
func getLatestMetric(metricFamilies map[string]*dto.MetricFamily, metricName string) (*dto.Metric, error) {
mf, ok := metricFamilies[metricName]
if !ok {
// klog.Warningf("metric family %q not found", metricName)
return nil, fmt.Errorf("metric family %q not found", metricName)
}
if len(mf.GetMetric()) == 0 {
return nil, fmt.Errorf("no metrics available for %q", metricName)
}
var latestTs int64
var latest *dto.Metric
for _, m := range mf.GetMetric() {
if m.GetTimestampMs() >= latestTs {
latestTs = m.GetTimestampMs()
latest = m
}
}
// klog.V(logutil.TRACE).Infof("Got metric value %+v for metric %v", latest, metricName)
return latest, nil
}

View File

@@ -0,0 +1,79 @@
package least_busy
import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type LeastBusyLoadBalancer struct {
criticalModels map[string]struct{}
}
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
lb := LeastBusyLoadBalancer{}
lb.criticalModels = make(map[string]struct{})
for _, model := range json.Get("criticalModels").Array() {
lb.criticalModels[model.String()] = struct{}{}
}
return lb, nil
}
// Callbacks which are called in request path
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
_, isCritical := lb.criticalModels[requestModel.String()]
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: isCritical,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
} else {
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
}
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}

View File

@@ -0,0 +1,203 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package scheduling
import (
"errors"
"math"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
type Filter interface {
Name() string
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
}
// filter applies current filterFunc, and then recursively applies next filters depending success or
// failure of the current filterFunc.
// It can be used to construct a flow chart algorithm.
type filter struct {
name string
filter filterFunc
// nextOnSuccess filter will be applied after successfully applying the current filter.
// The filtered results will be passed to the next filter.
nextOnSuccess *filter
// nextOnFailure filter will be applied if current filter fails.
// The original input will be passed to the next filter.
nextOnFailure *filter
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
// success or failure of the current filter.
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
nextOnSuccessOrFailure *filter
// callbacks api.FilterCallbackHandler
}
func (f *filter) Name() string {
if f == nil {
return "nil"
}
return f.name
}
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
proxywasm.LogDebugf("Running filter %q on request %v with %v pods", f.name, req, len(pods))
filtered, err := f.filter(req, pods)
next := f.nextOnSuccessOrFailure
if err == nil && len(filtered) > 0 {
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
// No succeeding filters to run, return.
return filtered, err
}
if f.nextOnSuccess != nil {
next = f.nextOnSuccess
}
// On success, pass the filtered result to the next filter.
return next.Filter(req, filtered)
} else {
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
// No succeeding filters to run, return.
return filtered, err
}
if f.nextOnFailure != nil {
next = f.nextOnFailure
}
// On failure, pass the initial set of pods to the next filter.
return next.Filter(req, pods)
}
}
// filterFunc filters a set of input pods to a subset.
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
func toFilterFunc(pp podPredicate) filterFunc {
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
pass := pp(req, pod)
if pass {
filtered = append(filtered, pod)
}
}
if len(filtered) == 0 {
return nil, errors.New("no pods left")
}
return filtered, nil
}
}
// leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range
// (max-min) by the number of pods, and finds the pods that fall into the first range.
// The intuition is that if there are multiple pods that share similar queue size in the low range,
// we should consider them all instead of the absolute minimum one. This worked better than picking
// the least one as it gives more choices for the next filter, which on aggregate gave better
// results.
// TODO: Compare this strategy with other strategies such as top K.
func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
min := math.MaxInt
max := 0
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
if pod.WaitingQueueSize <= min {
min = pod.WaitingQueueSize
}
if pod.WaitingQueueSize >= max {
max = pod.WaitingQueueSize
}
}
for _, pod := range pods {
if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
func lowQueueingPodPredicate(_ *LLMRequest, pod *backend.PodMetrics) bool {
return pod.WaitingQueueSize < queueingThresholdLoRA
}
// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range
// (max-min) by the number of pods, and finds the pods that fall into the first range.
// The intuition is that if there are multiple pods that share similar KV cache in the low range, we
// should consider them all instead of the absolute minimum one. This worked better than picking the
// least one as it gives more choices for the next filter, which on aggregate gave better results.
// TODO: Compare this strategy with other strategies such as top K.
func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
min := math.MaxFloat64
var max float64 = 0
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
if pod.KVCacheUsagePercent <= min {
min = pod.KVCacheUsagePercent
}
if pod.KVCacheUsagePercent >= max {
max = pod.KVCacheUsagePercent
}
}
for _, pod := range pods {
if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
// podPredicate is a filter function to check whether a pod is desired.
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
// We consider serving an adapter low cost it the adapter is active in the model server, or the
// model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by
// spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to
// a single pod. This gave good performance in our initial benchmarking results in the scenario
// where # of lora slots > # of lora adapters.
func lowLoRACostPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
_, ok := pod.ActiveModels[req.Model]
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
}
// loRAAffinityPredicate is a filter function to check whether a pod has affinity to the lora requested.
func loRAAffinityPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
_, ok := pod.ActiveModels[req.Model]
return ok
}
// canAcceptNewLoraPredicate is a filter function to check whether a pod has room to load the adapter.
func canAcceptNewLoraPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
return len(pod.ActiveModels) < pod.MaxActiveModels
}
func criticalRequestPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
return req.Critical
}
func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate {
return func(req *LLMRequest, pod *backend.PodMetrics) bool {
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
}
}

View File

@@ -0,0 +1,158 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package scheduling implements request scheduling algorithms.
package scheduling
import (
"errors"
"fmt"
"math/rand"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
"github.com/prometheus/common/expfmt"
)
const (
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
kvCacheThreshold = 0.8
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
queueThresholdCritical = 5
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
// the threshold for queued requests to be considered low below which we can prioritize LoRA affinity.
// The value of 50 is arrived heuristicically based on experiments.
queueingThresholdLoRA = 50
)
var (
defaultFilter = &filter{
name: "critical request",
filter: toFilterFunc(criticalRequestPredicate),
nextOnSuccess: lowLatencyFilter,
nextOnFailure: sheddableRequestFilter,
}
// queueLoRAAndKVCacheFilter applied least queue -> low cost lora -> least KV Cache filter
queueLoRAAndKVCacheFilter = &filter{
name: "least queuing",
filter: leastQueuingFilterFunc,
nextOnSuccessOrFailure: &filter{
name: "low cost LoRA",
filter: toFilterFunc(lowLoRACostPredicate),
nextOnSuccessOrFailure: &filter{
name: "least KV cache percent",
filter: leastKVCacheFilterFunc,
},
},
}
// queueAndKVCacheFilter applies least queue followed by least KV Cache filter
queueAndKVCacheFilter = &filter{
name: "least queuing",
filter: leastQueuingFilterFunc,
nextOnSuccessOrFailure: &filter{
name: "least KV cache percent",
filter: leastKVCacheFilterFunc,
},
}
lowLatencyFilter = &filter{
name: "low queueing filter",
filter: toFilterFunc((lowQueueingPodPredicate)),
nextOnSuccess: &filter{
name: "affinity LoRA",
filter: toFilterFunc(loRAAffinityPredicate),
nextOnSuccess: queueAndKVCacheFilter,
nextOnFailure: &filter{
name: "can accept LoRA Adapter",
filter: toFilterFunc(canAcceptNewLoraPredicate),
nextOnSuccessOrFailure: queueAndKVCacheFilter,
},
},
nextOnFailure: queueLoRAAndKVCacheFilter,
}
sheddableRequestFilter = &filter{
// When there is at least one model server that's not queuing requests, and still has KV
// cache below a certain threshold, we consider this model server has capacity to handle
// a sheddable request without impacting critical requests.
name: "has capacity for sheddable requests",
filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(queueThresholdCritical, kvCacheThreshold)),
nextOnSuccess: queueLoRAAndKVCacheFilter,
// If all pods are queuing or running above the KVCache threshold, we drop the sheddable
// request to make room for critical requests.
nextOnFailure: &filter{
name: "drop request",
filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
// api.LogDebugf("Dropping request %v", req)
return []*backend.PodMetrics{}, errors.New("dropping request due to limited backend resources")
},
},
}
)
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
return &Scheduler{
podMetrics: pm,
filter: defaultFilter,
}
}
type Scheduler struct {
podMetrics []*backend.PodMetrics
filter Filter
}
// Schedule finds the target pod based on metrics and the requested lora adapter.
func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error) {
pods, err := s.filter.Filter(req, s.podMetrics)
if err != nil || len(pods) == 0 {
return backend.Pod{}, fmt.Errorf("failed to apply filter, resulted %v pods: %w", len(pods), err)
}
i := rand.Intn(len(pods))
return pods[i].Pod, nil
}
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
if len(hostMetrics) == 0 {
return nil, errors.New("backend is not support llm scheduling")
}
var pms []*backend.PodMetrics
for addr, metric := range hostMetrics {
parser := expfmt.TextParser{}
metricFamilies, err := parser.TextToMetricFamilies(strings.NewReader(metric))
if err != nil {
return nil, err
}
pm := &backend.PodMetrics{
Pod: backend.Pod{
Name: addr,
Address: addr,
},
Metrics: backend.Metrics{},
}
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
if err != nil {
return nil, err
}
pms = append(pms, pm)
}
return NewScheduler(pms), nil
}

View File

@@ -0,0 +1,7 @@
package scheduling
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
type LLMRequest struct {
Model string
Critical bool
}

View File

@@ -0,0 +1,82 @@
package main
import (
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
)
func main() {}
func init() {
wrapper.SetCtx(
"ai-load-balancer",
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
)
}
type LoadBalancer interface {
HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action
HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action
HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action
HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte
HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action
}
type Config struct {
policy string
lb LoadBalancer
}
const (
LeastBusyLoadBalancerPolicy = "least_busy"
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
PrefixCache = "prefix_cache"
)
func parseConfig(json gjson.Result, config *Config) error {
config.policy = json.Get("lb_policy").String()
var err error
switch config.policy {
case LeastBusyLoadBalancerPolicy:
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
case GlobalLeastRequestLoadBalancerPolicy:
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
case PrefixCache:
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
default:
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
}
return err
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
return config.lb.HandleHttpRequestHeaders(ctx)
}
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
return config.lb.HandleHttpRequestBody(ctx, body)
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
return config.lb.HandleHttpResponseHeaders(ctx)
}
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config Config, data []byte, endOfStream bool) []byte {
return config.lb.HandleHttpStreamingResponseBody(ctx, data, endOfStream)
}
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
return config.lb.HandleHttpResponseBody(ctx, body)
}

View File

@@ -0,0 +1,302 @@
package prefix_cache
import (
"crypto/sha1"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLua = `-- hex string => bytes
local function hex_to_bytes(hex)
local bytes = {}
for i = 1, #hex, 2 do
local byte_str = hex:sub(i, i+1)
local byte_val = tonumber(byte_str, 16)
table.insert(bytes, byte_val)
end
return bytes
end
-- bytes => hex string
local function bytes_to_hex(bytes)
local result = ""
for _, byte in ipairs(bytes) do
result = result .. string.format("%02X", byte)
end
return result
end
-- byte XOR
local function byte_xor(a, b)
local result = 0
for i = 0, 7 do
local bit_val = 2^i
if ((a % (bit_val * 2)) >= bit_val) ~= ((b % (bit_val * 2)) >= bit_val) then
result = result + bit_val
end
end
return result
end
-- hex string XOR
local function hex_xor(a, b)
if #a ~= #b then
error("Hex strings must be of equal length, first is " .. a .. " second is " .. b)
end
local a_bytes = hex_to_bytes(a)
local b_bytes = hex_to_bytes(b)
local result_bytes = {}
for i = 1, #a_bytes do
table.insert(result_bytes, byte_xor(a_bytes[i], b_bytes[i]))
end
return bytes_to_hex(result_bytes)
end
-- check host whether healthy
local function is_healthy(addr)
for i = 4, #KEYS do
if addr == KEYS[i] then
return true
end
end
return false
end
local target = ""
local key = ""
local current_key = ""
local count = #ARGV
local ttl = KEYS[1]
local hset_key = KEYS[2]
local default_target = KEYS[3]
if count == 0 then
return target
end
-- find longest prefix
local index = 1
while index <= count do
if current_key == "" then
current_key = ARGV[index]
else
current_key = hex_xor(current_key, ARGV[index])
end
if redis.call("EXISTS", current_key) == 1 then
key = current_key
local tmp_target = redis.call("GET", key)
if not is_healthy(tmp_target) then
break
end
target = tmp_target
-- update ttl for exist keys
redis.call("EXPIRE", key, ttl)
index = index + 1
else
break
end
end
-- global least request
if target == "" then
index = 1
local current_count = 0
target = default_target
if redis.call('HEXISTS', hset_key, target) ~= 0 then
current_count = redis.call('HGET', hset_key, target)
local hash = redis.call('HGETALL', hset_key)
for i = 1, #hash, 2 do
local addr = hash[i]
local count = hash[i+1]
if count < current_count and is_healthy(addr) then
target = addr
current_count = count
end
end
end
end
-- update request count
redis.call("HINCRBY", hset_key, target, 1)
-- add tree-path
while index <= count do
if key == "" then
key = ARGV[index]
else
key = hex_xor(key, ARGV[index])
end
redis.call("SET", key, target)
redis.call("EXPIRE", key, ttl)
index = index + 1
end
return target`
)
type PrefixCacheLoadBalancer struct {
redisClient wrapper.RedisClient
redisKeyTTL int
}
func NewPrefixCacheLoadBalancer(json gjson.Result) (PrefixCacheLoadBalancer, error) {
lb := PrefixCacheLoadBalancer{}
serviceFQDN := json.Get("serviceFQDN").String()
servicePort := json.Get("servicePort").Int()
if serviceFQDN == "" || servicePort == 0 {
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
return lb, errors.New("invalid redis service config")
}
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: serviceFQDN,
Port: servicePort,
})
username := json.Get("username").String()
password := json.Get("password").String()
timeout := json.Get("timeout").Int()
if timeout == 0 {
timeout = 3000
}
// database default is 0
database := json.Get("database").Int()
if json.Get("redisKeyTTL").Int() == 0 {
lb.redisKeyTTL = int(json.Get("redisKeyTTL").Int())
} else {
lb.redisKeyTTL = 1800
}
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
}
func (lb PrefixCacheLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb PrefixCacheLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
var err error
routeName, err := utils.GetRouteName()
if err != nil || routeName == "" {
ctx.SetContext("error", true)
return types.ActionContinue
} else {
ctx.SetContext("routeName", routeName)
}
clusterName, err := utils.GetClusterName()
if err != nil || clusterName == "" {
ctx.SetContext("error", true)
return types.ActionContinue
} else {
ctx.SetContext("clusterName", clusterName)
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
ctx.SetContext("error", true)
log.Error("get upstream cluster endpoints failed")
return types.ActionContinue
}
healthyHosts := []string{}
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
healthyHosts = append(healthyHosts, hostInfo[0])
}
}
if len(healthyHosts) == 0 {
log.Info("upstream cluster has no healthy endpoints")
return types.ActionContinue
}
defaultHost := healthyHosts[rand.Intn(len(healthyHosts))]
params := []interface{}{}
rawStr := ""
messages := gjson.GetBytes(body, "messages").Array()
for index, obj := range messages {
if !obj.Get("role").Exists() || !obj.Get("content").Exists() {
ctx.SetContext("error", true)
log.Info("cannot extract role or content from request body, skip llm load balancing")
return types.ActionContinue
}
role := obj.Get("role").String()
content := obj.Get("content").String()
rawStr += role + ":" + content
if role == "user" || index == len(messages)-1 {
sha1Str := computeSHA1(rawStr)
params = append(params, sha1Str)
rawStr = ""
}
}
if len(params) == 0 {
return types.ActionContinue
}
keys := []interface{}{lb.redisKeyTTL, fmt.Sprintf(RedisKeyFormat, routeName, clusterName), defaultHost}
for _, v := range healthyHosts {
keys = append(keys, v)
}
err = lb.redisClient.Eval(RedisLua, len(keys), keys, params, func(response resp.Value) {
defer proxywasm.ResumeHttpRequest()
if err := response.Error(); err != nil {
ctx.SetContext("error", true)
log.Errorf("Redis eval failed: %+v", err)
return
}
hostSelected := response.String()
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
ctx.SetContext("error", true)
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
}
log.Debugf("host_selected: %s", hostSelected)
ctx.SetContext("host_selected", hostSelected)
})
if err != nil {
ctx.SetContext("error", true)
return types.ActionContinue
}
return types.ActionPause
}
func (lb PrefixCacheLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
return types.ActionContinue
}
func (lb PrefixCacheLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
if endOfStream {
isErr, _ := ctx.GetContext("error").(bool)
if !isErr {
routeName, _ := ctx.GetContext("routeName").(string)
clusterName, _ := ctx.GetContext("clusterName").(string)
host_selected, _ := ctx.GetContext("host_selected").(string)
if host_selected == "" {
log.Errorf("get host_selected failed")
} else {
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
}
}
}
return data
}
func (lb PrefixCacheLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func computeSHA1(data string) string {
hasher := sha1.New()
hasher.Write([]byte(data))
return strings.ToUpper(hex.EncodeToString(hasher.Sum(nil)))
}

View File

@@ -0,0 +1,19 @@
package utils
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
func GetRouteName() (string, error) {
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
return "", err
} else {
return string(raw), nil
}
}
func GetClusterName() (string, error) {
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
return "", err
} else {
return string(raw), nil
}
}