Files
higress/plugins/wasm-go/extensions/ai-load-balancer/cluster_hash/lb_policy.go

124 lines
3.6 KiB
Go

package cluster_hash
import (
"fmt"
"hash/fnv"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
DefaultHashHeader = "x-mse-consumer"
DefaultClusterHeader = "x-higress-target-cluster"
)
type clusterEntry struct {
Cluster string
Weight int
}
type ClusterHashLoadBalancer struct {
HashHeader string
ClusterHeader string
// slots is expanded from clusters by weight, length == 100.
slots []string
}
func NewClusterHashLoadBalancer(json gjson.Result) (ClusterHashLoadBalancer, error) {
lb := ClusterHashLoadBalancer{}
lb.HashHeader = json.Get("hash_header").String()
if lb.HashHeader == "" {
lb.HashHeader = DefaultHashHeader
}
lb.ClusterHeader = json.Get("cluster_header").String()
if lb.ClusterHeader == "" {
lb.ClusterHeader = DefaultClusterHeader
}
clustersJson := json.Get("clusters")
if !clustersJson.Exists() || !clustersJson.IsArray() || len(clustersJson.Array()) == 0 {
return lb, fmt.Errorf("clusters is required and must be a non-empty array")
}
var clusters []clusterEntry
var totalWeight int
for _, c := range clustersJson.Array() {
cluster := c.Get("cluster").String()
if cluster == "" {
return lb, fmt.Errorf("each entry must have a non-empty cluster field")
}
weight := int(c.Get("weight").Int())
if weight <= 0 {
return lb, fmt.Errorf("cluster %q has invalid weight %d, must be > 0", cluster, weight)
}
clusters = append(clusters, clusterEntry{Cluster: cluster, Weight: weight})
totalWeight += weight
}
if totalWeight != 100 {
return lb, fmt.Errorf("sum of cluster weights must be 100, got %d", totalWeight)
}
slots := make([]string, 0, 100)
for _, c := range clusters {
for i := 0; i < c.Weight; i++ {
slots = append(slots, c.Cluster)
}
}
lb.slots = slots
return lb, nil
}
func (lb ClusterHashLoadBalancer) selectCluster(hashKey string) string {
h := fnv.New32a()
h.Write([]byte(hashKey))
index := int(h.Sum32()) % len(lb.slots)
if index < 0 {
index += len(lb.slots)
}
return lb.slots[index]
}
func (lb ClusterHashLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
hashKey, err := proxywasm.GetHttpRequestHeader(lb.HashHeader)
if err != nil || hashKey == "" {
log.Warnf("[ai-load-balancer/cluster_hash] missing hash header %q, rejecting request", lb.HashHeader)
_ = proxywasm.SendHttpResponse(403, nil, []byte("hash header required"), -1)
return types.ActionPause
}
cluster := lb.selectCluster(hashKey)
if err := proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, cluster); err != nil {
log.Errorf("[ai-load-balancer/cluster_hash] failed to set target header: %v", err)
_ = proxywasm.SendHttpResponse(500, nil, []byte("internal error"), -1)
return types.ActionPause
}
log.Debugf("[ai-load-balancer/cluster_hash] %s=%s -> %s=%s", lb.HashHeader, hashKey, lb.ClusterHeader, cluster)
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb ClusterHashLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}