mirror of
https://github.com/alibaba/higress.git
synced 2026-03-07 18:10:54 +08:00
feat: advanced load balance policys for LLM service through wasm plugin (#2531)
This commit is contained in:
1
plugins/wasm-go/extensions/ai-load-balancer/.gitignore
vendored
Normal file
1
plugins/wasm-go/extensions/ai-load-balancer/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
test/
|
||||
174
plugins/wasm-go/extensions/ai-load-balancer/README.md
Normal file
174
plugins/wasm-go/extensions/ai-load-balancer/README.md
Normal file
@@ -0,0 +1,174 @@
|
||||
---
|
||||
title: AI负载均衡
|
||||
keywords: [higress, llm, load balance]
|
||||
description: 针对LLM服务的负载均衡策略
|
||||
---
|
||||
|
||||
# 功能说明
|
||||
|
||||
**注意**:
|
||||
- Higress网关版本需要>=v2.1.5
|
||||
|
||||
对LLM服务提供热插拔的负载均衡策略,如果关闭插件,负载均衡策略会退化为服务本身的负载均衡策略(轮训、本地最小请求数、随机、一致性hash等)。
|
||||
|
||||
配置如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
|
||||
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
|
||||
|
||||
目前支持的负载均衡策略包括:
|
||||
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
|
||||
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
|
||||
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
|
||||
|
||||
# 全局最小请求数
|
||||
## 功能说明
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant R as Redis
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
C ->> H: 发起请求
|
||||
H ->> R: 获取 host ongoing 请求数
|
||||
R ->> H: 返回结果
|
||||
H ->> R: 根据结果选择当前请求数最小的host,计数+1
|
||||
R ->> H: 返回结果
|
||||
H ->> H1: 绕过service原本的负载均衡策略,转发请求到对应host
|
||||
H1 ->> H: 返回响应
|
||||
H ->> R: host计数-1
|
||||
H ->> C: 返回响应
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `serviceFQDN` | string | 必填 | | redis 服务的FQDN,例如: `redis.dns` |
|
||||
| `servicePort` | int | 必填 | | redis 服务的port |
|
||||
| `username` | string | 必填 | | redis 用户名 |
|
||||
| `password` | string | 选填 | 空 | redis 密码 |
|
||||
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
|
||||
| `database` | int | 选填 | 0 | redis 数据库序号 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# 前缀匹配
|
||||
## 功能说明
|
||||
根据 prompt 前缀匹配选择 pod,以复用 KV Cache,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
|
||||
|
||||
例如以下请求被路由到了pod 1
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
那么后续具有相同前缀的请求也会被路由到 pod 1
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi! How can I assist you today? 😊"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write a short story aboud 100 words"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `serviceFQDN` | string | 必填 | | redis 服务的FQDN,例如: `redis.dns` |
|
||||
| `servicePort` | int | 必填 | | redis 服务的port |
|
||||
| `username` | string | 必填 | | redis 用户名 |
|
||||
| `password` | string | 选填 | 空 | redis 密码 |
|
||||
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
|
||||
| `database` | int | 选填 | 0 | redis 数据库序号 |
|
||||
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# 最小负载
|
||||
## 功能说明
|
||||
[gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
loop 定期拉取metrics
|
||||
H ->> H1: /metrics
|
||||
H1 ->> H: vllm metrics
|
||||
H ->> H2: /metrics
|
||||
H2 ->> H: vllm metrics
|
||||
end
|
||||
|
||||
C ->> H: 发起请求
|
||||
H ->> H1: 根据vllm metrics选择合适的pod,绕过服务原始的lb policy直接转发
|
||||
H1 ->> H: 返回响应
|
||||
H ->> C: 返回响应
|
||||
```
|
||||
|
||||
<!-- pod选取流程图如下:
|
||||
|
||||
 -->
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | 选填 | | critical的模型列表 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
```
|
||||
177
plugins/wasm-go/extensions/ai-load-balancer/README_EN.md
Normal file
177
plugins/wasm-go/extensions/ai-load-balancer/README_EN.md
Normal file
@@ -0,0 +1,177 @@
|
||||
---
|
||||
title: AI Load Balance
|
||||
keywords: [higress, llm, load balance]
|
||||
description: LLM-oriented load balance policies
|
||||
---
|
||||
|
||||
# Introduction
|
||||
|
||||
**Attention**:
|
||||
- Version of Higress should >= v2.1.5
|
||||
|
||||
This plug-in provides the llm-oriented load balancing capability in a hot-swappable manner. If the plugin is closed, the load balancing strategy will degenerate into the load balancing strategy of the service itself (round robin, local minimum request number, random, consistent hash, etc.).
|
||||
|
||||
The configuration is:
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_policy` | string | required | | load balance type |
|
||||
| `lb_config` | object | required | | configuration for the current load balance type |
|
||||
|
||||
Current supported load balance policies are:
|
||||
|
||||
- `global_least_request`: global least request based on redis
|
||||
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
|
||||
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
|
||||
|
||||
# Global Least Request
|
||||
## Introduction
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant R as Redis
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
C ->> H: Send request
|
||||
H ->> R: Get host ongoing request number
|
||||
R ->> H: Return result
|
||||
H ->> R: According to the result, select the host with the smallest number of current requests, host rq count +1.
|
||||
R ->> H: Return result
|
||||
H ->> H1: Bypass the service's original load balancing strategy and forward the request to the corresponding host
|
||||
H1 ->> H: Return result
|
||||
H ->> R: host rq count -1
|
||||
H ->> C: Receive response
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Name | Type | required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `serviceFQDN` | string | required | | redis FQDN, e.g. `redis.dns` |
|
||||
| `servicePort` | int | required | | redis port |
|
||||
| `username` | string | required | | redis username |
|
||||
| `password` | string | optional | `` | redis password |
|
||||
| `timeout` | int | optional | 3000ms | redis request timeout |
|
||||
| `database` | int | optional | 0 | redis database number |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# Prefix Cache
|
||||
## Introduction
|
||||
Select pods based on the prompt prefix match to reuse KV Cache. If no node can be matched by prefix match, select the service node based on the global minimum number of requests.
|
||||
|
||||
For example, the following request is routed to pod 1:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Then subsequent requests with the same prefix will also be routed to pod 1:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi! How can I assist you today? 😊"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write a short story aboud 100 words"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Name | Type | required | default | description |
|
||||
|--------------------|-----------------|-----------------------|-------------|---------------------------------|
|
||||
| `serviceFQDN` | string | required | | redis FQDN, e.g. `redis.dns` |
|
||||
| `servicePort` | int | required | | redis port |
|
||||
| `username` | string | required | | redis username |
|
||||
| `password` | string | optional | `` | redis password |
|
||||
| `timeout` | int | optional | 3000ms | redis request timeout |
|
||||
| `database` | int | optional | 0 | redis database number |
|
||||
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# Least Busy
|
||||
## Introduction
|
||||
|
||||
wasm implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
loop fetch metrics periodically
|
||||
H ->> H1: /metrics
|
||||
H1 ->> H: vllm metrics
|
||||
H ->> H2: /metrics
|
||||
H2 ->> H: vllm metrics
|
||||
end
|
||||
|
||||
C ->> H: request
|
||||
H ->> H1: select pod according to vllm metrics, bypassing original service load balance policy
|
||||
H1 ->> H: response
|
||||
H ->> C: response
|
||||
```
|
||||
|
||||
<!-- flowchart for pod selection:
|
||||
|
||||
 -->
|
||||
|
||||
## Configuration
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | required | | critical model names |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
```
|
||||
@@ -0,0 +1,178 @@
|
||||
package global_least_request
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `local seed = KEYS[1]
|
||||
local hset_key = KEYS[2]
|
||||
local current_target = KEYS[3]
|
||||
local current_count = 0
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
local function randomBool()
|
||||
return math.random() >= 0.5
|
||||
end
|
||||
|
||||
local function is_healthy(addr)
|
||||
for i = 4, #KEYS do
|
||||
if addr == KEYS[i] then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
if redis.call('HEXISTS', hset_key, current_target) ~= 0 then
|
||||
current_count = redis.call('HGET', hset_key, current_target)
|
||||
local hash = redis.call('HGETALL', hset_key)
|
||||
for i = 1, #hash, 2 do
|
||||
local addr = hash[i]
|
||||
local count = hash[i+1]
|
||||
if is_healthy(addr) then
|
||||
if count < current_count then
|
||||
current_target = addr
|
||||
current_count = count
|
||||
elseif count == current_count and randomBool() then
|
||||
current_target = addr
|
||||
current_count = count
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
|
||||
return current_target`
|
||||
)
|
||||
|
||||
type GlobalLeastRequestLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
}
|
||||
|
||||
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
|
||||
lb := GlobalLeastRequestLoadBalancer{}
|
||||
serviceFQDN := json.Get("serviceFQDN").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
if serviceFQDN == "" || servicePort == 0 {
|
||||
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
|
||||
return lb, errors.New("invalid redis service config")
|
||||
}
|
||||
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceFQDN,
|
||||
Port: servicePort,
|
||||
})
|
||||
username := json.Get("username").String()
|
||||
password := json.Get("password").String()
|
||||
timeout := json.Get("timeout").Int()
|
||||
if timeout == 0 {
|
||||
timeout = 3000
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
routeName, err := utils.GetRouteName()
|
||||
if err != nil || routeName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("routeName", routeName)
|
||||
}
|
||||
clusterName, err := utils.GetClusterName()
|
||||
if err != nil || clusterName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("clusterName", clusterName)
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
// Only healthy host can be selected
|
||||
healthyHostArray := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHostArray = append(healthyHostArray, hostInfo[0])
|
||||
}
|
||||
}
|
||||
if len(healthyHostArray) == 0 {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
randomIndex := rand.Intn(len(healthyHostArray))
|
||||
hostSelected := healthyHostArray[randomIndex]
|
||||
keys := []interface{}{time.Now().Unix(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
|
||||
for _, v := range healthyHostArray {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("HGetAll failed: %+v", err)
|
||||
ctx.SetContext("error", true)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = response.String()
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
}
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
if endOfStream {
|
||||
isErr, _ := ctx.GetContext("error").(bool)
|
||||
if !isErr {
|
||||
routeName, _ := ctx.GetContext("routeName").(string)
|
||||
clusterName, _ := ctx.GetContext("clusterName").(string)
|
||||
host_selected, _ := ctx.GetContext("host_selected").(string)
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
23
plugins/wasm-go/extensions/ai-load-balancer/go.mod
Normal file
23
plugins/wasm-go/extensions/ai-load-balancer/go.mod
Normal file
@@ -0,0 +1,23 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.3
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/resp v0.1.1
|
||||
go.uber.org/multierr v1.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/prometheus/common v0.64.0
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
)
|
||||
35
plugins/wasm-go/extensions/ai-load-balancer/go.sum
Normal file
35
plugins/wasm-go/extensions/ai-load-balancer/go.sum
Normal file
@@ -0,0 +1,35 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545 h1:zPXEonKCAeLvXI1IpwGpIeVSvLY5AZ9h9uTJnOuiA3Q=
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4=
|
||||
github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package backend
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PodSet map[Pod]bool
|
||||
|
||||
type Pod struct {
|
||||
Name string
|
||||
Address string
|
||||
}
|
||||
|
||||
func (p Pod) String() string {
|
||||
return p.Name + ":" + p.Address
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
// ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU.
|
||||
ActiveModels map[string]int
|
||||
// MaxActiveModels is the maximum number of models that can be loaded to GPU.
|
||||
MaxActiveModels int
|
||||
RunningQueueSize int
|
||||
WaitingQueueSize int
|
||||
KVCacheUsagePercent float64
|
||||
KvCacheMaxTokenCapacity int
|
||||
}
|
||||
|
||||
type PodMetrics struct {
|
||||
Pod
|
||||
Metrics
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) String() string {
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
cm := make(map[string]int, len(pm.ActiveModels))
|
||||
for k, v := range pm.ActiveModels {
|
||||
cm[k] = v
|
||||
}
|
||||
clone := &PodMetrics{
|
||||
Pod: pm.Pod,
|
||||
Metrics: Metrics{
|
||||
ActiveModels: cm,
|
||||
RunningQueueSize: pm.RunningQueueSize,
|
||||
WaitingQueueSize: pm.WaitingQueueSize,
|
||||
KVCacheUsagePercent: pm.KVCacheUsagePercent,
|
||||
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
|
||||
},
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package vllm provides vllm specific pod metrics implementation.
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
const (
|
||||
LoraRequestInfoMetricName = "vllm:lora_requests_info"
|
||||
LoraRequestInfoRunningAdaptersMetricName = "running_lora_adapters"
|
||||
LoraRequestInfoMaxAdaptersMetricName = "max_lora"
|
||||
// TODO: Replace these with the num_tokens_running/waiting below once we add those to the fork.
|
||||
RunningQueueSizeMetricName = "vllm:num_requests_running"
|
||||
WaitingQueueSizeMetricName = "vllm:num_requests_waiting"
|
||||
/* TODO: Uncomment this once the following are added to the fork.
|
||||
RunningQueueSizeMetricName = "vllm:num_tokens_running"
|
||||
WaitingQueueSizeMetricName = "vllm:num_tokens_waiting"
|
||||
*/
|
||||
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
|
||||
KvCacheMaxTokenCapacityMetricName = "vllm:gpu_cache_max_token_capacity"
|
||||
)
|
||||
|
||||
// promToPodMetrics updates internal pod metrics with scraped prometheus metrics.
|
||||
// A combined error is returned if errors occur in one or more metric processing.
|
||||
// it returns a new PodMetrics pointer which can be used to atomically update the pod metrics map.
|
||||
func PromToPodMetrics(
|
||||
metricFamilies map[string]*dto.MetricFamily,
|
||||
existing *backend.PodMetrics,
|
||||
) (*backend.PodMetrics, error) {
|
||||
var errs error
|
||||
updated := existing.Clone()
|
||||
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.RunningQueueSize = int(runningQueueSize.GetGauge().GetValue())
|
||||
}
|
||||
waitingQueueSize, err := getLatestMetric(metricFamilies, WaitingQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.WaitingQueueSize = int(waitingQueueSize.GetGauge().GetValue())
|
||||
}
|
||||
cachePercent, err := getLatestMetric(metricFamilies, KVCacheUsagePercentMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.KVCacheUsagePercent = cachePercent.GetGauge().GetValue()
|
||||
}
|
||||
|
||||
loraMetrics, _, err := getLatestLoraMetric(metricFamilies)
|
||||
errs = multierr.Append(errs, err)
|
||||
/* TODO: uncomment once this is available in vllm.
|
||||
kvCap, _, err := getGaugeLatestValue(metricFamilies, KvCacheMaxTokenCapacityMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err != nil {
|
||||
updated.KvCacheMaxTokenCapacity = int(kvCap)
|
||||
}
|
||||
*/
|
||||
|
||||
if loraMetrics != nil {
|
||||
updated.ActiveModels = make(map[string]int)
|
||||
for _, label := range loraMetrics.GetLabel() {
|
||||
if label.GetName() == LoraRequestInfoRunningAdaptersMetricName {
|
||||
if label.GetValue() != "" {
|
||||
adapterList := strings.Split(label.GetValue(), ",")
|
||||
for _, adapter := range adapterList {
|
||||
updated.ActiveModels[adapter] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if label.GetName() == LoraRequestInfoMaxAdaptersMetricName {
|
||||
if label.GetValue() != "" {
|
||||
updated.MaxActiveModels, err = strconv.Atoi(label.GetValue())
|
||||
if err != nil {
|
||||
errs = multierr.Append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return updated, errs
|
||||
}
|
||||
|
||||
// getLatestLoraMetric gets latest lora metric series in gauge metric family `vllm:lora_requests_info`
|
||||
// reason its specially fetched is because each label key value pair permutation generates new series
|
||||
// and only most recent is useful. The value of each series is the creation timestamp so we can
|
||||
// retrieve the latest by sorting the value.
|
||||
func getLatestLoraMetric(metricFamilies map[string]*dto.MetricFamily) (*dto.Metric, time.Time, error) {
|
||||
loraRequests, ok := metricFamilies[LoraRequestInfoMetricName]
|
||||
if !ok {
|
||||
// klog.Warningf("metric family %q not found", LoraRequestInfoMetricName)
|
||||
return nil, time.Time{}, fmt.Errorf("metric family %q not found", LoraRequestInfoMetricName)
|
||||
}
|
||||
var latestTs float64
|
||||
var latest *dto.Metric
|
||||
for _, m := range loraRequests.GetMetric() {
|
||||
if m.GetGauge().GetValue() > latestTs {
|
||||
latestTs = m.GetGauge().GetValue()
|
||||
latest = m
|
||||
}
|
||||
}
|
||||
return latest, time.Unix(0, int64(latestTs*1000)), nil
|
||||
}
|
||||
|
||||
// getLatestMetric gets the latest metric of a family. This should be used to get the latest Gauge metric.
|
||||
// Since vllm doesn't set the timestamp in metric, this metric essentially gets the first metric.
|
||||
func getLatestMetric(metricFamilies map[string]*dto.MetricFamily, metricName string) (*dto.Metric, error) {
|
||||
mf, ok := metricFamilies[metricName]
|
||||
if !ok {
|
||||
// klog.Warningf("metric family %q not found", metricName)
|
||||
return nil, fmt.Errorf("metric family %q not found", metricName)
|
||||
}
|
||||
if len(mf.GetMetric()) == 0 {
|
||||
return nil, fmt.Errorf("no metrics available for %q", metricName)
|
||||
}
|
||||
var latestTs int64
|
||||
var latest *dto.Metric
|
||||
for _, m := range mf.GetMetric() {
|
||||
if m.GetTimestampMs() >= latestTs {
|
||||
latestTs = m.GetTimestampMs()
|
||||
latest = m
|
||||
}
|
||||
}
|
||||
// klog.V(logutil.TRACE).Infof("Got metric value %+v for metric %v", latest, metricName)
|
||||
return latest, nil
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package least_busy
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type LeastBusyLoadBalancer struct {
|
||||
criticalModels map[string]struct{}
|
||||
}
|
||||
|
||||
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
|
||||
lb := LeastBusyLoadBalancer{}
|
||||
lb.criticalModels = make(map[string]struct{})
|
||||
for _, model := range json.Get("criticalModels").Array() {
|
||||
lb.criticalModels[model.String()] = struct{}{}
|
||||
}
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
_, isCritical := lb.criticalModels[requestModel.String()]
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: isCritical,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
|
||||
} else {
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
type Filter interface {
|
||||
Name() string
|
||||
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
|
||||
}
|
||||
|
||||
// filter applies current filterFunc, and then recursively applies next filters depending success or
|
||||
// failure of the current filterFunc.
|
||||
// It can be used to construct a flow chart algorithm.
|
||||
type filter struct {
|
||||
name string
|
||||
filter filterFunc
|
||||
// nextOnSuccess filter will be applied after successfully applying the current filter.
|
||||
// The filtered results will be passed to the next filter.
|
||||
nextOnSuccess *filter
|
||||
// nextOnFailure filter will be applied if current filter fails.
|
||||
// The original input will be passed to the next filter.
|
||||
nextOnFailure *filter
|
||||
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
|
||||
// success or failure of the current filter.
|
||||
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
|
||||
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
|
||||
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
|
||||
nextOnSuccessOrFailure *filter
|
||||
|
||||
// callbacks api.FilterCallbackHandler
|
||||
}
|
||||
|
||||
func (f *filter) Name() string {
|
||||
if f == nil {
|
||||
return "nil"
|
||||
}
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
proxywasm.LogDebugf("Running filter %q on request %v with %v pods", f.name, req, len(pods))
|
||||
filtered, err := f.filter(req, pods)
|
||||
|
||||
next := f.nextOnSuccessOrFailure
|
||||
if err == nil && len(filtered) > 0 {
|
||||
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
|
||||
// No succeeding filters to run, return.
|
||||
return filtered, err
|
||||
}
|
||||
if f.nextOnSuccess != nil {
|
||||
next = f.nextOnSuccess
|
||||
}
|
||||
// On success, pass the filtered result to the next filter.
|
||||
return next.Filter(req, filtered)
|
||||
} else {
|
||||
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
|
||||
// No succeeding filters to run, return.
|
||||
return filtered, err
|
||||
}
|
||||
if f.nextOnFailure != nil {
|
||||
next = f.nextOnFailure
|
||||
}
|
||||
// On failure, pass the initial set of pods to the next filter.
|
||||
return next.Filter(req, pods)
|
||||
}
|
||||
}
|
||||
|
||||
// filterFunc filters a set of input pods to a subset.
|
||||
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
|
||||
|
||||
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
|
||||
func toFilterFunc(pp podPredicate) filterFunc {
|
||||
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
filtered := []*backend.PodMetrics{}
|
||||
for _, pod := range pods {
|
||||
pass := pp(req, pod)
|
||||
if pass {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, errors.New("no pods left")
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
}
|
||||
|
||||
// leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range
|
||||
// (max-min) by the number of pods, and finds the pods that fall into the first range.
|
||||
// The intuition is that if there are multiple pods that share similar queue size in the low range,
|
||||
// we should consider them all instead of the absolute minimum one. This worked better than picking
|
||||
// the least one as it gives more choices for the next filter, which on aggregate gave better
|
||||
// results.
|
||||
// TODO: Compare this strategy with other strategies such as top K.
|
||||
func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxInt
|
||||
max := 0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.WaitingQueueSize <= min {
|
||||
min = pod.WaitingQueueSize
|
||||
}
|
||||
if pod.WaitingQueueSize >= max {
|
||||
max = pod.WaitingQueueSize
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func lowQueueingPodPredicate(_ *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return pod.WaitingQueueSize < queueingThresholdLoRA
|
||||
}
|
||||
|
||||
// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range
|
||||
// (max-min) by the number of pods, and finds the pods that fall into the first range.
|
||||
// The intuition is that if there are multiple pods that share similar KV cache in the low range, we
|
||||
// should consider them all instead of the absolute minimum one. This worked better than picking the
|
||||
// least one as it gives more choices for the next filter, which on aggregate gave better results.
|
||||
// TODO: Compare this strategy with other strategies such as top K.
|
||||
func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
var max float64 = 0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.KVCacheUsagePercent <= min {
|
||||
min = pod.KVCacheUsagePercent
|
||||
}
|
||||
if pod.KVCacheUsagePercent >= max {
|
||||
max = pod.KVCacheUsagePercent
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// podPredicate is a filter function to check whether a pod is desired.
|
||||
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
|
||||
|
||||
// We consider serving an adapter low cost it the adapter is active in the model server, or the
|
||||
// model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by
|
||||
// spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to
|
||||
// a single pod. This gave good performance in our initial benchmarking results in the scenario
|
||||
// where # of lora slots > # of lora adapters.
|
||||
func lowLoRACostPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
_, ok := pod.ActiveModels[req.Model]
|
||||
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
|
||||
}
|
||||
|
||||
// loRAAffinityPredicate is a filter function to check whether a pod has affinity to the lora requested.
|
||||
func loRAAffinityPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
_, ok := pod.ActiveModels[req.Model]
|
||||
return ok
|
||||
}
|
||||
|
||||
// canAcceptNewLoraPredicate is a filter function to check whether a pod has room to load the adapter.
|
||||
func canAcceptNewLoraPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return len(pod.ActiveModels) < pod.MaxActiveModels
|
||||
}
|
||||
|
||||
func criticalRequestPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return req.Critical
|
||||
}
|
||||
|
||||
func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate {
|
||||
return func(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package scheduling implements request scheduling algorithms.
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
|
||||
|
||||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
kvCacheThreshold = 0.8
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
queueThresholdCritical = 5
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
// the threshold for queued requests to be considered low below which we can prioritize LoRA affinity.
|
||||
// The value of 50 is arrived heuristicically based on experiments.
|
||||
queueingThresholdLoRA = 50
|
||||
)
|
||||
|
||||
var (
|
||||
defaultFilter = &filter{
|
||||
name: "critical request",
|
||||
filter: toFilterFunc(criticalRequestPredicate),
|
||||
nextOnSuccess: lowLatencyFilter,
|
||||
nextOnFailure: sheddableRequestFilter,
|
||||
}
|
||||
|
||||
// queueLoRAAndKVCacheFilter applied least queue -> low cost lora -> least KV Cache filter
|
||||
queueLoRAAndKVCacheFilter = &filter{
|
||||
name: "least queuing",
|
||||
filter: leastQueuingFilterFunc,
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "low cost LoRA",
|
||||
filter: toFilterFunc(lowLoRACostPredicate),
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "least KV cache percent",
|
||||
filter: leastKVCacheFilterFunc,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// queueAndKVCacheFilter applies least queue followed by least KV Cache filter
|
||||
queueAndKVCacheFilter = &filter{
|
||||
name: "least queuing",
|
||||
filter: leastQueuingFilterFunc,
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "least KV cache percent",
|
||||
filter: leastKVCacheFilterFunc,
|
||||
},
|
||||
}
|
||||
|
||||
lowLatencyFilter = &filter{
|
||||
name: "low queueing filter",
|
||||
filter: toFilterFunc((lowQueueingPodPredicate)),
|
||||
nextOnSuccess: &filter{
|
||||
name: "affinity LoRA",
|
||||
filter: toFilterFunc(loRAAffinityPredicate),
|
||||
nextOnSuccess: queueAndKVCacheFilter,
|
||||
nextOnFailure: &filter{
|
||||
name: "can accept LoRA Adapter",
|
||||
filter: toFilterFunc(canAcceptNewLoraPredicate),
|
||||
nextOnSuccessOrFailure: queueAndKVCacheFilter,
|
||||
},
|
||||
},
|
||||
nextOnFailure: queueLoRAAndKVCacheFilter,
|
||||
}
|
||||
|
||||
sheddableRequestFilter = &filter{
|
||||
// When there is at least one model server that's not queuing requests, and still has KV
|
||||
// cache below a certain threshold, we consider this model server has capacity to handle
|
||||
// a sheddable request without impacting critical requests.
|
||||
name: "has capacity for sheddable requests",
|
||||
filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(queueThresholdCritical, kvCacheThreshold)),
|
||||
nextOnSuccess: queueLoRAAndKVCacheFilter,
|
||||
// If all pods are queuing or running above the KVCache threshold, we drop the sheddable
|
||||
// request to make room for critical requests.
|
||||
nextOnFailure: &filter{
|
||||
name: "drop request",
|
||||
filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
// api.LogDebugf("Dropping request %v", req)
|
||||
return []*backend.PodMetrics{}, errors.New("dropping request due to limited backend resources")
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
|
||||
|
||||
return &Scheduler{
|
||||
podMetrics: pm,
|
||||
filter: defaultFilter,
|
||||
}
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
podMetrics []*backend.PodMetrics
|
||||
filter Filter
|
||||
}
|
||||
|
||||
// Schedule finds the target pod based on metrics and the requested lora adapter.
|
||||
func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error) {
|
||||
pods, err := s.filter.Filter(req, s.podMetrics)
|
||||
if err != nil || len(pods) == 0 {
|
||||
return backend.Pod{}, fmt.Errorf("failed to apply filter, resulted %v pods: %w", len(pods), err)
|
||||
}
|
||||
i := rand.Intn(len(pods))
|
||||
return pods[i].Pod, nil
|
||||
}
|
||||
|
||||
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
if len(hostMetrics) == 0 {
|
||||
return nil, errors.New("backend is not support llm scheduling")
|
||||
}
|
||||
var pms []*backend.PodMetrics
|
||||
for addr, metric := range hostMetrics {
|
||||
parser := expfmt.TextParser{}
|
||||
metricFamilies, err := parser.TextToMetricFamilies(strings.NewReader(metric))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pm := &backend.PodMetrics{
|
||||
Pod: backend.Pod{
|
||||
Name: addr,
|
||||
Address: addr,
|
||||
},
|
||||
Metrics: backend.Metrics{},
|
||||
}
|
||||
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pms = append(pms, pm)
|
||||
}
|
||||
return NewScheduler(pms), nil
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package scheduling
|
||||
|
||||
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
|
||||
type LLMRequest struct {
|
||||
Model string
|
||||
Critical bool
|
||||
}
|
||||
82
plugins/wasm-go/extensions/ai-load-balancer/main.go
Normal file
82
plugins/wasm-go/extensions/ai-load-balancer/main.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
|
||||
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
|
||||
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"ai-load-balancer",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
type LoadBalancer interface {
|
||||
HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action
|
||||
HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action
|
||||
HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action
|
||||
HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte
|
||||
HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
policy string
|
||||
lb LoadBalancer
|
||||
}
|
||||
|
||||
const (
|
||||
LeastBusyLoadBalancerPolicy = "least_busy"
|
||||
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
|
||||
PrefixCache = "prefix_cache"
|
||||
)
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.policy = json.Get("lb_policy").String()
|
||||
var err error
|
||||
switch config.policy {
|
||||
case LeastBusyLoadBalancerPolicy:
|
||||
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
|
||||
case GlobalLeastRequestLoadBalancerPolicy:
|
||||
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
|
||||
case PrefixCache:
|
||||
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
|
||||
default:
|
||||
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
return config.lb.HandleHttpRequestHeaders(ctx)
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
return config.lb.HandleHttpRequestBody(ctx, body)
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
return config.lb.HandleHttpResponseHeaders(ctx)
|
||||
}
|
||||
|
||||
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config Config, data []byte, endOfStream bool) []byte {
|
||||
return config.lb.HandleHttpStreamingResponseBody(ctx, data, endOfStream)
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
return config.lb.HandleHttpResponseBody(ctx, body)
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package prefix_cache
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `-- hex string => bytes
|
||||
local function hex_to_bytes(hex)
|
||||
local bytes = {}
|
||||
for i = 1, #hex, 2 do
|
||||
local byte_str = hex:sub(i, i+1)
|
||||
local byte_val = tonumber(byte_str, 16)
|
||||
table.insert(bytes, byte_val)
|
||||
end
|
||||
return bytes
|
||||
end
|
||||
|
||||
-- bytes => hex string
|
||||
local function bytes_to_hex(bytes)
|
||||
local result = ""
|
||||
for _, byte in ipairs(bytes) do
|
||||
result = result .. string.format("%02X", byte)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- byte XOR
|
||||
local function byte_xor(a, b)
|
||||
local result = 0
|
||||
for i = 0, 7 do
|
||||
local bit_val = 2^i
|
||||
if ((a % (bit_val * 2)) >= bit_val) ~= ((b % (bit_val * 2)) >= bit_val) then
|
||||
result = result + bit_val
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- hex string XOR
|
||||
local function hex_xor(a, b)
|
||||
if #a ~= #b then
|
||||
error("Hex strings must be of equal length, first is " .. a .. " second is " .. b)
|
||||
end
|
||||
|
||||
local a_bytes = hex_to_bytes(a)
|
||||
local b_bytes = hex_to_bytes(b)
|
||||
|
||||
local result_bytes = {}
|
||||
for i = 1, #a_bytes do
|
||||
table.insert(result_bytes, byte_xor(a_bytes[i], b_bytes[i]))
|
||||
end
|
||||
|
||||
return bytes_to_hex(result_bytes)
|
||||
end
|
||||
|
||||
-- check host whether healthy
|
||||
local function is_healthy(addr)
|
||||
for i = 4, #KEYS do
|
||||
if addr == KEYS[i] then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local target = ""
|
||||
local key = ""
|
||||
local current_key = ""
|
||||
local count = #ARGV
|
||||
local ttl = KEYS[1]
|
||||
local hset_key = KEYS[2]
|
||||
local default_target = KEYS[3]
|
||||
|
||||
if count == 0 then
|
||||
return target
|
||||
end
|
||||
|
||||
-- find longest prefix
|
||||
local index = 1
|
||||
while index <= count do
|
||||
if current_key == "" then
|
||||
current_key = ARGV[index]
|
||||
else
|
||||
current_key = hex_xor(current_key, ARGV[index])
|
||||
end
|
||||
if redis.call("EXISTS", current_key) == 1 then
|
||||
key = current_key
|
||||
local tmp_target = redis.call("GET", key)
|
||||
if not is_healthy(tmp_target) then
|
||||
break
|
||||
end
|
||||
target = tmp_target
|
||||
-- update ttl for exist keys
|
||||
redis.call("EXPIRE", key, ttl)
|
||||
index = index + 1
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
-- global least request
|
||||
if target == "" then
|
||||
index = 1
|
||||
local current_count = 0
|
||||
target = default_target
|
||||
if redis.call('HEXISTS', hset_key, target) ~= 0 then
|
||||
current_count = redis.call('HGET', hset_key, target)
|
||||
local hash = redis.call('HGETALL', hset_key)
|
||||
for i = 1, #hash, 2 do
|
||||
local addr = hash[i]
|
||||
local count = hash[i+1]
|
||||
if count < current_count and is_healthy(addr) then
|
||||
target = addr
|
||||
current_count = count
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- update request count
|
||||
redis.call("HINCRBY", hset_key, target, 1)
|
||||
|
||||
-- add tree-path
|
||||
while index <= count do
|
||||
if key == "" then
|
||||
key = ARGV[index]
|
||||
else
|
||||
key = hex_xor(key, ARGV[index])
|
||||
end
|
||||
redis.call("SET", key, target)
|
||||
redis.call("EXPIRE", key, ttl)
|
||||
index = index + 1
|
||||
end
|
||||
|
||||
return target`
|
||||
)
|
||||
|
||||
type PrefixCacheLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
redisKeyTTL int
|
||||
}
|
||||
|
||||
func NewPrefixCacheLoadBalancer(json gjson.Result) (PrefixCacheLoadBalancer, error) {
|
||||
lb := PrefixCacheLoadBalancer{}
|
||||
serviceFQDN := json.Get("serviceFQDN").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
if serviceFQDN == "" || servicePort == 0 {
|
||||
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
|
||||
return lb, errors.New("invalid redis service config")
|
||||
}
|
||||
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceFQDN,
|
||||
Port: servicePort,
|
||||
})
|
||||
username := json.Get("username").String()
|
||||
password := json.Get("password").String()
|
||||
timeout := json.Get("timeout").Int()
|
||||
if timeout == 0 {
|
||||
timeout = 3000
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
if json.Get("redisKeyTTL").Int() == 0 {
|
||||
lb.redisKeyTTL = int(json.Get("redisKeyTTL").Int())
|
||||
} else {
|
||||
lb.redisKeyTTL = 1800
|
||||
}
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
var err error
|
||||
routeName, err := utils.GetRouteName()
|
||||
if err != nil || routeName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("routeName", routeName)
|
||||
}
|
||||
clusterName, err := utils.GetClusterName()
|
||||
if err != nil || clusterName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("clusterName", clusterName)
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Error("get upstream cluster endpoints failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
healthyHosts := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHosts = append(healthyHosts, hostInfo[0])
|
||||
}
|
||||
}
|
||||
if len(healthyHosts) == 0 {
|
||||
log.Info("upstream cluster has no healthy endpoints")
|
||||
return types.ActionContinue
|
||||
}
|
||||
defaultHost := healthyHosts[rand.Intn(len(healthyHosts))]
|
||||
params := []interface{}{}
|
||||
rawStr := ""
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
for index, obj := range messages {
|
||||
if !obj.Get("role").Exists() || !obj.Get("content").Exists() {
|
||||
ctx.SetContext("error", true)
|
||||
log.Info("cannot extract role or content from request body, skip llm load balancing")
|
||||
return types.ActionContinue
|
||||
}
|
||||
role := obj.Get("role").String()
|
||||
content := obj.Get("content").String()
|
||||
rawStr += role + ":" + content
|
||||
if role == "user" || index == len(messages)-1 {
|
||||
sha1Str := computeSHA1(rawStr)
|
||||
params = append(params, sha1Str)
|
||||
rawStr = ""
|
||||
}
|
||||
}
|
||||
if len(params) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
keys := []interface{}{lb.redisKeyTTL, fmt.Sprintf(RedisKeyFormat, routeName, clusterName), defaultHost}
|
||||
for _, v := range healthyHosts {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, params, func(response resp.Value) {
|
||||
defer proxywasm.ResumeHttpRequest()
|
||||
if err := response.Error(); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("Redis eval failed: %+v", err)
|
||||
return
|
||||
}
|
||||
hostSelected := response.String()
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
}
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
if endOfStream {
|
||||
isErr, _ := ctx.GetContext("error").(bool)
|
||||
if !isErr {
|
||||
routeName, _ := ctx.GetContext("routeName").(string)
|
||||
clusterName, _ := ctx.GetContext("clusterName").(string)
|
||||
host_selected, _ := ctx.GetContext("host_selected").(string)
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func computeSHA1(data string) string {
|
||||
hasher := sha1.New()
|
||||
hasher.Write([]byte(data))
|
||||
return strings.ToUpper(hex.EncodeToString(hasher.Sum(nil)))
|
||||
}
|
||||
19
plugins/wasm-go/extensions/ai-load-balancer/utils/utils.go
Normal file
19
plugins/wasm-go/extensions/ai-load-balancer/utils/utils.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
|
||||
func GetRouteName() (string, error) {
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return string(raw), nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetClusterName() (string, error) {
|
||||
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return string(raw), nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user