mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
[feat] load balancing across different clusters and endpoints based on metrics (#3063)
This commit is contained in:
@@ -15,13 +15,18 @@ description: 针对LLM服务的负载均衡策略
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_type` | string | 选填 | endpoint | 负载均衡类型,可选`endpoint`,`cluster` |
|
||||
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
|
||||
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
|
||||
|
||||
目前支持的负载均衡策略包括:
|
||||
`lb_type` 为 `endpoint` 时支持的负载均衡策略包括:
|
||||
|
||||
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
|
||||
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
|
||||
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
|
||||
- `endpoint_metrics`: 基于 llm 服务暴露的 metrics 进行负载均衡
|
||||
|
||||
`lb_type` 为 `cluster` 时支持的负载均衡策略包括:
|
||||
- `cluster_metrics`: 基于网关统计的不同service的指标进行服务之间的负载均衡
|
||||
|
||||
# 全局最小请求数
|
||||
## 功能说明
|
||||
@@ -59,6 +64,7 @@ sequenceDiagram
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -116,11 +122,12 @@ lb_config:
|
||||
| `password` | string | 选填 | 空 | redis 密码 |
|
||||
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
|
||||
| `database` | int | 选填 | 0 | redis 数据库序号 |
|
||||
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
|
||||
| `redisKeyTTL` | int | 选填 | 1800s | prompt 前缀对应的key的ttl |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -161,14 +168,73 @@ sequenceDiagram
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | 选填 | | critical的模型列表 |
|
||||
| `metric_policy` | string | 必填 | | 如何使用llm暴露的metrics做负载均衡,当前支持`[default, least, most]` |
|
||||
| `target_metric` | string | 选填 | | 要使用的metric名称,`metric_policy` 取值为 `least` 或者 `most` 时生效 |
|
||||
| `rate_limit` | string | 选填 | 1 | 单个节点处理请求比例上限,取值范围0~1 |
|
||||
|
||||
|
||||
## 配置示例
|
||||
|
||||
使用 [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 中的算法
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: default
|
||||
rate_limit: 0.6 # 单个节点承载的最大请求比例
|
||||
```
|
||||
|
||||
根据当前排队请求数进行负载均衡
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_waiting
|
||||
rate_limit: 0.6 # 单个节点承载的最大请求比例
|
||||
```
|
||||
|
||||
根据当前GPU中正在处理的请求数进行负载均衡
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_running
|
||||
rate_limit: 0.6 # 单个节点承载的最大请求比例
|
||||
```
|
||||
|
||||
|
||||
# 跨服务负载均衡
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `mode` | string | 必填 | | 如何使用服务级指标做负载均衡,当前支持`[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
|
||||
| `service_list` | []string | 必填 | | 路由后端服务列表 |
|
||||
| `rate_limit` | string | 选填 | 1 | 单个服务处理请求比例上限,取值范围0~1 |
|
||||
| `cluster_header` | string | 选填 | `x-envoy-target-cluster` | 通过取该header的值得知需要路由到哪个后端服务 |
|
||||
| `queue_size` | int | 选填 | 100 | 根据最近的多少个请求进行观测指标的计算 |
|
||||
|
||||
`mode` 各取值含义如下:
|
||||
- `LeastBusy`: 路由到当前并发请求数最少的服务
|
||||
- `LeastTotalLatency`: 路由到当前RT最低的服务
|
||||
- `LeastFirstTokenLatency`: 路由到当前首包RT最低的服务
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_type: cluster
|
||||
lb_policy: cluster_metrics
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
mode: LeastTotalLatency # 策略名称
|
||||
queue_size: 100 # 统计指标时使用的最近请求数
|
||||
rate_limit: 0.6 # 单个服务承载的最大请求比例
|
||||
service_list:
|
||||
- outbound|80||test-1.dns
|
||||
- outbound|80||test-2.static
|
||||
```
|
||||
@@ -15,14 +15,19 @@ The configuration is:
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_policy` | string | required | | load balance type |
|
||||
| `lb_type` | string | optional | endpoint | load balance policy type, `endpoint` or `cluster` |
|
||||
| `lb_policy` | string | required | | load balance policy type |
|
||||
| `lb_config` | object | required | | configuration for the current load balance type |
|
||||
|
||||
Current supported load balance policies are:
|
||||
When `lb_type = endpoint`, current supported load balance policies are:
|
||||
|
||||
- `global_least_request`: global least request based on redis
|
||||
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
|
||||
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
|
||||
- `endpoint_metrics`: Load balancing based on metrics exposed by the llm service
|
||||
|
||||
When `lb_type = cluster`, current supported load balance policies are:
|
||||
- `cluster_metrics`: Load balancing based on metrics of clusters
|
||||
|
||||
|
||||
# Global Least Request
|
||||
## Introduction
|
||||
@@ -60,6 +65,7 @@ sequenceDiagram
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -118,11 +124,12 @@ Then subsequent requests with the same prefix will also be routed to pod 1:
|
||||
| `password` | string | optional | `` | redis password |
|
||||
| `timeout` | int | optional | 3000ms | redis request timeout |
|
||||
| `database` | int | optional | 0 | redis database number |
|
||||
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
|
||||
| `redisKeyTTL` | int | optional | 1800s | prompt prefix key's ttl |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -164,14 +171,71 @@ sequenceDiagram
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | required | | critical model names |
|
||||
| `metric_policy` | string | required | | How to use the metrics exposed by LLM for load balancing, currently supporting `[default, least, most]` |
|
||||
| `target_metric` | string | optional | | The metric name to use. This is valid only when `metric_policy` is `least` or `most` |
|
||||
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, 0~1 |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
Use the algorithm of [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md):
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: default
|
||||
rate_limit: 0.6
|
||||
```
|
||||
|
||||
Load balancing based on the current number of queued requests:
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_waiting
|
||||
rate_limit: 0.6
|
||||
```
|
||||
|
||||
Load balancing based on the number of requests currently being processed by the GPU:
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_running
|
||||
rate_limit: 0.6
|
||||
```
|
||||
|
||||
# Cross-service load balancing
|
||||
|
||||
## Configuration
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `mode` | string | required | | how to use cluster metrics, value of `[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
|
||||
| `service_list` | []string | required | | service list of current route |
|
||||
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, value of 0~1 |
|
||||
| `cluster_header` | string | optional | `x-envoy-target-cluster` | By retrieving the value of this header, we can determine which backend service to route to |
|
||||
| `queue_size` | int | optional | 100 | The metrics is calculated based on the number of most recent requests. |
|
||||
|
||||
The meanings of the values for `mode` are as follows:
|
||||
|
||||
- `LeastBusy`: Routes to the service with the fewest concurrent requests.
|
||||
- `LeastTotalLatency`: Routes to the service with the lowest response time (RT).
|
||||
- `LeastFirstTokenLatency`: Routes to the service with the lowest RT for the first packet.
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_type: cluster
|
||||
lb_policy: cluster_metrics
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
```
|
||||
mode: LeastTotalLatency
|
||||
rate_limit: 0.6
|
||||
service_list:
|
||||
- outbound|80||test-1.dns
|
||||
- outbound|80||test-2.static
|
||||
```
|
||||
@@ -0,0 +1,185 @@
|
||||
package cluster_metrics
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultQueueSize = 100
|
||||
DefaultClusterHeader = "x-envoy-target-cluster"
|
||||
)
|
||||
|
||||
type ClusterEndpointLoadBalancer struct {
|
||||
// Configurations
|
||||
Mode string
|
||||
ClusterHeader string
|
||||
ServiceList []string
|
||||
RateLimit float64
|
||||
// Statistic
|
||||
ServiceRequestOngoing map[string]int
|
||||
ServiceRequestCount map[string]int
|
||||
FirstTokenLatencyRequests map[string]*utils.FixedQueue[float64]
|
||||
TotalLatencyRequests map[string]*utils.FixedQueue[float64]
|
||||
}
|
||||
|
||||
func NewClusterEndpointLoadBalancer(json gjson.Result) (ClusterEndpointLoadBalancer, error) {
|
||||
lb := ClusterEndpointLoadBalancer{}
|
||||
lb.ServiceRequestOngoing = make(map[string]int)
|
||||
lb.ServiceRequestCount = make(map[string]int)
|
||||
lb.FirstTokenLatencyRequests = make(map[string]*utils.FixedQueue[float64])
|
||||
lb.TotalLatencyRequests = make(map[string]*utils.FixedQueue[float64])
|
||||
|
||||
lb.Mode = json.Get("mode").String()
|
||||
lb.ClusterHeader = json.Get("cluster_header").String()
|
||||
if lb.ClusterHeader == "" {
|
||||
lb.ClusterHeader = DefaultClusterHeader
|
||||
}
|
||||
if json.Get("rate_limit").Exists() {
|
||||
lb.RateLimit = json.Get("rate_limit").Float()
|
||||
} else {
|
||||
lb.RateLimit = 1.0
|
||||
}
|
||||
queueSize := int(json.Get("queue_size").Int())
|
||||
if queueSize == 0 {
|
||||
queueSize = DefaultQueueSize
|
||||
}
|
||||
|
||||
for _, svc := range json.Get("service_list").Array() {
|
||||
serviceName := svc.String()
|
||||
lb.ServiceList = append(lb.ServiceList, serviceName)
|
||||
lb.ServiceRequestOngoing[serviceName] = 0
|
||||
lb.ServiceRequestCount[serviceName] = 0
|
||||
lb.FirstTokenLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
|
||||
lb.TotalLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
|
||||
}
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) getRequestRate(serviceName string) float64 {
|
||||
totalRequestCount := 0
|
||||
for _, v := range lb.ServiceRequestCount {
|
||||
totalRequestCount += v
|
||||
}
|
||||
if totalRequestCount != 0 {
|
||||
return float64(lb.ServiceRequestCount[serviceName]) / float64(totalRequestCount)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) getServiceTTFT(serviceName string) float64 {
|
||||
queue, ok := lb.FirstTokenLatencyRequests[serviceName]
|
||||
if !ok || queue.Size() == 0 {
|
||||
return 0
|
||||
}
|
||||
value := 0.0
|
||||
queue.ForEach(func(i int, item float64) {
|
||||
value += float64(item)
|
||||
})
|
||||
return value / float64(queue.Size())
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) getServiceTotalRT(serviceName string) float64 {
|
||||
queue, ok := lb.TotalLatencyRequests[serviceName]
|
||||
if !ok || queue.Size() == 0 {
|
||||
return 0
|
||||
}
|
||||
value := 0.0
|
||||
queue.ForEach(func(i int, item float64) {
|
||||
value += float64(item)
|
||||
})
|
||||
return value / float64(queue.Size())
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.SetContext("request_start", time.Now().UnixMilli())
|
||||
candidate := lb.ServiceList[rand.Int()%len(lb.ServiceList)]
|
||||
switch lb.Mode {
|
||||
case "LeastBusy":
|
||||
for svc, ongoingNum := range lb.ServiceRequestOngoing {
|
||||
if candidate == svc {
|
||||
continue
|
||||
}
|
||||
log.Debugf("[candidate: %s] {ongoing request: %d, total request: %d, request rate: %.2f}, [new candidate: %s] {ongoing request: %d, total request: %d, request rate: %.2f}",
|
||||
candidate, lb.ServiceRequestOngoing[candidate], lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate),
|
||||
svc, lb.ServiceRequestOngoing[svc], lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
|
||||
if lb.getRequestRate(candidate) >= lb.RateLimit {
|
||||
candidate = svc
|
||||
} else if ongoingNum < lb.ServiceRequestOngoing[candidate] && lb.getRequestRate(svc) < lb.RateLimit {
|
||||
candidate = svc
|
||||
}
|
||||
}
|
||||
case "LeastFirstTokenLatency":
|
||||
candidateTTFT := lb.getServiceTTFT(candidate)
|
||||
for _, svc := range lb.ServiceList {
|
||||
if candidate == svc {
|
||||
continue
|
||||
}
|
||||
log.Debugf("[candidate: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}, [new candidate: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}",
|
||||
candidate, lb.getServiceTTFT(candidate), lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate),
|
||||
svc, lb.getServiceTTFT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
|
||||
if lb.getRequestRate(candidate) >= lb.RateLimit {
|
||||
candidate = svc
|
||||
} else if lb.getServiceTTFT(svc) < candidateTTFT && lb.getRequestRate(svc) < lb.RateLimit {
|
||||
candidate = svc
|
||||
}
|
||||
}
|
||||
case "LeastTotalLatency":
|
||||
candidateTotalRT := lb.getServiceTotalRT(candidate)
|
||||
for _, svc := range lb.ServiceList {
|
||||
if candidate == svc {
|
||||
continue
|
||||
}
|
||||
log.Debugf("[candidate: %s] {average latency: %.2f, total request: %d, request rate: %.2f}, [new candidate: %s] {average latency: %.2f, total request: %d, request rate: %.2f}",
|
||||
candidate, lb.getServiceTotalRT(candidate), lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate),
|
||||
svc, lb.getServiceTotalRT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
|
||||
if lb.getRequestRate(candidate) >= lb.RateLimit {
|
||||
candidate = svc
|
||||
} else if lb.getServiceTotalRT(svc) < candidateTotalRT && lb.getRequestRate(svc) < lb.RateLimit {
|
||||
candidate = svc
|
||||
}
|
||||
}
|
||||
}
|
||||
proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, candidate)
|
||||
ctx.SetContext(lb.ClusterHeader, candidate)
|
||||
lb.ServiceRequestOngoing[candidate] += 1
|
||||
lb.ServiceRequestCount[candidate] += 1
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
if ctx.GetContext("ttft_recorded") == nil {
|
||||
candidate := ctx.GetContext(lb.ClusterHeader).(string)
|
||||
duration := time.Now().UnixMilli() - ctx.GetContext("request_start").(int64)
|
||||
lb.FirstTokenLatencyRequests[candidate].Enqueue(float64(duration))
|
||||
ctx.SetContext("ttft_recorded", struct{}{})
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {
|
||||
candidate := ctx.GetContext(lb.ClusterHeader).(string)
|
||||
duration := time.Now().UnixMilli() - ctx.GetContext("request_start").(int64)
|
||||
lb.TotalLatencyRequests[candidate].Enqueue(float64(duration))
|
||||
lb.ServiceRequestOngoing[candidate] -= 1
|
||||
}
|
||||
@@ -40,13 +40,19 @@ type Metrics struct {
|
||||
KvCacheMaxTokenCapacity int
|
||||
}
|
||||
|
||||
type UserSelectedMetric struct {
|
||||
MetricName string
|
||||
MetricValue float64
|
||||
}
|
||||
|
||||
type PodMetrics struct {
|
||||
Pod
|
||||
Metrics
|
||||
UserSelectedMetric
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) String() string {
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v, UserSelectedMetric: %+v", pm.Pod, pm.Metrics, pm.UserSelectedMetric)
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
@@ -63,6 +69,10 @@ func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
KVCacheUsagePercent: pm.KVCacheUsagePercent,
|
||||
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
|
||||
},
|
||||
UserSelectedMetric: UserSelectedMetric{
|
||||
MetricName: pm.MetricName,
|
||||
MetricValue: pm.MetricValue,
|
||||
},
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"go.uber.org/multierr"
|
||||
@@ -53,6 +53,16 @@ func PromToPodMetrics(
|
||||
) (*backend.PodMetrics, error) {
|
||||
var errs error
|
||||
updated := existing.Clone()
|
||||
// User selected metric
|
||||
if updated.MetricName != "" {
|
||||
metricValue, err := getLatestMetric(metricFamilies, updated.MetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.MetricValue = metricValue.GetGauge().GetValue()
|
||||
}
|
||||
return updated, errs
|
||||
}
|
||||
// Default metric
|
||||
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
@@ -0,0 +1,120 @@
|
||||
package endpoint_metrics
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
FixedQueueSize = 100
|
||||
)
|
||||
|
||||
type MetricsEndpointLoadBalancer struct {
|
||||
metricPolicy string
|
||||
targetMetric string
|
||||
endpointRequests *utils.FixedQueue[string]
|
||||
maxRate float64
|
||||
}
|
||||
|
||||
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
|
||||
lb := MetricsEndpointLoadBalancer{}
|
||||
if json.Get("metric_policy").Exists() {
|
||||
lb.metricPolicy = json.Get("metric_policy").String()
|
||||
} else {
|
||||
lb.metricPolicy = scheduling.MetricPolicyDefault
|
||||
}
|
||||
if json.Get("target_metric").Exists() {
|
||||
lb.targetMetric = json.Get("target_metric").String()
|
||||
}
|
||||
if json.Get("rate_limit").Exists() {
|
||||
lb.maxRate = json.Get("rate_limit").Float()
|
||||
} else {
|
||||
lb.maxRate = 1.0
|
||||
}
|
||||
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: true,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
finalAddress := targetPod.Address
|
||||
otherHosts := []string{} // 如果当前host超过请求数限制,那么在其中随机挑选一个
|
||||
currentRate := 0.0
|
||||
for k := range hostMetrics {
|
||||
if k != finalAddress {
|
||||
otherHosts = append(otherHosts, k)
|
||||
}
|
||||
}
|
||||
if lb.endpointRequests.Size() != 0 {
|
||||
count := 0.0
|
||||
lb.endpointRequests.ForEach(func(i int, item string) {
|
||||
if item == finalAddress {
|
||||
count += 1
|
||||
}
|
||||
})
|
||||
currentRate = count / float64(lb.endpointRequests.Size())
|
||||
}
|
||||
if currentRate > lb.maxRate && len(otherHosts) > 0 {
|
||||
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
|
||||
}
|
||||
lb.endpointRequests.Enqueue(finalAddress)
|
||||
log.Debugf("pod %s is selected", finalAddress)
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
@@ -20,15 +20,22 @@ package scheduling
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend/vllm"
|
||||
|
||||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
const (
|
||||
MetricPolicyDefault = "default"
|
||||
MetricPolicyLeast = "least"
|
||||
MetricPolicyMost = "most"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
kvCacheThreshold = 0.8
|
||||
@@ -107,11 +114,11 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
|
||||
func NewScheduler(pm []*backend.PodMetrics, filter Filter) *Scheduler {
|
||||
|
||||
return &Scheduler{
|
||||
podMetrics: pm,
|
||||
filter: defaultFilter,
|
||||
filter: filter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +137,7 @@ func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error)
|
||||
return pods[i].Pod, nil
|
||||
}
|
||||
|
||||
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
func GetScheduler(hostMetrics map[string]string, metricPolicy string, targetMetric string) (*Scheduler, error) {
|
||||
if len(hostMetrics) == 0 {
|
||||
return nil, errors.New("backend is not support llm scheduling")
|
||||
}
|
||||
@@ -147,6 +154,9 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
Address: addr,
|
||||
},
|
||||
Metrics: backend.Metrics{},
|
||||
UserSelectedMetric: backend.UserSelectedMetric{
|
||||
MetricName: targetMetric,
|
||||
},
|
||||
}
|
||||
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
|
||||
if err != nil {
|
||||
@@ -154,5 +164,60 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
}
|
||||
pms = append(pms, pm)
|
||||
}
|
||||
return NewScheduler(pms), nil
|
||||
if metricPolicy == MetricPolicyLeast {
|
||||
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
max := 0.0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= min {
|
||||
min = pod.MetricValue
|
||||
}
|
||||
if pod.MetricValue >= max {
|
||||
max = pod.MetricValue
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue >= min && pod.MetricValue <= min+(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
filter := filter{
|
||||
name: "least user selected metric",
|
||||
filter: filterFunc,
|
||||
}
|
||||
return NewScheduler(pms, &filter), nil
|
||||
} else if metricPolicy == MetricPolicyMost {
|
||||
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
max := 0.0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= min {
|
||||
min = pod.MetricValue
|
||||
}
|
||||
if pod.MetricValue >= max {
|
||||
max = pod.MetricValue
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= max && pod.MetricValue >= max-(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
filter := filter{
|
||||
name: "most user selected metric",
|
||||
filter: filterFunc,
|
||||
}
|
||||
return NewScheduler(pms, &filter), nil
|
||||
}
|
||||
return NewScheduler(pms, defaultFilter), nil
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package least_busy
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type LeastBusyLoadBalancer struct {
|
||||
criticalModels map[string]struct{}
|
||||
}
|
||||
|
||||
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
|
||||
lb := LeastBusyLoadBalancer{}
|
||||
lb.criticalModels = make(map[string]struct{})
|
||||
for _, model := range json.Get("criticalModels").Array() {
|
||||
lb.criticalModels[model.String()] = struct{}{}
|
||||
}
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
_, isCritical := lb.criticalModels[requestModel.String()]
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: isCritical,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
|
||||
} else {
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
|
||||
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
|
||||
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
@@ -37,34 +38,57 @@ type LoadBalancer interface {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
policy string
|
||||
lb LoadBalancer
|
||||
lbType string
|
||||
lbPolicy string
|
||||
lb LoadBalancer
|
||||
}
|
||||
|
||||
const (
|
||||
LeastBusyLoadBalancerPolicy = "least_busy"
|
||||
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
|
||||
PrefixCache = "prefix_cache"
|
||||
ClusterLoadBalancerType = "cluster"
|
||||
EndpointLoadBalancerType = "endpoint"
|
||||
// Cluster load balancer policies
|
||||
MetricsBasedCluster = "cluster_metrics"
|
||||
// Endpoint load balancer policies
|
||||
MetricsBasedEndpoint = "endpoint_metrics"
|
||||
MetricsBasedEndpointDeprecated = "metrics_based" // Compatible with old configurations, equal to `endpoint_metrics`
|
||||
GlobalLeastRequestEndpoint = "global_least_request"
|
||||
PrefixCacheEndpoint = "prefix_cache"
|
||||
)
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.policy = json.Get("lb_policy").String()
|
||||
config.lbType = json.Get("lb_type").String()
|
||||
// Compatible with old configurations
|
||||
if config.lbType == "" {
|
||||
config.lbType = EndpointLoadBalancerType
|
||||
}
|
||||
config.lbPolicy = json.Get("lb_policy").String()
|
||||
var err error
|
||||
switch config.policy {
|
||||
case LeastBusyLoadBalancerPolicy:
|
||||
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
|
||||
case GlobalLeastRequestLoadBalancerPolicy:
|
||||
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
|
||||
case PrefixCache:
|
||||
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
|
||||
switch config.lbType {
|
||||
case ClusterLoadBalancerType:
|
||||
switch config.lbPolicy {
|
||||
case MetricsBasedCluster:
|
||||
config.lb, err = cluster_metrics.NewClusterEndpointLoadBalancer(json.Get("lb_config"))
|
||||
default:
|
||||
err = fmt.Errorf("lb_policy %s is not supported", config.lbPolicy)
|
||||
}
|
||||
case EndpointLoadBalancerType:
|
||||
switch config.lbPolicy {
|
||||
case MetricsBasedEndpoint, MetricsBasedEndpointDeprecated:
|
||||
config.lb, err = endpoint_metrics.NewMetricsEndpointLoadBalancer(json.Get("lb_config"))
|
||||
case GlobalLeastRequestEndpoint:
|
||||
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
|
||||
case PrefixCacheEndpoint:
|
||||
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
|
||||
default:
|
||||
err = fmt.Errorf("lb_psolicy %s is not supported", config.lbPolicy)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
|
||||
err = fmt.Errorf("lb_type %s is not supported", config.lbType)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
ctx.DisableReroute()
|
||||
return config.lb.HandleHttpRequestHeaders(ctx)
|
||||
}
|
||||
|
||||
|
||||
175
plugins/wasm-go/extensions/ai-load-balancer/utils/queue.go
Normal file
175
plugins/wasm-go/extensions/ai-load-balancer/utils/queue.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// FixedQueue 实现了一个固定容量的环形缓冲区队列
|
||||
// 当队列满时,新元素会覆盖最旧的元素
|
||||
type FixedQueue[T any] struct {
|
||||
data []T
|
||||
head int
|
||||
tail int
|
||||
size int
|
||||
cap int
|
||||
}
|
||||
|
||||
// NewFixed 创建一个指定容量的固定队列
|
||||
func NewFixedQueue[T any](capacity int) *FixedQueue[T] {
|
||||
if capacity <= 0 {
|
||||
capacity = 16
|
||||
}
|
||||
return &FixedQueue[T]{
|
||||
data: make([]T, capacity),
|
||||
head: 0,
|
||||
tail: 0,
|
||||
size: 0,
|
||||
cap: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队操作
|
||||
// 如果队列已满,会覆盖最旧的元素
|
||||
func (q *FixedQueue[T]) Enqueue(item T) {
|
||||
if q.size < q.cap {
|
||||
// 队列未满,正常插入
|
||||
q.data[q.tail] = item
|
||||
q.tail = (q.tail + 1) % q.cap
|
||||
q.size++
|
||||
} else {
|
||||
// 队列已满,覆盖最旧元素
|
||||
q.data[q.tail] = item
|
||||
q.head = (q.head + 1) % q.cap // 移动head,丢弃最旧元素
|
||||
q.tail = (q.tail + 1) % q.cap // tail正常移动
|
||||
// size保持不变(仍然是cap)
|
||||
}
|
||||
}
|
||||
|
||||
// Dequeue 出队操作
|
||||
func (q *FixedQueue[T]) Dequeue() (T, error) {
|
||||
var zero T
|
||||
if q.size == 0 {
|
||||
return zero, errors.New("queue is empty")
|
||||
}
|
||||
|
||||
item := q.data[q.head]
|
||||
// 清除引用,避免内存泄漏
|
||||
var zeroVal T
|
||||
q.data[q.head] = zeroVal
|
||||
|
||||
q.head = (q.head + 1) % q.cap
|
||||
q.size--
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Peek 查看队头元素但不移除
|
||||
func (q *FixedQueue[T]) Peek() (T, error) {
|
||||
var zero T
|
||||
if q.size == 0 {
|
||||
return zero, errors.New("queue is empty")
|
||||
}
|
||||
return q.data[q.head], nil
|
||||
}
|
||||
|
||||
// Size 返回队列中元素的数量
|
||||
func (q *FixedQueue[T]) Size() int {
|
||||
return q.size
|
||||
}
|
||||
|
||||
// Capacity 返回队列的固定容量
|
||||
func (q *FixedQueue[T]) Capacity() int {
|
||||
return q.cap
|
||||
}
|
||||
|
||||
// IsEmpty 判断队列是否为空
|
||||
func (q *FixedQueue[T]) IsEmpty() bool {
|
||||
return q.size == 0
|
||||
}
|
||||
|
||||
// IsFull 判断队列是否已满
|
||||
func (q *FixedQueue[T]) IsFull() bool {
|
||||
return q.size == q.cap
|
||||
}
|
||||
|
||||
// OverwriteCount 返回被覆盖的元素数量
|
||||
// 注意:这个实现中我们不直接跟踪覆盖次数,
|
||||
// 但可以通过其他方式计算(如果需要的话)
|
||||
func (q *FixedQueue[T]) OverwriteCount() int {
|
||||
// 如果需要跟踪覆盖次数,可以添加一个字段
|
||||
// 目前这个实现不提供此功能
|
||||
return 0
|
||||
}
|
||||
|
||||
// Clear 清空队列
|
||||
func (q *FixedQueue[T]) Clear() {
|
||||
// 清除所有引用
|
||||
for i := 0; i < q.size; i++ {
|
||||
idx := (q.head + i) % q.cap
|
||||
var zero T
|
||||
q.data[idx] = zero
|
||||
}
|
||||
q.head = 0
|
||||
q.tail = 0
|
||||
q.size = 0
|
||||
}
|
||||
|
||||
// ToSlice 返回队列元素的切片副本(按队列顺序,从最旧到最新)
|
||||
func (q *FixedQueue[T]) ToSlice() []T {
|
||||
if q.size == 0 {
|
||||
return []T{}
|
||||
}
|
||||
|
||||
result := make([]T, q.size)
|
||||
if q.head <= q.tail || q.size == q.cap {
|
||||
if q.head < q.tail {
|
||||
// 数据连续且未满
|
||||
copy(result, q.data[q.head:q.tail])
|
||||
} else {
|
||||
// 数据连续但已满(head == tail)
|
||||
// 或者数据跨越边界
|
||||
if q.head == q.tail && q.size == q.cap {
|
||||
// 已满且head == tail的情况
|
||||
copy(result, q.data[q.head:])
|
||||
if len(result) > q.cap-q.head {
|
||||
copy(result[q.cap-q.head:], q.data[:q.tail])
|
||||
}
|
||||
} else {
|
||||
// 跨越边界
|
||||
copy(result, q.data[q.head:])
|
||||
copy(result[q.cap-q.head:], q.data[:q.tail])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 跨越边界的情况
|
||||
copy(result, q.data[q.head:])
|
||||
copy(result[q.cap-q.head:], q.data[:q.tail])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Oldest 返回最旧的元素(队头)
|
||||
func (q *FixedQueue[T]) Oldest() (T, error) {
|
||||
return q.Peek()
|
||||
}
|
||||
|
||||
// Newest 返回最新的元素(队尾的前一个元素)
|
||||
func (q *FixedQueue[T]) Newest() (T, error) {
|
||||
var zero T
|
||||
if q.size == 0 {
|
||||
return zero, errors.New("queue is empty")
|
||||
}
|
||||
|
||||
// tail指向下一个插入位置,所以最新元素在 (tail - 1 + cap) % cap
|
||||
newestIndex := (q.tail - 1 + q.cap) % q.cap
|
||||
return q.data[newestIndex], nil
|
||||
}
|
||||
|
||||
// ForEach 对队列中的每个元素执行回调函数
|
||||
func (q *FixedQueue[T]) ForEach(fn func(index int, item T)) {
|
||||
for i := 0; i < q.size; i++ {
|
||||
idx := (q.head + i) % q.cap
|
||||
fn(i, q.data[idx])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user