Files
higress/plugins/wasm-go/extensions/ai-proxy/streaming_matrix_test.go

136 lines
4.9 KiB
Go

package main
import (
"encoding/json"
"errors"
"testing"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
wasmtest "github.com/higress-group/wasm-go/pkg/test"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
type streamBodyStub struct {
out []byte
err error
}
func (s *streamBodyStub) GetProviderType() string { return "stream-body-stub" }
func (s *streamBodyStub) OnStreamingResponseBody(ctx wrapper.HttpContext, name provider.ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
_ = ctx
_ = name
_ = chunk
_ = isLastChunk
return s.out, s.err
}
type streamEventStub struct {
eventsOut []provider.StreamEvent
err error
}
func (s *streamEventStub) GetProviderType() string { return "stream-event-stub" }
func (s *streamEventStub) OnStreamingEvent(ctx wrapper.HttpContext, name provider.ApiName, event provider.StreamEvent) ([]provider.StreamEvent, error) {
_ = ctx
_ = name
_ = event
return s.eventsOut, s.err
}
func pluginConfigWithStubProvider(t *testing.T, p provider.Provider) config.PluginConfig {
t.Helper()
var c config.PluginConfig
c.FromJson(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["tok"]}}`))
require.NoError(t, c.Validate())
require.NoError(t, c.Complete())
c.SetActiveProviderForTest(p)
return c
}
func TestOnStreamingResponseBody_matrix(t *testing.T) {
wasmtest.RunGoTest(t, func(t *testing.T) {
bootstrap, st := wasmtest.NewTestHost(json.RawMessage(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:1","apiTokens":["bootstrap"]}}`))
require.Equal(t, types.OnPluginStartStatusOK, st)
defer bootstrap.Reset()
t.Run("nil_provider_returns_chunk", func(t *testing.T) {
var c config.PluginConfig
c.FromJson(gjson.Parse(`{"providers":[{"id":"x","type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["t"]}]}`))
require.NoError(t, c.Validate())
require.NoError(t, c.Complete())
ctx := test.NewMockHttpContext()
in := []byte("keep")
out := OnStreamingResponseBodyForTest(ctx, c, in, false)
require.Equal(t, in, out)
})
t.Run("streaming_body_handler_err_returns_original_chunk", func(t *testing.T) {
stub := &streamBodyStub{out: []byte("x"), err: errors.New("handler failed")}
pc := pluginConfigWithStubProvider(t, stub)
ctx := test.NewMockHttpContext()
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
in := []byte("original")
out := OnStreamingResponseBodyForTest(ctx, pc, in, false)
require.Equal(t, in, out)
})
t.Run("streaming_body_handler_nil_modified_returns_original_chunk", func(t *testing.T) {
stub := &streamBodyStub{out: nil, err: nil}
pc := pluginConfigWithStubProvider(t, stub)
ctx := test.NewMockHttpContext()
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
in := []byte("original")
out := OnStreamingResponseBodyForTest(ctx, pc, in, false)
require.Equal(t, in, out)
})
t.Run("streaming_body_handler_ok_returns_modified", func(t *testing.T) {
stub := &streamBodyStub{out: []byte("modified"), err: nil}
pc := pluginConfigWithStubProvider(t, stub)
ctx := test.NewMockHttpContext()
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
in := []byte("in")
out := OnStreamingResponseBodyForTest(ctx, pc, in, false)
require.Equal(t, "modified", string(out))
})
t.Run("streaming_event_handler_zero_events_returns_empty", func(t *testing.T) {
stub := &streamEventStub{}
pc := pluginConfigWithStubProvider(t, stub)
ctx := test.NewMockHttpContext()
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
out := OnStreamingResponseBodyForTest(ctx, pc, []byte("incomplete"), false)
require.Equal(t, []byte(""), out)
})
t.Run("streaming_event_handler_on_event_err_returns_chunk", func(t *testing.T) {
stub := &streamEventStub{err: errors.New("event failed")}
pc := pluginConfigWithStubProvider(t, stub)
ctx := test.NewMockHttpContext()
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
chunk := []byte("data: {\"x\":1}\n\n")
out := OnStreamingResponseBodyForTest(ctx, pc, chunk, false)
require.Equal(t, chunk, out)
})
t.Run("no_handler_no_flags_returns_chunk", func(t *testing.T) {
var c config.PluginConfig
c.FromJson(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["t"]}}`))
require.NoError(t, c.Validate())
require.NoError(t, c.Complete())
ctx := test.NewMockHttpContext()
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
in := []byte("passthrough")
out := OnStreamingResponseBodyForTest(ctx, c, in, false)
require.Equal(t, in, out)
})
})
}