mirror of
https://github.com/alibaba/higress.git
synced 2026-05-26 13:47:27 +08:00
test(ai-proxy): expand wasm integration tests, coverage, and provider matrix (#3790)
Signed-off-by: jingze <daijingze.djz@alibaba-inc.com>
This commit is contained in:
@@ -1141,6 +1141,25 @@ func TestNormalizeFinishReason(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_streaming_tool_call_smoke(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
start := `data: {"id":"tc1","choices":[{"index":0,"delta":{"role":"assistant","content":""}}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
_, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(start))
|
||||
require.NoError(t, err)
|
||||
|
||||
toolChunk := `data: {"id":"tc1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"my_fn","arguments":""}}]}}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
out, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(toolChunk))
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(out), "content_block_start")
|
||||
require.Contains(t, string(out), "tool_use")
|
||||
|
||||
argChunk := `data: {"id":"tc1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"x\":1}"}}]}}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
out2, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(argChunk))
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(out2), "input_json_delta")
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_ConvertOpenAIStreamResponseToClaude_Compatibility(t *testing.T) {
|
||||
t.Run("finish_reason empty string should not stop stream", func(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
@@ -544,6 +544,19 @@ func TestFailover_FromJson_FailoverOnStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailover_Validate(t *testing.T) {
|
||||
t.Run("missing_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
f.FromJson(gjson.Parse(`{"enabled":true}`))
|
||||
assert.Error(t, f.Validate())
|
||||
})
|
||||
t.Run("ok_with_healthCheckModel", func(t *testing.T) {
|
||||
f := &failover{}
|
||||
f.FromJson(gjson.Parse(`{"enabled":true,"healthCheckModel":"gpt-4o-mini"}`))
|
||||
assert.NoError(t, f.Validate())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHealthCheckEndpoint_Struct(t *testing.T) {
|
||||
t.Run("health_check_endpoint_fields", func(t *testing.T) {
|
||||
endpoint := HealthCheckEndpoint{
|
||||
@@ -679,6 +692,32 @@ func TestProviderConfig_SetDefaultCapabilities(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateProvider(t *testing.T) {
|
||||
t.Run("generic_success", func(t *testing.T) {
|
||||
var pc ProviderConfig
|
||||
pc.FromJson(gjson.Parse(`{"type":"generic","genericHost":"http://127.0.0.1:8080","apiTokens":["t"]}`))
|
||||
p, err := CreateProvider(pc)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, providerTypeGeneric, p.GetProviderType())
|
||||
})
|
||||
|
||||
t.Run("openai_minimal_success", func(t *testing.T) {
|
||||
var pc ProviderConfig
|
||||
pc.FromJson(gjson.Parse(`{"type":"openai","apiTokens":["sk-test"]}`))
|
||||
p, err := CreateProvider(pc)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, providerTypeOpenAI, p.GetProviderType())
|
||||
})
|
||||
|
||||
t.Run("unknown_type", func(t *testing.T) {
|
||||
var pc ProviderConfig
|
||||
pc.FromJson(gjson.Parse(`{"type":"no-such-provider-xyz","apiTokens":["t"]}`))
|
||||
_, err := CreateProvider(pc)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown provider type")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStripClaudeInternalMessageFields(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model":"claude",
|
||||
|
||||
126
plugins/wasm-go/extensions/ai-proxy/provider/retry_test.go
Normal file
126
plugins/wasm-go/extensions/ai-proxy/provider/retry_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/iface"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// mapCtx is a minimal wrapper.HttpContext for offline tests (no import cycle with test package).
|
||||
type mapCtx struct {
|
||||
kv map[string]interface{}
|
||||
}
|
||||
|
||||
func newMapCtx() *mapCtx {
|
||||
return &mapCtx{kv: make(map[string]interface{})}
|
||||
}
|
||||
|
||||
func (m *mapCtx) SetContext(key string, value interface{}) { m.kv[key] = value }
|
||||
func (m *mapCtx) GetContext(key string) interface{} { return m.kv[key] }
|
||||
func (m *mapCtx) GetBoolContext(key string, def bool) bool { return def }
|
||||
func (m *mapCtx) GetStringContext(key, def string) string { return def }
|
||||
func (m *mapCtx) GetByteSliceContext(key string, def []byte) []byte { return def }
|
||||
func (m *mapCtx) Scheme() string { return "" }
|
||||
func (m *mapCtx) Host() string { return "" }
|
||||
func (m *mapCtx) Path() string { return "" }
|
||||
func (m *mapCtx) Method() string { return "" }
|
||||
func (m *mapCtx) GetUserAttribute(key string) interface{} { return nil }
|
||||
func (m *mapCtx) SetUserAttribute(key string, value interface{}) {}
|
||||
func (m *mapCtx) SetUserAttributeMap(kvmap map[string]interface{}) {}
|
||||
func (m *mapCtx) GetUserAttributeMap() map[string]interface{} { return nil }
|
||||
func (m *mapCtx) WriteUserAttributeToLog() error { return nil }
|
||||
func (m *mapCtx) WriteUserAttributeToLogWithKey(key string) error { return nil }
|
||||
func (m *mapCtx) WriteUserAttributeToTrace() error { return nil }
|
||||
func (m *mapCtx) DontReadRequestBody() {}
|
||||
func (m *mapCtx) DontReadResponseBody() {}
|
||||
func (m *mapCtx) BufferRequestBody() {}
|
||||
func (m *mapCtx) BufferResponseBody() {}
|
||||
func (m *mapCtx) NeedPauseStreamingResponse() {}
|
||||
func (m *mapCtx) PushBuffer(buffer []byte) {}
|
||||
func (m *mapCtx) PopBuffer() []byte { return nil }
|
||||
func (m *mapCtx) BufferQueueSize() int { return 0 }
|
||||
func (m *mapCtx) DisableReroute() {}
|
||||
func (m *mapCtx) SetRequestBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mapCtx) SetResponseBodyBufferLimit(byteSize uint32) {}
|
||||
func (m *mapCtx) RouteCall(method, url string, headers [][2]string, body []byte, callback iface.RouteResponseCallback) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mapCtx) GetExecutionPhase() iface.HTTPExecutionPhase { return 0 }
|
||||
func (m *mapCtx) HasRequestBody() bool { return false }
|
||||
func (m *mapCtx) HasResponseBody() bool { return false }
|
||||
func (m *mapCtx) IsWebsocket() bool { return false }
|
||||
func (m *mapCtx) IsBinaryRequestBody() bool { return false }
|
||||
func (m *mapCtx) IsBinaryResponseBody() bool { return false }
|
||||
|
||||
var _ wrapper.HttpContext = (*mapCtx)(nil)
|
||||
|
||||
type stubProviderType struct{}
|
||||
|
||||
func (stubProviderType) GetProviderType() string { return providerTypeOpenAI }
|
||||
|
||||
func TestRemoveApiTokenFromRetryList(t *testing.T) {
|
||||
t.Run("removes_token", func(t *testing.T) {
|
||||
got := removeApiTokenFromRetryList([]string{"a", "b", "c"}, "b")
|
||||
assert.Equal(t, []string{"a", "c"}, got)
|
||||
})
|
||||
t.Run("removes_all_when_single", func(t *testing.T) {
|
||||
got := removeApiTokenFromRetryList([]string{"x"}, "x")
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
t.Run("no_match_unchanged", func(t *testing.T) {
|
||||
got := removeApiTokenFromRetryList([]string{"a", "b"}, "z")
|
||||
assert.Equal(t, []string{"a", "b"}, got)
|
||||
})
|
||||
t.Run("empty_input", func(t *testing.T) {
|
||||
got := removeApiTokenFromRetryList(nil, "a")
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRandomToken(t *testing.T) {
|
||||
assert.Equal(t, "", GetRandomToken(nil))
|
||||
assert.Equal(t, "", GetRandomToken([]string{}))
|
||||
assert.Equal(t, "only", GetRandomToken([]string{"only"}))
|
||||
tokens := []string{"a", "b", "c"}
|
||||
for i := 0; i < 20; i++ {
|
||||
got := GetRandomToken(tokens)
|
||||
assert.Contains(t, tokens, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryOnFailure_FromJson_defaults(t *testing.T) {
|
||||
var c ProviderConfig
|
||||
c.FromJson(gjson.Parse(`{"type":"openai","apiTokens":["t"],"retryOnFailure":{"enabled":true}}`))
|
||||
require.True(t, c.IsRetryOnFailureEnabled())
|
||||
assert.Equal(t, int64(1), c.retryOnFailure.maxRetries)
|
||||
assert.Equal(t, int64(60*1000), c.retryOnFailure.retryTimeout)
|
||||
assert.Equal(t, []string{"4.*", "5.*"}, c.retryOnFailure.retryOnStatus)
|
||||
}
|
||||
|
||||
func TestOnRequestFailed_offlineBranches(t *testing.T) {
|
||||
t.Run("no_failover_no_retry_always_continue", func(t *testing.T) {
|
||||
var c ProviderConfig
|
||||
c.FromJson(gjson.Parse(`{"type":"openai","apiTokens":["t"]}`))
|
||||
ctx := newMapCtx()
|
||||
act := c.OnRequestFailed(stubProviderType{}, ctx, "t", []string{"t"}, "503")
|
||||
assert.Equal(t, types.ActionContinue, act)
|
||||
})
|
||||
|
||||
t.Run("retry_enabled_single_token_returns_continue_before_post", func(t *testing.T) {
|
||||
var c ProviderConfig
|
||||
c.FromJson(gjson.Parse(`{
|
||||
"type":"openai",
|
||||
"apiTokens":["only"],
|
||||
"retryOnFailure":{"enabled":true,"retryOnStatus":["429","503"]}
|
||||
}`))
|
||||
ctx := newMapCtx()
|
||||
ctx.SetContext(CtxKeyApiName, ApiNameChatCompletion)
|
||||
act := c.OnRequestFailed(stubProviderType{}, ctx, "only", []string{"only"}, "503")
|
||||
assert.Equal(t, types.ActionContinue, act)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractStreamingEvents(t *testing.T) {
|
||||
t.Run("empty_chunk", func(t *testing.T) {
|
||||
ctx := newMapCtx()
|
||||
events := ExtractStreamingEvents(ctx, nil)
|
||||
assert.Empty(t, events)
|
||||
})
|
||||
|
||||
t.Run("crlf_normalized", func(t *testing.T) {
|
||||
ctx := newMapCtx()
|
||||
chunk := "event:msg\r\ndata:{\"k\":1}\r\n\r\n"
|
||||
events := ExtractStreamingEvents(ctx, []byte(chunk))
|
||||
require.NotEmpty(t, events)
|
||||
})
|
||||
|
||||
t.Run("qwen_style_block", func(t *testing.T) {
|
||||
ctx := newMapCtx()
|
||||
chunk := "event:result\n:HTTP_STATUS/200\ndata:{\"output\":1}\n\n"
|
||||
events := ExtractStreamingEvents(ctx, []byte(chunk))
|
||||
require.NotEmpty(t, events)
|
||||
foundData := false
|
||||
for _, e := range events {
|
||||
if strings.Contains(e.RawEvent, "data:") {
|
||||
foundData = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundData, "expected a data line in parsed events: %#v", events)
|
||||
})
|
||||
|
||||
t.Run("split_chunk_buffers_incomplete", func(t *testing.T) {
|
||||
ctx := newMapCtx()
|
||||
part1 := []byte("event:a\n")
|
||||
_ = ExtractStreamingEvents(ctx, part1)
|
||||
buf, has := ctx.GetContext(ctxKeyStreamingBody).([]byte)
|
||||
require.True(t, has, "expected streaming body buffer after incomplete chunk")
|
||||
require.NotEmpty(t, buf)
|
||||
|
||||
part2 := []byte("data:{}\n\n")
|
||||
events := ExtractStreamingEvents(ctx, part2)
|
||||
require.NotEmpty(t, events)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user