Files
higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/lb_policy.go

121 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package endpoint_metrics
import (
"math/rand"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
FixedQueueSize = 100
)
type MetricsEndpointLoadBalancer struct {
metricPolicy string
targetMetric string
endpointRequests *utils.FixedQueue[string]
maxRate float64
}
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
lb := MetricsEndpointLoadBalancer{}
if json.Get("metric_policy").Exists() {
lb.metricPolicy = json.Get("metric_policy").String()
} else {
lb.metricPolicy = scheduling.MetricPolicyDefault
}
if json.Get("target_metric").Exists() {
lb.targetMetric = json.Get("target_metric").String()
}
if json.Get("rate_limit").Exists() {
lb.maxRate = json.Get("rate_limit").Float()
} else {
lb.maxRate = 1.0
}
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
return lb, nil
}
// Callbacks which are called in request path
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: true,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
return types.ActionContinue
}
finalAddress := targetPod.Address
otherHosts := []string{} // 如果当前host超过请求数限制那么在其中随机挑选一个
currentRate := 0.0
for k := range hostMetrics {
if k != finalAddress {
otherHosts = append(otherHosts, k)
}
}
if lb.endpointRequests.Size() != 0 {
count := 0.0
lb.endpointRequests.ForEach(func(i int, item string) {
if item == finalAddress {
count += 1
}
})
currentRate = count / float64(lb.endpointRequests.Size())
}
if currentRate > lb.maxRate && len(otherHosts) > 0 {
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
}
lb.endpointRequests.Enqueue(finalAddress)
log.Debugf("pod %s is selected", finalAddress)
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}