mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 20:57:32 +08:00
[feat] load balancing across different clusters and endpoints based on metrics (#3063)
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package backend
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PodSet map[Pod]bool
|
||||
|
||||
type Pod struct {
|
||||
Name string
|
||||
Address string
|
||||
}
|
||||
|
||||
func (p Pod) String() string {
|
||||
return p.Name + ":" + p.Address
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
// ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU.
|
||||
ActiveModels map[string]int
|
||||
// MaxActiveModels is the maximum number of models that can be loaded to GPU.
|
||||
MaxActiveModels int
|
||||
RunningQueueSize int
|
||||
WaitingQueueSize int
|
||||
KVCacheUsagePercent float64
|
||||
KvCacheMaxTokenCapacity int
|
||||
}
|
||||
|
||||
type UserSelectedMetric struct {
|
||||
MetricName string
|
||||
MetricValue float64
|
||||
}
|
||||
|
||||
type PodMetrics struct {
|
||||
Pod
|
||||
Metrics
|
||||
UserSelectedMetric
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) String() string {
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v, UserSelectedMetric: %+v", pm.Pod, pm.Metrics, pm.UserSelectedMetric)
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
cm := make(map[string]int, len(pm.ActiveModels))
|
||||
for k, v := range pm.ActiveModels {
|
||||
cm[k] = v
|
||||
}
|
||||
clone := &PodMetrics{
|
||||
Pod: pm.Pod,
|
||||
Metrics: Metrics{
|
||||
ActiveModels: cm,
|
||||
RunningQueueSize: pm.RunningQueueSize,
|
||||
WaitingQueueSize: pm.WaitingQueueSize,
|
||||
KVCacheUsagePercent: pm.KVCacheUsagePercent,
|
||||
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
|
||||
},
|
||||
UserSelectedMetric: UserSelectedMetric{
|
||||
MetricName: pm.MetricName,
|
||||
MetricValue: pm.MetricValue,
|
||||
},
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package vllm provides vllm specific pod metrics implementation.
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
const (
|
||||
LoraRequestInfoMetricName = "vllm:lora_requests_info"
|
||||
LoraRequestInfoRunningAdaptersMetricName = "running_lora_adapters"
|
||||
LoraRequestInfoMaxAdaptersMetricName = "max_lora"
|
||||
// TODO: Replace these with the num_tokens_running/waiting below once we add those to the fork.
|
||||
RunningQueueSizeMetricName = "vllm:num_requests_running"
|
||||
WaitingQueueSizeMetricName = "vllm:num_requests_waiting"
|
||||
/* TODO: Uncomment this once the following are added to the fork.
|
||||
RunningQueueSizeMetricName = "vllm:num_tokens_running"
|
||||
WaitingQueueSizeMetricName = "vllm:num_tokens_waiting"
|
||||
*/
|
||||
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
|
||||
KvCacheMaxTokenCapacityMetricName = "vllm:gpu_cache_max_token_capacity"
|
||||
)
|
||||
|
||||
// promToPodMetrics updates internal pod metrics with scraped prometheus metrics.
|
||||
// A combined error is returned if errors occur in one or more metric processing.
|
||||
// it returns a new PodMetrics pointer which can be used to atomically update the pod metrics map.
|
||||
func PromToPodMetrics(
|
||||
metricFamilies map[string]*dto.MetricFamily,
|
||||
existing *backend.PodMetrics,
|
||||
) (*backend.PodMetrics, error) {
|
||||
var errs error
|
||||
updated := existing.Clone()
|
||||
// User selected metric
|
||||
if updated.MetricName != "" {
|
||||
metricValue, err := getLatestMetric(metricFamilies, updated.MetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.MetricValue = metricValue.GetGauge().GetValue()
|
||||
}
|
||||
return updated, errs
|
||||
}
|
||||
// Default metric
|
||||
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.RunningQueueSize = int(runningQueueSize.GetGauge().GetValue())
|
||||
}
|
||||
waitingQueueSize, err := getLatestMetric(metricFamilies, WaitingQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.WaitingQueueSize = int(waitingQueueSize.GetGauge().GetValue())
|
||||
}
|
||||
cachePercent, err := getLatestMetric(metricFamilies, KVCacheUsagePercentMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.KVCacheUsagePercent = cachePercent.GetGauge().GetValue()
|
||||
}
|
||||
|
||||
loraMetrics, _, err := getLatestLoraMetric(metricFamilies)
|
||||
errs = multierr.Append(errs, err)
|
||||
/* TODO: uncomment once this is available in vllm.
|
||||
kvCap, _, err := getGaugeLatestValue(metricFamilies, KvCacheMaxTokenCapacityMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err != nil {
|
||||
updated.KvCacheMaxTokenCapacity = int(kvCap)
|
||||
}
|
||||
*/
|
||||
|
||||
if loraMetrics != nil {
|
||||
updated.ActiveModels = make(map[string]int)
|
||||
for _, label := range loraMetrics.GetLabel() {
|
||||
if label.GetName() == LoraRequestInfoRunningAdaptersMetricName {
|
||||
if label.GetValue() != "" {
|
||||
adapterList := strings.Split(label.GetValue(), ",")
|
||||
for _, adapter := range adapterList {
|
||||
updated.ActiveModels[adapter] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if label.GetName() == LoraRequestInfoMaxAdaptersMetricName {
|
||||
if label.GetValue() != "" {
|
||||
updated.MaxActiveModels, err = strconv.Atoi(label.GetValue())
|
||||
if err != nil {
|
||||
errs = multierr.Append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return updated, errs
|
||||
}
|
||||
|
||||
// getLatestLoraMetric gets latest lora metric series in gauge metric family `vllm:lora_requests_info`
|
||||
// reason its specially fetched is because each label key value pair permutation generates new series
|
||||
// and only most recent is useful. The value of each series is the creation timestamp so we can
|
||||
// retrieve the latest by sorting the value.
|
||||
func getLatestLoraMetric(metricFamilies map[string]*dto.MetricFamily) (*dto.Metric, time.Time, error) {
|
||||
loraRequests, ok := metricFamilies[LoraRequestInfoMetricName]
|
||||
if !ok {
|
||||
// klog.Warningf("metric family %q not found", LoraRequestInfoMetricName)
|
||||
return nil, time.Time{}, fmt.Errorf("metric family %q not found", LoraRequestInfoMetricName)
|
||||
}
|
||||
var latestTs float64
|
||||
var latest *dto.Metric
|
||||
for _, m := range loraRequests.GetMetric() {
|
||||
if m.GetGauge().GetValue() > latestTs {
|
||||
latestTs = m.GetGauge().GetValue()
|
||||
latest = m
|
||||
}
|
||||
}
|
||||
return latest, time.Unix(0, int64(latestTs*1000)), nil
|
||||
}
|
||||
|
||||
// getLatestMetric gets the latest metric of a family. This should be used to get the latest Gauge metric.
|
||||
// Since vllm doesn't set the timestamp in metric, this metric essentially gets the first metric.
|
||||
func getLatestMetric(metricFamilies map[string]*dto.MetricFamily, metricName string) (*dto.Metric, error) {
|
||||
mf, ok := metricFamilies[metricName]
|
||||
if !ok {
|
||||
// klog.Warningf("metric family %q not found", metricName)
|
||||
return nil, fmt.Errorf("metric family %q not found", metricName)
|
||||
}
|
||||
if len(mf.GetMetric()) == 0 {
|
||||
return nil, fmt.Errorf("no metrics available for %q", metricName)
|
||||
}
|
||||
var latestTs int64
|
||||
var latest *dto.Metric
|
||||
for _, m := range mf.GetMetric() {
|
||||
if m.GetTimestampMs() >= latestTs {
|
||||
latestTs = m.GetTimestampMs()
|
||||
latest = m
|
||||
}
|
||||
}
|
||||
// klog.V(logutil.TRACE).Infof("Got metric value %+v for metric %v", latest, metricName)
|
||||
return latest, nil
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package endpoint_metrics
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
FixedQueueSize = 100
|
||||
)
|
||||
|
||||
type MetricsEndpointLoadBalancer struct {
|
||||
metricPolicy string
|
||||
targetMetric string
|
||||
endpointRequests *utils.FixedQueue[string]
|
||||
maxRate float64
|
||||
}
|
||||
|
||||
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
|
||||
lb := MetricsEndpointLoadBalancer{}
|
||||
if json.Get("metric_policy").Exists() {
|
||||
lb.metricPolicy = json.Get("metric_policy").String()
|
||||
} else {
|
||||
lb.metricPolicy = scheduling.MetricPolicyDefault
|
||||
}
|
||||
if json.Get("target_metric").Exists() {
|
||||
lb.targetMetric = json.Get("target_metric").String()
|
||||
}
|
||||
if json.Get("rate_limit").Exists() {
|
||||
lb.maxRate = json.Get("rate_limit").Float()
|
||||
} else {
|
||||
lb.maxRate = 1.0
|
||||
}
|
||||
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: true,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
finalAddress := targetPod.Address
|
||||
otherHosts := []string{} // 如果当前host超过请求数限制,那么在其中随机挑选一个
|
||||
currentRate := 0.0
|
||||
for k := range hostMetrics {
|
||||
if k != finalAddress {
|
||||
otherHosts = append(otherHosts, k)
|
||||
}
|
||||
}
|
||||
if lb.endpointRequests.Size() != 0 {
|
||||
count := 0.0
|
||||
lb.endpointRequests.ForEach(func(i int, item string) {
|
||||
if item == finalAddress {
|
||||
count += 1
|
||||
}
|
||||
})
|
||||
currentRate = count / float64(lb.endpointRequests.Size())
|
||||
}
|
||||
if currentRate > lb.maxRate && len(otherHosts) > 0 {
|
||||
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
|
||||
}
|
||||
lb.endpointRequests.Enqueue(finalAddress)
|
||||
log.Debugf("pod %s is selected", finalAddress)
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}
|
||||
@@ -0,0 +1,203 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
type Filter interface {
|
||||
Name() string
|
||||
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
|
||||
}
|
||||
|
||||
// filter applies current filterFunc, and then recursively applies next filters depending success or
|
||||
// failure of the current filterFunc.
|
||||
// It can be used to construct a flow chart algorithm.
|
||||
type filter struct {
|
||||
name string
|
||||
filter filterFunc
|
||||
// nextOnSuccess filter will be applied after successfully applying the current filter.
|
||||
// The filtered results will be passed to the next filter.
|
||||
nextOnSuccess *filter
|
||||
// nextOnFailure filter will be applied if current filter fails.
|
||||
// The original input will be passed to the next filter.
|
||||
nextOnFailure *filter
|
||||
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
|
||||
// success or failure of the current filter.
|
||||
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
|
||||
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
|
||||
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
|
||||
nextOnSuccessOrFailure *filter
|
||||
|
||||
// callbacks api.FilterCallbackHandler
|
||||
}
|
||||
|
||||
func (f *filter) Name() string {
|
||||
if f == nil {
|
||||
return "nil"
|
||||
}
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
proxywasm.LogDebugf("Running filter %q on request %v with %v pods", f.name, req, len(pods))
|
||||
filtered, err := f.filter(req, pods)
|
||||
|
||||
next := f.nextOnSuccessOrFailure
|
||||
if err == nil && len(filtered) > 0 {
|
||||
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
|
||||
// No succeeding filters to run, return.
|
||||
return filtered, err
|
||||
}
|
||||
if f.nextOnSuccess != nil {
|
||||
next = f.nextOnSuccess
|
||||
}
|
||||
// On success, pass the filtered result to the next filter.
|
||||
return next.Filter(req, filtered)
|
||||
} else {
|
||||
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
|
||||
// No succeeding filters to run, return.
|
||||
return filtered, err
|
||||
}
|
||||
if f.nextOnFailure != nil {
|
||||
next = f.nextOnFailure
|
||||
}
|
||||
// On failure, pass the initial set of pods to the next filter.
|
||||
return next.Filter(req, pods)
|
||||
}
|
||||
}
|
||||
|
||||
// filterFunc filters a set of input pods to a subset.
|
||||
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
|
||||
|
||||
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
|
||||
func toFilterFunc(pp podPredicate) filterFunc {
|
||||
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
filtered := []*backend.PodMetrics{}
|
||||
for _, pod := range pods {
|
||||
pass := pp(req, pod)
|
||||
if pass {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, errors.New("no pods left")
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
}
|
||||
|
||||
// leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range
|
||||
// (max-min) by the number of pods, and finds the pods that fall into the first range.
|
||||
// The intuition is that if there are multiple pods that share similar queue size in the low range,
|
||||
// we should consider them all instead of the absolute minimum one. This worked better than picking
|
||||
// the least one as it gives more choices for the next filter, which on aggregate gave better
|
||||
// results.
|
||||
// TODO: Compare this strategy with other strategies such as top K.
|
||||
func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxInt
|
||||
max := 0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.WaitingQueueSize <= min {
|
||||
min = pod.WaitingQueueSize
|
||||
}
|
||||
if pod.WaitingQueueSize >= max {
|
||||
max = pod.WaitingQueueSize
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func lowQueueingPodPredicate(_ *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return pod.WaitingQueueSize < queueingThresholdLoRA
|
||||
}
|
||||
|
||||
// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range
|
||||
// (max-min) by the number of pods, and finds the pods that fall into the first range.
|
||||
// The intuition is that if there are multiple pods that share similar KV cache in the low range, we
|
||||
// should consider them all instead of the absolute minimum one. This worked better than picking the
|
||||
// least one as it gives more choices for the next filter, which on aggregate gave better results.
|
||||
// TODO: Compare this strategy with other strategies such as top K.
|
||||
func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
var max float64 = 0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.KVCacheUsagePercent <= min {
|
||||
min = pod.KVCacheUsagePercent
|
||||
}
|
||||
if pod.KVCacheUsagePercent >= max {
|
||||
max = pod.KVCacheUsagePercent
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// podPredicate is a filter function to check whether a pod is desired.
|
||||
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
|
||||
|
||||
// We consider serving an adapter low cost it the adapter is active in the model server, or the
|
||||
// model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by
|
||||
// spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to
|
||||
// a single pod. This gave good performance in our initial benchmarking results in the scenario
|
||||
// where # of lora slots > # of lora adapters.
|
||||
func lowLoRACostPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
_, ok := pod.ActiveModels[req.Model]
|
||||
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
|
||||
}
|
||||
|
||||
// loRAAffinityPredicate is a filter function to check whether a pod has affinity to the lora requested.
|
||||
func loRAAffinityPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
_, ok := pod.ActiveModels[req.Model]
|
||||
return ok
|
||||
}
|
||||
|
||||
// canAcceptNewLoraPredicate is a filter function to check whether a pod has room to load the adapter.
|
||||
func canAcceptNewLoraPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return len(pod.ActiveModels) < pod.MaxActiveModels
|
||||
}
|
||||
|
||||
func criticalRequestPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return req.Critical
|
||||
}
|
||||
|
||||
func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate {
|
||||
return func(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package scheduling implements request scheduling algorithms.
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend/vllm"
|
||||
|
||||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
const (
|
||||
MetricPolicyDefault = "default"
|
||||
MetricPolicyLeast = "least"
|
||||
MetricPolicyMost = "most"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
kvCacheThreshold = 0.8
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
queueThresholdCritical = 5
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
// the threshold for queued requests to be considered low below which we can prioritize LoRA affinity.
|
||||
// The value of 50 is arrived heuristicically based on experiments.
|
||||
queueingThresholdLoRA = 50
|
||||
)
|
||||
|
||||
var (
|
||||
defaultFilter = &filter{
|
||||
name: "critical request",
|
||||
filter: toFilterFunc(criticalRequestPredicate),
|
||||
nextOnSuccess: lowLatencyFilter,
|
||||
nextOnFailure: sheddableRequestFilter,
|
||||
}
|
||||
|
||||
// queueLoRAAndKVCacheFilter applied least queue -> low cost lora -> least KV Cache filter
|
||||
queueLoRAAndKVCacheFilter = &filter{
|
||||
name: "least queuing",
|
||||
filter: leastQueuingFilterFunc,
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "low cost LoRA",
|
||||
filter: toFilterFunc(lowLoRACostPredicate),
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "least KV cache percent",
|
||||
filter: leastKVCacheFilterFunc,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// queueAndKVCacheFilter applies least queue followed by least KV Cache filter
|
||||
queueAndKVCacheFilter = &filter{
|
||||
name: "least queuing",
|
||||
filter: leastQueuingFilterFunc,
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "least KV cache percent",
|
||||
filter: leastKVCacheFilterFunc,
|
||||
},
|
||||
}
|
||||
|
||||
lowLatencyFilter = &filter{
|
||||
name: "low queueing filter",
|
||||
filter: toFilterFunc((lowQueueingPodPredicate)),
|
||||
nextOnSuccess: &filter{
|
||||
name: "affinity LoRA",
|
||||
filter: toFilterFunc(loRAAffinityPredicate),
|
||||
nextOnSuccess: queueAndKVCacheFilter,
|
||||
nextOnFailure: &filter{
|
||||
name: "can accept LoRA Adapter",
|
||||
filter: toFilterFunc(canAcceptNewLoraPredicate),
|
||||
nextOnSuccessOrFailure: queueAndKVCacheFilter,
|
||||
},
|
||||
},
|
||||
nextOnFailure: queueLoRAAndKVCacheFilter,
|
||||
}
|
||||
|
||||
sheddableRequestFilter = &filter{
|
||||
// When there is at least one model server that's not queuing requests, and still has KV
|
||||
// cache below a certain threshold, we consider this model server has capacity to handle
|
||||
// a sheddable request without impacting critical requests.
|
||||
name: "has capacity for sheddable requests",
|
||||
filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(queueThresholdCritical, kvCacheThreshold)),
|
||||
nextOnSuccess: queueLoRAAndKVCacheFilter,
|
||||
// If all pods are queuing or running above the KVCache threshold, we drop the sheddable
|
||||
// request to make room for critical requests.
|
||||
nextOnFailure: &filter{
|
||||
name: "drop request",
|
||||
filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
// api.LogDebugf("Dropping request %v", req)
|
||||
return []*backend.PodMetrics{}, errors.New("dropping request due to limited backend resources")
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func NewScheduler(pm []*backend.PodMetrics, filter Filter) *Scheduler {
|
||||
|
||||
return &Scheduler{
|
||||
podMetrics: pm,
|
||||
filter: filter,
|
||||
}
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
podMetrics []*backend.PodMetrics
|
||||
filter Filter
|
||||
}
|
||||
|
||||
// Schedule finds the target pod based on metrics and the requested lora adapter.
|
||||
func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error) {
|
||||
pods, err := s.filter.Filter(req, s.podMetrics)
|
||||
if err != nil || len(pods) == 0 {
|
||||
return backend.Pod{}, fmt.Errorf("failed to apply filter, resulted %v pods: %w", len(pods), err)
|
||||
}
|
||||
i := rand.Intn(len(pods))
|
||||
return pods[i].Pod, nil
|
||||
}
|
||||
|
||||
func GetScheduler(hostMetrics map[string]string, metricPolicy string, targetMetric string) (*Scheduler, error) {
|
||||
if len(hostMetrics) == 0 {
|
||||
return nil, errors.New("backend is not support llm scheduling")
|
||||
}
|
||||
var pms []*backend.PodMetrics
|
||||
for addr, metric := range hostMetrics {
|
||||
parser := expfmt.TextParser{}
|
||||
metricFamilies, err := parser.TextToMetricFamilies(strings.NewReader(metric))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pm := &backend.PodMetrics{
|
||||
Pod: backend.Pod{
|
||||
Name: addr,
|
||||
Address: addr,
|
||||
},
|
||||
Metrics: backend.Metrics{},
|
||||
UserSelectedMetric: backend.UserSelectedMetric{
|
||||
MetricName: targetMetric,
|
||||
},
|
||||
}
|
||||
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pms = append(pms, pm)
|
||||
}
|
||||
if metricPolicy == MetricPolicyLeast {
|
||||
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
max := 0.0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= min {
|
||||
min = pod.MetricValue
|
||||
}
|
||||
if pod.MetricValue >= max {
|
||||
max = pod.MetricValue
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue >= min && pod.MetricValue <= min+(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
filter := filter{
|
||||
name: "least user selected metric",
|
||||
filter: filterFunc,
|
||||
}
|
||||
return NewScheduler(pms, &filter), nil
|
||||
} else if metricPolicy == MetricPolicyMost {
|
||||
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
max := 0.0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= min {
|
||||
min = pod.MetricValue
|
||||
}
|
||||
if pod.MetricValue >= max {
|
||||
max = pod.MetricValue
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= max && pod.MetricValue >= max-(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
filter := filter{
|
||||
name: "most user selected metric",
|
||||
filter: filterFunc,
|
||||
}
|
||||
return NewScheduler(pms, &filter), nil
|
||||
}
|
||||
return NewScheduler(pms, defaultFilter), nil
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package scheduling
|
||||
|
||||
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
|
||||
type LLMRequest struct {
|
||||
Model string
|
||||
Critical bool
|
||||
}
|
||||
Reference in New Issue
Block a user