diff --git a/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics/lb_policy.go b/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics/lb_policy.go index 236c9abf9..5cc4f5438 100644 --- a/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics/lb_policy.go +++ b/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics/lb_policy.go @@ -1,6 +1,7 @@ package cluster_metrics import ( + "fmt" "math/rand" "time" @@ -14,7 +15,7 @@ import ( const ( DefaultQueueSize = 100 - DefaultClusterHeader = "x-envoy-target-cluster" + DefaultClusterHeader = "x-higress-target-cluster" ) type ClusterEndpointLoadBalancer struct { @@ -102,52 +103,62 @@ func (lb ClusterEndpointLoadBalancer) getServiceTotalRT(serviceName string) floa func (lb ClusterEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action { ctx.SetContext("request_start", time.Now().UnixMilli()) candidate := lb.ServiceList[rand.Int()%len(lb.ServiceList)] + var debugInfo string switch lb.Mode { case "LeastBusy": for svc, ongoingNum := range lb.ServiceRequestOngoing { if candidate == svc { continue } - log.Debugf("[candidate: %s] {ongoing request: %d, total request: %d, request rate: %.2f}, [new candidate: %s] {ongoing request: %d, total request: %d, request rate: %.2f}", - candidate, lb.ServiceRequestOngoing[candidate], lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate), - svc, lb.ServiceRequestOngoing[svc], lb.ServiceRequestCount[svc], lb.getRequestRate(svc)) if lb.getRequestRate(candidate) >= lb.RateLimit { candidate = svc } else if ongoingNum < lb.ServiceRequestOngoing[candidate] && lb.getRequestRate(svc) < lb.RateLimit { candidate = svc } } + for svc := range lb.ServiceRequestOngoing { + debugInfo += fmt.Sprintf("[service: %s] {ongoing request: %d, total request: %d, request rate: %.2f}, ", + svc, lb.ServiceRequestOngoing[svc], lb.ServiceRequestCount[svc], lb.getRequestRate(svc)) + } case "LeastFirstTokenLatency": candidateTTFT := lb.getServiceTTFT(candidate) for _, svc := range lb.ServiceList { if candidate == svc { continue } - log.Debugf("[candidate: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}, [new candidate: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}", - candidate, lb.getServiceTTFT(candidate), lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate), - svc, lb.getServiceTTFT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc)) if lb.getRequestRate(candidate) >= lb.RateLimit { candidate = svc + candidateTTFT = lb.getServiceTTFT(svc) } else if lb.getServiceTTFT(svc) < candidateTTFT && lb.getRequestRate(svc) < lb.RateLimit { candidate = svc + candidateTTFT = lb.getServiceTTFT(svc) } } + for _, svc := range lb.ServiceList { + debugInfo += fmt.Sprintf("[service: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}, ", + svc, lb.getServiceTTFT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc)) + } case "LeastTotalLatency": candidateTotalRT := lb.getServiceTotalRT(candidate) for _, svc := range lb.ServiceList { if candidate == svc { continue } - log.Debugf("[candidate: %s] {average latency: %.2f, total request: %d, request rate: %.2f}, [new candidate: %s] {average latency: %.2f, total request: %d, request rate: %.2f}", - candidate, lb.getServiceTotalRT(candidate), lb.ServiceRequestCount[candidate], lb.getRequestRate(candidate), - svc, lb.getServiceTotalRT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc)) if lb.getRequestRate(candidate) >= lb.RateLimit { candidate = svc + candidateTotalRT = lb.getServiceTotalRT(svc) } else if lb.getServiceTotalRT(svc) < candidateTotalRT && lb.getRequestRate(svc) < lb.RateLimit { candidate = svc + candidateTotalRT = lb.getServiceTotalRT(svc) } } + for _, svc := range lb.ServiceList { + debugInfo += fmt.Sprintf("[service: %s] {average latency: %.2f, total request: %d, request rate: %.2f}, ", + svc, lb.getServiceTotalRT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc)) + } } + debugInfo += fmt.Sprintf("final service: %s", candidate) + log.Debug(debugInfo) proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, candidate) ctx.SetContext(lb.ClusterHeader, candidate) lb.ServiceRequestOngoing[candidate] += 1 @@ -160,14 +171,26 @@ func (lb ClusterEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpCont } func (lb ClusterEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action { + statusCode, _ := proxywasm.GetHttpResponseHeader(":status") + ctx.SetContext("statusCode", statusCode) return types.ActionContinue } func (lb ClusterEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte { if ctx.GetContext("ttft_recorded") == nil { candidate := ctx.GetContext(lb.ClusterHeader).(string) - duration := time.Now().UnixMilli() - ctx.GetContext("request_start").(int64) - lb.FirstTokenLatencyRequests[candidate].Enqueue(float64(duration)) + duration := float64(time.Now().UnixMilli() - ctx.GetContext("request_start").(int64)) + // punish failed request + if ctx.GetContext("statusCode").(string) != "200" { + for _, svc := range lb.ServiceList { + ttft := lb.getServiceTTFT(svc) + if duration < ttft { + duration = ttft + } + } + duration *= 2 + } + lb.FirstTokenLatencyRequests[candidate].Enqueue(duration) ctx.SetContext("ttft_recorded", struct{}{}) } return data @@ -179,7 +202,17 @@ func (lb ClusterEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpCon func (lb ClusterEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) { candidate := ctx.GetContext(lb.ClusterHeader).(string) - duration := time.Now().UnixMilli() - ctx.GetContext("request_start").(int64) - lb.TotalLatencyRequests[candidate].Enqueue(float64(duration)) lb.ServiceRequestOngoing[candidate] -= 1 + duration := float64(time.Now().UnixMilli() - ctx.GetContext("request_start").(int64)) + // punish failed request + if ctx.GetContext("statusCode").(string) != "200" { + for _, svc := range lb.ServiceList { + rt := lb.getServiceTotalRT(svc) + if duration < rt { + duration = rt + } + } + duration *= 2 + } + lb.TotalLatencyRequests[candidate].Enqueue(duration) }