[feat] load balancing across different clusters and endpoints based on metrics (#3063)

This commit is contained in:
rinfx
2025-11-25 10:32:34 +08:00
committed by GitHub
parent 7a504fd67d
commit 42334f21df
12 changed files with 764 additions and 126 deletions

View File

@@ -0,0 +1,120 @@
package endpoint_metrics
import (
"math/rand"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
FixedQueueSize = 100
)
type MetricsEndpointLoadBalancer struct {
metricPolicy string
targetMetric string
endpointRequests *utils.FixedQueue[string]
maxRate float64
}
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
lb := MetricsEndpointLoadBalancer{}
if json.Get("metric_policy").Exists() {
lb.metricPolicy = json.Get("metric_policy").String()
} else {
lb.metricPolicy = scheduling.MetricPolicyDefault
}
if json.Get("target_metric").Exists() {
lb.targetMetric = json.Get("target_metric").String()
}
if json.Get("rate_limit").Exists() {
lb.maxRate = json.Get("rate_limit").Float()
} else {
lb.maxRate = 1.0
}
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
return lb, nil
}
// Callbacks which are called in request path
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: true,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
return types.ActionContinue
}
finalAddress := targetPod.Address
otherHosts := []string{} // 如果当前host超过请求数限制那么在其中随机挑选一个
currentRate := 0.0
for k := range hostMetrics {
if k != finalAddress {
otherHosts = append(otherHosts, k)
}
}
if lb.endpointRequests.Size() != 0 {
count := 0.0
lb.endpointRequests.ForEach(func(i int, item string) {
if item == finalAddress {
count += 1
}
})
currentRate = count / float64(lb.endpointRequests.Size())
}
if currentRate > lb.maxRate && len(otherHosts) > 0 {
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
}
lb.endpointRequests.Enqueue(finalAddress)
log.Debugf("pod %s is selected", finalAddress)
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}