[feat] load balancing across different clusters and endpoints based on metrics (#3063)

This commit is contained in:
rinfx
2025-11-25 10:32:34 +08:00
committed by GitHub
parent 7a504fd67d
commit 42334f21df
12 changed files with 764 additions and 126 deletions

View File

@@ -15,13 +15,18 @@ description: 针对LLM服务的负载均衡策略
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `lb_type` | string | 选填 | endpoint | 负载均衡类型,可选`endpoint`,`cluster` |
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
目前支持的负载均衡策略包括:
`lb_type``endpoint`支持的负载均衡策略包括:
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
- `endpoint_metrics`: 基于 llm 服务暴露的 metrics 进行负载均衡
`lb_type``cluster` 时支持的负载均衡策略包括:
- `cluster_metrics`: 基于网关统计的不同service的指标进行服务之间的负载均衡
# 全局最小请求数
## 功能说明
@@ -59,6 +64,7 @@ sequenceDiagram
## 配置示例
```yaml
lb_type: endpoint
lb_policy: global_least_request
lb_config:
serviceFQDN: redis.static
@@ -116,11 +122,12 @@ lb_config:
| `password` | string | 选填 | 空 | redis 密码 |
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
| `database` | int | 选填 | 0 | redis 数据库序号 |
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
| `redisKeyTTL` | int | 选填 | 1800s | prompt 前缀对应的key的ttl |
## 配置示例
```yaml
lb_type: endpoint
lb_policy: prefix_cache
lb_config:
serviceFQDN: redis.static
@@ -161,14 +168,73 @@ sequenceDiagram
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `criticalModels` | []string | | | critical的模型列表 |
| `metric_policy` | string | | | 如何使用llm暴露的metrics做负载均衡当前支持`[default, least, most]` |
| `target_metric` | string | 选填 | | 要使用的metric名称`metric_policy` 取值为 `least` 或者 `most` 时生效 |
| `rate_limit` | string | 选填 | 1 | 单个节点处理请求比例上限取值范围0~1 |
## 配置示例
使用 [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 中的算法
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: default
rate_limit: 0.6 # 单个节点承载的最大请求比例
```
根据当前排队请求数进行负载均衡
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_waiting
rate_limit: 0.6 # 单个节点承载的最大请求比例
```
根据当前GPU中正在处理的请求数进行负载均衡
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_running
rate_limit: 0.6 # 单个节点承载的最大请求比例
```
# 跨服务负载均衡
## 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `mode` | string | 必填 | | 如何使用服务级指标做负载均衡,当前支持`[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
| `service_list` | []string | 必填 | | 路由后端服务列表 |
| `rate_limit` | string | 选填 | 1 | 单个服务处理请求比例上限取值范围0~1 |
| `cluster_header` | string | 选填 | `x-envoy-target-cluster` | 通过取该header的值得知需要路由到哪个后端服务 |
| `queue_size` | int | 选填 | 100 | 根据最近的多少个请求进行观测指标的计算 |
`mode` 各取值含义如下:
- `LeastBusy`: 路由到当前并发请求数最少的服务
- `LeastTotalLatency`: 路由到当前RT最低的服务
- `LeastFirstTokenLatency`: 路由到当前首包RT最低的服务
## 配置示例
```yaml
lb_policy: least_busy
lb_type: cluster
lb_policy: cluster_metrics
lb_config:
criticalModels:
- meta-llama/Llama-2-7b-hf
- sql-lora
mode: LeastTotalLatency # 策略名称
queue_size: 100 # 统计指标时使用的最近请求数
rate_limit: 0.6 # 单个服务承载的最大请求比例
service_list:
- outbound|80||test-1.dns
- outbound|80||test-2.static
```

View File

@@ -15,14 +15,19 @@ The configuration is:
| Name | Type | Required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `lb_policy` | string | required | | load balance type |
| `lb_type` | string | optional | endpoint | load balance policy type, `endpoint` or `cluster` |
| `lb_policy` | string | required | | load balance policy type |
| `lb_config` | object | required | | configuration for the current load balance type |
Current supported load balance policies are:
When `lb_type = endpoint`, current supported load balance policies are:
- `global_least_request`: global least request based on redis
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
- `endpoint_metrics`: Load balancing based on metrics exposed by the llm service
When `lb_type = cluster`, current supported load balance policies are:
- `cluster_metrics`: Load balancing based on metrics of clusters
# Global Least Request
## Introduction
@@ -60,6 +65,7 @@ sequenceDiagram
## Configuration Example
```yaml
lb_type: endpoint
lb_policy: global_least_request
lb_config:
serviceFQDN: redis.static
@@ -118,11 +124,12 @@ Then subsequent requests with the same prefix will also be routed to pod 1:
| `password` | string | optional | `` | redis password |
| `timeout` | int | optional | 3000ms | redis request timeout |
| `database` | int | optional | 0 | redis database number |
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
| `redisKeyTTL` | int | optional | 1800s | prompt prefix key's ttl |
## Configuration Example
```yaml
lb_type: endpoint
lb_policy: prefix_cache
lb_config:
serviceFQDN: redis.static
@@ -164,14 +171,71 @@ sequenceDiagram
| Name | Type | Required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `criticalModels` | []string | required | | critical model names |
| `metric_policy` | string | required | | How to use the metrics exposed by LLM for load balancing, currently supporting `[default, least, most]` |
| `target_metric` | string | optional | | The metric name to use. This is valid only when `metric_policy` is `least` or `most` |
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, 0~1 |
## Configuration Example
Use the algorithm of [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md):
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: default
rate_limit: 0.6
```
Load balancing based on the current number of queued requests:
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_waiting
rate_limit: 0.6
```
Load balancing based on the number of requests currently being processed by the GPU:
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_running
rate_limit: 0.6
```
# Cross-service load balancing
## Configuration
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `mode` | string | required | | how to use cluster metrics, value of `[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
| `service_list` | []string | required | | service list of current route |
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, value of 0~1 |
| `cluster_header` | string | optional | `x-envoy-target-cluster` | By retrieving the value of this header, we can determine which backend service to route to |
| `queue_size` | int | optional | 100 | The metrics is calculated based on the number of most recent requests. |
The meanings of the values for `mode` are as follows:
- `LeastBusy`: Routes to the service with the fewest concurrent requests.
- `LeastTotalLatency`: Routes to the service with the lowest response time (RT).
- `LeastFirstTokenLatency`: Routes to the service with the lowest RT for the first packet.
## Configuration Example
```yaml
lb_policy: least_busy
lb_type: cluster
lb_policy: cluster_metrics
lb_config:
criticalModels:
- meta-llama/Llama-2-7b-hf
- sql-lora
```
mode: LeastTotalLatency
rate_limit: 0.6
service_list:
- outbound|80||test-1.dns
- outbound|80||test-2.static
```

View File

@@ -0,0 +1,185 @@
package cluster_metrics
import (
"math/rand"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
DefaultQueueSize = 100
DefaultClusterHeader = "x-envoy-target-cluster"
)
type ClusterEndpointLoadBalancer struct {
// Configurations
Mode string
ClusterHeader string
ServiceList []string
RateLimit float64
// Statistic
ServiceRequestOngoing map[string]int
ServiceRequestCount map[string]int
FirstTokenLatencyRequests map[string]*utils.FixedQueue[float64]
TotalLatencyRequests map[string]*utils.FixedQueue[float64]
}
func NewClusterEndpointLoadBalancer(json gjson.Result) (ClusterEndpointLoadBalancer, error) {
lb := ClusterEndpointLoadBalancer{}
lb.ServiceRequestOngoing = make(map[string]int)
lb.ServiceRequestCount = make(map[string]int)
lb.FirstTokenLatencyRequests = make(map[string]*utils.FixedQueue[float64])
lb.TotalLatencyRequests = make(map[string]*utils.FixedQueue[float64])
lb.Mode = json.Get("mode").String()
lb.ClusterHeader = json.Get("cluster_header").String()
if lb.ClusterHeader == "" {
lb.ClusterHeader = DefaultClusterHeader
}
if json.Get("rate_limit").Exists() {
lb.RateLimit = json.Get("rate_limit").Float()
} else {
lb.RateLimit = 1.0
}
queueSize := int(json.Get("queue_size").Int())
if queueSize == 0 {
queueSize = DefaultQueueSize
}
for _, svc := range json.Get("service_list").Array() {
serviceName := svc.String()
lb.ServiceList = append(lb.ServiceList, serviceName)
lb.ServiceRequestOngoing[serviceName] = 0
lb.ServiceRequestCount[serviceName] = 0
lb.FirstTokenLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
lb.TotalLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
}
return lb, nil
}
func (lb ClusterEndpointLoadBalancer) getRequestRate(serviceName string) float64 {
totalRequestCount := 0
for _, v := range lb.ServiceRequestCount {
totalRequestCount += v
}
if totalRequestCount != 0 {
return float64(lb.ServiceRequestCount[serviceName]) / float64(totalRequestCount)
}
return 0
}
func (lb ClusterEndpointLoadBalancer) getServiceTTFT(serviceName string) float64 {
queue, ok := lb.FirstTokenLatencyRequests[serviceName]
if !ok || queue.Size() == 0 {
return 0
}
value := 0.0
queue.ForEach(func(i int, item float64) {
value += float64(item)
})
return value / float64(queue.Size())
}
func (lb ClusterEndpointLoadBalancer) getServiceTotalRT(serviceName string) float64 {
queue, ok := lb.TotalLatencyRequests[serviceName]
if !ok || queue.Size() == 0 {
return 0
}
value := 0.0
queue.ForEach(func(i int, item float64) {
value += float64(item)
})
return value / float64(queue.Size())
}
// Callbacks which are called in request path
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
ctx.SetContext("request_start", time.Now().UnixMilli())
candidate := lb.ServiceList[rand.Int()%len(lb.ServiceList)]
switch lb.Mode {
case "LeastBusy":
for svc, ongoingNum := range lb.ServiceRequestOngoing {
if candidate == svc {
continue
}
log.Debugf("[candidate: %s] {ongoing request: %d, total request: %d, request rate: %.2f}, [new candidate: %s] {ongoing request: %d, total request: %d, request rate: %.2f}",
candidate, lb.ServiceRequestOngoing[candidate], lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate),
svc, lb.ServiceRequestOngoing[svc], lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
if lb.getRequestRate(candidate) >= lb.RateLimit {
candidate = svc
} else if ongoingNum < lb.ServiceRequestOngoing[candidate] && lb.getRequestRate(svc) < lb.RateLimit {
candidate = svc
}
}
case "LeastFirstTokenLatency":
candidateTTFT := lb.getServiceTTFT(candidate)
for _, svc := range lb.ServiceList {
if candidate == svc {
continue
}
log.Debugf("[candidate: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}, [new candidate: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}",
candidate, lb.getServiceTTFT(candidate), lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate),
svc, lb.getServiceTTFT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
if lb.getRequestRate(candidate) >= lb.RateLimit {
candidate = svc
} else if lb.getServiceTTFT(svc) < candidateTTFT && lb.getRequestRate(svc) < lb.RateLimit {
candidate = svc
}
}
case "LeastTotalLatency":
candidateTotalRT := lb.getServiceTotalRT(candidate)
for _, svc := range lb.ServiceList {
if candidate == svc {
continue
}
log.Debugf("[candidate: %s] {average latency: %.2f, total request: %d, request rate: %.2f}, [new candidate: %s] {average latency: %.2f, total request: %d, request rate: %.2f}",
candidate, lb.getServiceTotalRT(candidate), lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate),
svc, lb.getServiceTotalRT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
if lb.getRequestRate(candidate) >= lb.RateLimit {
candidate = svc
} else if lb.getServiceTotalRT(svc) < candidateTotalRT && lb.getRequestRate(svc) < lb.RateLimit {
candidate = svc
}
}
}
proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, candidate)
ctx.SetContext(lb.ClusterHeader, candidate)
lb.ServiceRequestOngoing[candidate] += 1
lb.ServiceRequestCount[candidate] += 1
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
if ctx.GetContext("ttft_recorded") == nil {
candidate := ctx.GetContext(lb.ClusterHeader).(string)
duration := time.Now().UnixMilli() - ctx.GetContext("request_start").(int64)
lb.FirstTokenLatencyRequests[candidate].Enqueue(float64(duration))
ctx.SetContext("ttft_recorded", struct{}{})
}
return data
}
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {
candidate := ctx.GetContext(lb.ClusterHeader).(string)
duration := time.Now().UnixMilli() - ctx.GetContext("request_start").(int64)
lb.TotalLatencyRequests[candidate].Enqueue(float64(duration))
lb.ServiceRequestOngoing[candidate] -= 1
}

View File

@@ -40,13 +40,19 @@ type Metrics struct {
KvCacheMaxTokenCapacity int
}
type UserSelectedMetric struct {
MetricName string
MetricValue float64
}
type PodMetrics struct {
Pod
Metrics
UserSelectedMetric
}
func (pm *PodMetrics) String() string {
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
return fmt.Sprintf("Pod: %+v; Metrics: %+v, UserSelectedMetric: %+v", pm.Pod, pm.Metrics, pm.UserSelectedMetric)
}
func (pm *PodMetrics) Clone() *PodMetrics {
@@ -63,6 +69,10 @@ func (pm *PodMetrics) Clone() *PodMetrics {
KVCacheUsagePercent: pm.KVCacheUsagePercent,
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
},
UserSelectedMetric: UserSelectedMetric{
MetricName: pm.MetricName,
MetricValue: pm.MetricValue,
},
}
return clone
}

View File

@@ -23,7 +23,7 @@ import (
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
dto "github.com/prometheus/client_model/go"
"go.uber.org/multierr"
@@ -53,6 +53,16 @@ func PromToPodMetrics(
) (*backend.PodMetrics, error) {
var errs error
updated := existing.Clone()
// User selected metric
if updated.MetricName != "" {
metricValue, err := getLatestMetric(metricFamilies, updated.MetricName)
errs = multierr.Append(errs, err)
if err == nil {
updated.MetricValue = metricValue.GetGauge().GetValue()
}
return updated, errs
}
// Default metric
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
errs = multierr.Append(errs, err)
if err == nil {

View File

@@ -0,0 +1,120 @@
package endpoint_metrics
import (
"math/rand"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
FixedQueueSize = 100
)
type MetricsEndpointLoadBalancer struct {
metricPolicy string
targetMetric string
endpointRequests *utils.FixedQueue[string]
maxRate float64
}
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
lb := MetricsEndpointLoadBalancer{}
if json.Get("metric_policy").Exists() {
lb.metricPolicy = json.Get("metric_policy").String()
} else {
lb.metricPolicy = scheduling.MetricPolicyDefault
}
if json.Get("target_metric").Exists() {
lb.targetMetric = json.Get("target_metric").String()
}
if json.Get("rate_limit").Exists() {
lb.maxRate = json.Get("rate_limit").Float()
} else {
lb.maxRate = 1.0
}
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
return lb, nil
}
// Callbacks which are called in request path
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: true,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
return types.ActionContinue
}
finalAddress := targetPod.Address
otherHosts := []string{} // 如果当前host超过请求数限制那么在其中随机挑选一个
currentRate := 0.0
for k := range hostMetrics {
if k != finalAddress {
otherHosts = append(otherHosts, k)
}
}
if lb.endpointRequests.Size() != 0 {
count := 0.0
lb.endpointRequests.ForEach(func(i int, item string) {
if item == finalAddress {
count += 1
}
})
currentRate = count / float64(lb.endpointRequests.Size())
}
if currentRate > lb.maxRate && len(otherHosts) > 0 {
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
}
lb.endpointRequests.Enqueue(finalAddress)
log.Debugf("pod %s is selected", finalAddress)
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}

View File

@@ -20,7 +20,7 @@ import (
"errors"
"math"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)

View File

@@ -20,15 +20,22 @@ package scheduling
import (
"errors"
"fmt"
"math"
"math/rand"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend/vllm"
"github.com/prometheus/common/expfmt"
)
const (
MetricPolicyDefault = "default"
MetricPolicyLeast = "least"
MetricPolicyMost = "most"
)
const (
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
kvCacheThreshold = 0.8
@@ -107,11 +114,11 @@ var (
}
)
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
func NewScheduler(pm []*backend.PodMetrics, filter Filter) *Scheduler {
return &Scheduler{
podMetrics: pm,
filter: defaultFilter,
filter: filter,
}
}
@@ -130,7 +137,7 @@ func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error)
return pods[i].Pod, nil
}
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
func GetScheduler(hostMetrics map[string]string, metricPolicy string, targetMetric string) (*Scheduler, error) {
if len(hostMetrics) == 0 {
return nil, errors.New("backend is not support llm scheduling")
}
@@ -147,6 +154,9 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
Address: addr,
},
Metrics: backend.Metrics{},
UserSelectedMetric: backend.UserSelectedMetric{
MetricName: targetMetric,
},
}
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
if err != nil {
@@ -154,5 +164,60 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
}
pms = append(pms, pm)
}
return NewScheduler(pms), nil
if metricPolicy == MetricPolicyLeast {
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
min := math.MaxFloat64
max := 0.0
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
if pod.MetricValue <= min {
min = pod.MetricValue
}
if pod.MetricValue >= max {
max = pod.MetricValue
}
}
for _, pod := range pods {
if pod.MetricValue >= min && pod.MetricValue <= min+(max-min)/float64(len(pods)) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
filter := filter{
name: "least user selected metric",
filter: filterFunc,
}
return NewScheduler(pms, &filter), nil
} else if metricPolicy == MetricPolicyMost {
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
min := math.MaxFloat64
max := 0.0
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
if pod.MetricValue <= min {
min = pod.MetricValue
}
if pod.MetricValue >= max {
max = pod.MetricValue
}
}
for _, pod := range pods {
if pod.MetricValue <= max && pod.MetricValue >= max-(max-min)/float64(len(pods)) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
filter := filter{
name: "most user selected metric",
filter: filterFunc,
}
return NewScheduler(pms, &filter), nil
}
return NewScheduler(pms, defaultFilter), nil
}

View File

@@ -1,81 +0,0 @@
package least_busy
import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type LeastBusyLoadBalancer struct {
criticalModels map[string]struct{}
}
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
lb := LeastBusyLoadBalancer{}
lb.criticalModels = make(map[string]struct{})
for _, model := range json.Get("criticalModels").Array() {
lb.criticalModels[model.String()] = struct{}{}
}
return lb, nil
}
// Callbacks which are called in request path
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
_, isCritical := lb.criticalModels[requestModel.String()]
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: isCritical,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
} else {
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
}
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}

View File

@@ -7,9 +7,10 @@ import (
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
)
func main() {}
@@ -37,34 +38,57 @@ type LoadBalancer interface {
}
type Config struct {
policy string
lb LoadBalancer
lbType string
lbPolicy string
lb LoadBalancer
}
const (
LeastBusyLoadBalancerPolicy = "least_busy"
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
PrefixCache = "prefix_cache"
ClusterLoadBalancerType = "cluster"
EndpointLoadBalancerType = "endpoint"
// Cluster load balancer policies
MetricsBasedCluster = "cluster_metrics"
// Endpoint load balancer policies
MetricsBasedEndpoint = "endpoint_metrics"
MetricsBasedEndpointDeprecated = "metrics_based" // Compatible with old configurations, equal to `endpoint_metrics`
GlobalLeastRequestEndpoint = "global_least_request"
PrefixCacheEndpoint = "prefix_cache"
)
func parseConfig(json gjson.Result, config *Config) error {
config.policy = json.Get("lb_policy").String()
config.lbType = json.Get("lb_type").String()
// Compatible with old configurations
if config.lbType == "" {
config.lbType = EndpointLoadBalancerType
}
config.lbPolicy = json.Get("lb_policy").String()
var err error
switch config.policy {
case LeastBusyLoadBalancerPolicy:
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
case GlobalLeastRequestLoadBalancerPolicy:
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
case PrefixCache:
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
switch config.lbType {
case ClusterLoadBalancerType:
switch config.lbPolicy {
case MetricsBasedCluster:
config.lb, err = cluster_metrics.NewClusterEndpointLoadBalancer(json.Get("lb_config"))
default:
err = fmt.Errorf("lb_policy %s is not supported", config.lbPolicy)
}
case EndpointLoadBalancerType:
switch config.lbPolicy {
case MetricsBasedEndpoint, MetricsBasedEndpointDeprecated:
config.lb, err = endpoint_metrics.NewMetricsEndpointLoadBalancer(json.Get("lb_config"))
case GlobalLeastRequestEndpoint:
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
case PrefixCacheEndpoint:
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
default:
err = fmt.Errorf("lb_psolicy %s is not supported", config.lbPolicy)
}
default:
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
err = fmt.Errorf("lb_type %s is not supported", config.lbType)
}
return err
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
ctx.DisableReroute()
return config.lb.HandleHttpRequestHeaders(ctx)
}

View File

@@ -0,0 +1,175 @@
package utils
import (
"errors"
)
// FixedQueue 实现了一个固定容量的环形缓冲区队列
// 当队列满时,新元素会覆盖最旧的元素
type FixedQueue[T any] struct {
data []T
head int
tail int
size int
cap int
}
// NewFixed 创建一个指定容量的固定队列
func NewFixedQueue[T any](capacity int) *FixedQueue[T] {
if capacity <= 0 {
capacity = 16
}
return &FixedQueue[T]{
data: make([]T, capacity),
head: 0,
tail: 0,
size: 0,
cap: capacity,
}
}
// Enqueue 入队操作
// 如果队列已满,会覆盖最旧的元素
func (q *FixedQueue[T]) Enqueue(item T) {
if q.size < q.cap {
// 队列未满,正常插入
q.data[q.tail] = item
q.tail = (q.tail + 1) % q.cap
q.size++
} else {
// 队列已满,覆盖最旧元素
q.data[q.tail] = item
q.head = (q.head + 1) % q.cap // 移动head丢弃最旧元素
q.tail = (q.tail + 1) % q.cap // tail正常移动
// size保持不变仍然是cap
}
}
// Dequeue 出队操作
func (q *FixedQueue[T]) Dequeue() (T, error) {
var zero T
if q.size == 0 {
return zero, errors.New("queue is empty")
}
item := q.data[q.head]
// 清除引用,避免内存泄漏
var zeroVal T
q.data[q.head] = zeroVal
q.head = (q.head + 1) % q.cap
q.size--
return item, nil
}
// Peek 查看队头元素但不移除
func (q *FixedQueue[T]) Peek() (T, error) {
var zero T
if q.size == 0 {
return zero, errors.New("queue is empty")
}
return q.data[q.head], nil
}
// Size 返回队列中元素的数量
func (q *FixedQueue[T]) Size() int {
return q.size
}
// Capacity 返回队列的固定容量
func (q *FixedQueue[T]) Capacity() int {
return q.cap
}
// IsEmpty 判断队列是否为空
func (q *FixedQueue[T]) IsEmpty() bool {
return q.size == 0
}
// IsFull 判断队列是否已满
func (q *FixedQueue[T]) IsFull() bool {
return q.size == q.cap
}
// OverwriteCount 返回被覆盖的元素数量
// 注意:这个实现中我们不直接跟踪覆盖次数,
// 但可以通过其他方式计算(如果需要的话)
func (q *FixedQueue[T]) OverwriteCount() int {
// 如果需要跟踪覆盖次数,可以添加一个字段
// 目前这个实现不提供此功能
return 0
}
// Clear 清空队列
func (q *FixedQueue[T]) Clear() {
// 清除所有引用
for i := 0; i < q.size; i++ {
idx := (q.head + i) % q.cap
var zero T
q.data[idx] = zero
}
q.head = 0
q.tail = 0
q.size = 0
}
// ToSlice 返回队列元素的切片副本(按队列顺序,从最旧到最新)
func (q *FixedQueue[T]) ToSlice() []T {
if q.size == 0 {
return []T{}
}
result := make([]T, q.size)
if q.head <= q.tail || q.size == q.cap {
if q.head < q.tail {
// 数据连续且未满
copy(result, q.data[q.head:q.tail])
} else {
// 数据连续但已满head == tail
// 或者数据跨越边界
if q.head == q.tail && q.size == q.cap {
// 已满且head == tail的情况
copy(result, q.data[q.head:])
if len(result) > q.cap-q.head {
copy(result[q.cap-q.head:], q.data[:q.tail])
}
} else {
// 跨越边界
copy(result, q.data[q.head:])
copy(result[q.cap-q.head:], q.data[:q.tail])
}
}
} else {
// 跨越边界的情况
copy(result, q.data[q.head:])
copy(result[q.cap-q.head:], q.data[:q.tail])
}
return result
}
// Oldest 返回最旧的元素(队头)
func (q *FixedQueue[T]) Oldest() (T, error) {
return q.Peek()
}
// Newest 返回最新的元素(队尾的前一个元素)
func (q *FixedQueue[T]) Newest() (T, error) {
var zero T
if q.size == 0 {
return zero, errors.New("queue is empty")
}
// tail指向下一个插入位置所以最新元素在 (tail - 1 + cap) % cap
newestIndex := (q.tail - 1 + q.cap) % q.cap
return q.data[newestIndex], nil
}
// ForEach 对队列中的每个元素执行回调函数
func (q *FixedQueue[T]) ForEach(fn func(index int, item T)) {
for i := 0; i < q.size; i++ {
idx := (q.head + i) % q.cap
fn(i, q.data[idx])
}
}