feat(ai-load-balancer): add cluster_hash load balancing policy with FNV-1a consistent hashing (#3898)

Signed-off-by: zat366 <authentic.zhao@gmail.com>
This commit is contained in:
zat366
2026-06-01 10:19:46 +08:00
committed by GitHub
parent c21a38e783
commit 52c99eb27d
5 changed files with 388 additions and 5 deletions

View File

@@ -0,0 +1,123 @@
package cluster_hash
import (
"fmt"
"hash/fnv"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
DefaultHashHeader = "x-mse-consumer"
DefaultClusterHeader = "x-higress-target-cluster"
)
type clusterEntry struct {
Cluster string
Weight int
}
type ClusterHashLoadBalancer struct {
HashHeader string
ClusterHeader string
// slots is expanded from clusters by weight, length == 100.
slots []string
}
func NewClusterHashLoadBalancer(json gjson.Result) (ClusterHashLoadBalancer, error) {
lb := ClusterHashLoadBalancer{}
lb.HashHeader = json.Get("hash_header").String()
if lb.HashHeader == "" {
lb.HashHeader = DefaultHashHeader
}
lb.ClusterHeader = json.Get("cluster_header").String()
if lb.ClusterHeader == "" {
lb.ClusterHeader = DefaultClusterHeader
}
clustersJson := json.Get("clusters")
if !clustersJson.Exists() || !clustersJson.IsArray() || len(clustersJson.Array()) == 0 {
return lb, fmt.Errorf("clusters is required and must be a non-empty array")
}
var clusters []clusterEntry
var totalWeight int
for _, c := range clustersJson.Array() {
cluster := c.Get("cluster").String()
if cluster == "" {
return lb, fmt.Errorf("each entry must have a non-empty cluster field")
}
weight := int(c.Get("weight").Int())
if weight <= 0 {
return lb, fmt.Errorf("cluster %q has invalid weight %d, must be > 0", cluster, weight)
}
clusters = append(clusters, clusterEntry{Cluster: cluster, Weight: weight})
totalWeight += weight
}
if totalWeight != 100 {
return lb, fmt.Errorf("sum of cluster weights must be 100, got %d", totalWeight)
}
slots := make([]string, 0, 100)
for _, c := range clusters {
for i := 0; i < c.Weight; i++ {
slots = append(slots, c.Cluster)
}
}
lb.slots = slots
return lb, nil
}
func (lb ClusterHashLoadBalancer) selectCluster(hashKey string) string {
h := fnv.New32a()
h.Write([]byte(hashKey))
index := int(h.Sum32()) % len(lb.slots)
if index < 0 {
index += len(lb.slots)
}
return lb.slots[index]
}
func (lb ClusterHashLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
hashKey, err := proxywasm.GetHttpRequestHeader(lb.HashHeader)
if err != nil || hashKey == "" {
log.Warnf("[ai-load-balancer/cluster_hash] missing hash header %q, rejecting request", lb.HashHeader)
_ = proxywasm.SendHttpResponse(403, nil, []byte("hash header required"), -1)
return types.ActionPause
}
cluster := lb.selectCluster(hashKey)
if err := proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, cluster); err != nil {
log.Errorf("[ai-load-balancer/cluster_hash] failed to set target header: %v", err)
_ = proxywasm.SendHttpResponse(500, nil, []byte("internal error"), -1)
return types.ActionPause
}
log.Debugf("[ai-load-balancer/cluster_hash] %s=%s -> %s=%s", lb.HashHeader, hashKey, lb.ClusterHeader, cluster)
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb ClusterHashLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterHashLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}

View File

@@ -0,0 +1,171 @@
package cluster_hash
import (
"fmt"
"testing"
"github.com/tidwall/gjson"
)
func TestParseConfig_Valid(t *testing.T) {
json := gjson.Parse(`{
"clusters": [
{"cluster": "outbound|443||llm-a.internal.dns", "weight": 70},
{"cluster": "outbound|443||llm-b.internal.dns", "weight": 30}
]
}`)
lb, err := NewClusterHashLoadBalancer(json)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if lb.HashHeader != DefaultHashHeader {
t.Errorf("expected default hash_header %q, got %q", DefaultHashHeader, lb.HashHeader)
}
if lb.ClusterHeader != DefaultClusterHeader {
t.Errorf("expected default cluster_header %q, got %q", DefaultClusterHeader, lb.ClusterHeader)
}
if len(lb.slots) != 100 {
t.Errorf("expected 100 slots, got %d", len(lb.slots))
}
}
func TestParseConfig_CustomHeaders(t *testing.T) {
json := gjson.Parse(`{
"hash_header": "x-custom-key",
"cluster_header": "x-custom-target",
"clusters": [
{"cluster": "outbound|443||llm-a.internal.dns", "weight": 100}
]
}`)
lb, err := NewClusterHashLoadBalancer(json)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if lb.HashHeader != "x-custom-key" {
t.Errorf("got hash_header %q", lb.HashHeader)
}
if lb.ClusterHeader != "x-custom-target" {
t.Errorf("got cluster_header %q", lb.ClusterHeader)
}
}
func TestParseConfig_WeightNotSum100(t *testing.T) {
json := gjson.Parse(`{
"clusters": [
{"cluster": "outbound|443||llm-a.internal.dns", "weight": 60},
{"cluster": "outbound|443||llm-b.internal.dns", "weight": 30}
]
}`)
if _, err := NewClusterHashLoadBalancer(json); err == nil {
t.Fatal("expected error for weights not summing to 100")
}
}
func TestParseConfig_EmptyClusters(t *testing.T) {
json := gjson.Parse(`{"clusters": []}`)
if _, err := NewClusterHashLoadBalancer(json); err == nil {
t.Fatal("expected error for empty clusters")
}
}
func TestParseConfig_MissingClusters(t *testing.T) {
json := gjson.Parse(`{}`)
if _, err := NewClusterHashLoadBalancer(json); err == nil {
t.Fatal("expected error for missing clusters field")
}
}
func TestParseConfig_MissingClusterField(t *testing.T) {
json := gjson.Parse(`{
"clusters": [
{"weight": 100}
]
}`)
if _, err := NewClusterHashLoadBalancer(json); err == nil {
t.Fatal("expected error for missing cluster field")
}
}
func TestParseConfig_ZeroWeight(t *testing.T) {
json := gjson.Parse(`{
"clusters": [
{"cluster": "outbound|443||llm-a.internal.dns", "weight": 0},
{"cluster": "outbound|443||llm-b.internal.dns", "weight": 100}
]
}`)
if _, err := NewClusterHashLoadBalancer(json); err == nil {
t.Fatal("expected error for zero weight")
}
}
func TestSelectCluster_Consistency(t *testing.T) {
lb := buildLB(t, []clusterEntry{
{Cluster: "outbound|443||llm-a.internal.dns", Weight: 50},
{Cluster: "outbound|443||llm-b.internal.dns", Weight: 50},
})
key := "alice"
first := lb.selectCluster(key)
for range 10 {
if got := lb.selectCluster(key); got != first {
t.Errorf("inconsistent result for same key: got %q, want %q", got, first)
}
}
}
func TestSelectCluster_Distribution(t *testing.T) {
clusterA := "outbound|443||llm-a.internal.dns"
clusterB := "outbound|443||llm-b.internal.dns"
lb := buildLB(t, []clusterEntry{
{Cluster: clusterA, Weight: 70},
{Cluster: clusterB, Weight: 30},
})
hasA, hasB := false, false
for _, c := range lb.slots {
switch c {
case clusterA:
hasA = true
case clusterB:
hasB = true
}
}
if !hasA || !hasB {
t.Fatalf("weight-expanded slots must include both clusters, hasA=%v hasB=%v", hasA, hasB)
}
seen := map[string]struct{}{}
for i := 0; i < 256 && len(seen) < 2; i++ {
seen[lb.selectCluster(fmt.Sprintf("key-%d", i))] = struct{}{}
}
if len(seen) < 2 {
t.Errorf("expected hash routing to reach at least 2 clusters, got %v", seen)
}
}
func TestSelectCluster_SingleCluster(t *testing.T) {
target := "outbound|443||llm-a.internal.dns"
lb := buildLB(t, []clusterEntry{
{Cluster: target, Weight: 100},
})
for _, key := range []string{"alice", "bob", "carol"} {
if got := lb.selectCluster(key); got != target {
t.Errorf("single cluster: expected %q, got %q for key %q", target, got, key)
}
}
}
func buildLB(t *testing.T, entries []clusterEntry) ClusterHashLoadBalancer {
t.Helper()
slots := make([]string, 0, 100)
for _, e := range entries {
for i := 0; i < e.Weight; i++ {
slots = append(slots, e.Cluster)
}
}
return ClusterHashLoadBalancer{
HashHeader: DefaultHashHeader,
ClusterHeader: DefaultClusterHeader,
slots: slots,
}
}