mirror of
https://github.com/alibaba/higress.git
synced 2026-05-21 11:17:28 +08:00
136 lines
4.9 KiB
Go
136 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
wasmtest "github.com/higress-group/wasm-go/pkg/test"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type streamBodyStub struct {
|
|
out []byte
|
|
err error
|
|
}
|
|
|
|
func (s *streamBodyStub) GetProviderType() string { return "stream-body-stub" }
|
|
|
|
func (s *streamBodyStub) OnStreamingResponseBody(ctx wrapper.HttpContext, name provider.ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
|
_ = ctx
|
|
_ = name
|
|
_ = chunk
|
|
_ = isLastChunk
|
|
return s.out, s.err
|
|
}
|
|
|
|
type streamEventStub struct {
|
|
eventsOut []provider.StreamEvent
|
|
err error
|
|
}
|
|
|
|
func (s *streamEventStub) GetProviderType() string { return "stream-event-stub" }
|
|
|
|
func (s *streamEventStub) OnStreamingEvent(ctx wrapper.HttpContext, name provider.ApiName, event provider.StreamEvent) ([]provider.StreamEvent, error) {
|
|
_ = ctx
|
|
_ = name
|
|
_ = event
|
|
return s.eventsOut, s.err
|
|
}
|
|
|
|
func pluginConfigWithStubProvider(t *testing.T, p provider.Provider) config.PluginConfig {
|
|
t.Helper()
|
|
var c config.PluginConfig
|
|
c.FromJson(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["tok"]}}`))
|
|
require.NoError(t, c.Validate())
|
|
require.NoError(t, c.Complete())
|
|
c.SetActiveProviderForTest(p)
|
|
return c
|
|
}
|
|
|
|
func TestOnStreamingResponseBody_matrix(t *testing.T) {
|
|
wasmtest.RunGoTest(t, func(t *testing.T) {
|
|
bootstrap, st := wasmtest.NewTestHost(json.RawMessage(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:1","apiTokens":["bootstrap"]}}`))
|
|
require.Equal(t, types.OnPluginStartStatusOK, st)
|
|
defer bootstrap.Reset()
|
|
|
|
t.Run("nil_provider_returns_chunk", func(t *testing.T) {
|
|
var c config.PluginConfig
|
|
c.FromJson(gjson.Parse(`{"providers":[{"id":"x","type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["t"]}]}`))
|
|
require.NoError(t, c.Validate())
|
|
require.NoError(t, c.Complete())
|
|
ctx := test.NewMockHttpContext()
|
|
in := []byte("keep")
|
|
out := OnStreamingResponseBodyForTest(ctx, c, in, false)
|
|
require.Equal(t, in, out)
|
|
})
|
|
|
|
t.Run("streaming_body_handler_err_returns_original_chunk", func(t *testing.T) {
|
|
stub := &streamBodyStub{out: []byte("x"), err: errors.New("handler failed")}
|
|
pc := pluginConfigWithStubProvider(t, stub)
|
|
ctx := test.NewMockHttpContext()
|
|
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
|
|
in := []byte("original")
|
|
out := OnStreamingResponseBodyForTest(ctx, pc, in, false)
|
|
require.Equal(t, in, out)
|
|
})
|
|
|
|
t.Run("streaming_body_handler_nil_modified_returns_original_chunk", func(t *testing.T) {
|
|
stub := &streamBodyStub{out: nil, err: nil}
|
|
pc := pluginConfigWithStubProvider(t, stub)
|
|
ctx := test.NewMockHttpContext()
|
|
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
|
|
in := []byte("original")
|
|
out := OnStreamingResponseBodyForTest(ctx, pc, in, false)
|
|
require.Equal(t, in, out)
|
|
})
|
|
|
|
t.Run("streaming_body_handler_ok_returns_modified", func(t *testing.T) {
|
|
stub := &streamBodyStub{out: []byte("modified"), err: nil}
|
|
pc := pluginConfigWithStubProvider(t, stub)
|
|
ctx := test.NewMockHttpContext()
|
|
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
|
|
in := []byte("in")
|
|
out := OnStreamingResponseBodyForTest(ctx, pc, in, false)
|
|
require.Equal(t, "modified", string(out))
|
|
})
|
|
|
|
t.Run("streaming_event_handler_zero_events_returns_empty", func(t *testing.T) {
|
|
stub := &streamEventStub{}
|
|
pc := pluginConfigWithStubProvider(t, stub)
|
|
ctx := test.NewMockHttpContext()
|
|
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
|
|
out := OnStreamingResponseBodyForTest(ctx, pc, []byte("incomplete"), false)
|
|
require.Equal(t, []byte(""), out)
|
|
})
|
|
|
|
t.Run("streaming_event_handler_on_event_err_returns_chunk", func(t *testing.T) {
|
|
stub := &streamEventStub{err: errors.New("event failed")}
|
|
pc := pluginConfigWithStubProvider(t, stub)
|
|
ctx := test.NewMockHttpContext()
|
|
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
|
|
chunk := []byte("data: {\"x\":1}\n\n")
|
|
out := OnStreamingResponseBodyForTest(ctx, pc, chunk, false)
|
|
require.Equal(t, chunk, out)
|
|
})
|
|
|
|
t.Run("no_handler_no_flags_returns_chunk", func(t *testing.T) {
|
|
var c config.PluginConfig
|
|
c.FromJson(gjson.Parse(`{"provider":{"type":"generic","genericHost":"http://127.0.0.1:9","apiTokens":["t"]}}`))
|
|
require.NoError(t, c.Validate())
|
|
require.NoError(t, c.Complete())
|
|
ctx := test.NewMockHttpContext()
|
|
ctx.SetContext(provider.CtxKeyApiName, provider.ApiNameChatCompletion)
|
|
in := []byte("passthrough")
|
|
out := OnStreamingResponseBodyForTest(ctx, c, in, false)
|
|
require.Equal(t, in, out)
|
|
})
|
|
})
|
|
}
|