mirror of
https://github.com/alibaba/higress.git
synced 2026-03-04 08:30:48 +08:00
357 lines
12 KiB
Go
357 lines
12 KiB
Go
package test
|
||
|
||
import (
|
||
"encoding/json"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/higress-group/wasm-go/pkg/test"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// 测试配置:基本 Fireworks 配置
|
||
var basicFireworksConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"provider": map[string]interface{}{
|
||
"type": "fireworks",
|
||
"apiTokens": []string{"fw-test123456789"},
|
||
"modelMapping": map[string]string{
|
||
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:Fireworks 多模型配置
|
||
var fireworksMultiModelConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"provider": map[string]interface{}{
|
||
"type": "fireworks",
|
||
"apiTokens": []string{"fw-multi-model"},
|
||
"modelMapping": map[string]string{
|
||
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:无效 Fireworks 配置(缺少 apiToken)
|
||
var invalidFireworksConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"provider": map[string]interface{}{
|
||
"type": "fireworks",
|
||
"apiTokens": []string{},
|
||
"modelMapping": map[string]string{},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:完整 Fireworks 配置
|
||
var completeFireworksConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"provider": map[string]interface{}{
|
||
"type": "fireworks",
|
||
"apiTokens": []string{"fw-complete-test"},
|
||
"modelMapping": map[string]string{
|
||
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
},
|
||
},
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// RunFireworksParseConfigTests 测试 Fireworks 配置解析
|
||
func RunFireworksParseConfigTests(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
// 测试基本 Fireworks 配置解析
|
||
t.Run("basic fireworks config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试 Fireworks 多模型配置解析
|
||
t.Run("fireworks multi model config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(fireworksMultiModelConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试无效 Fireworks 配置(缺少 apiToken)
|
||
t.Run("invalid fireworks config - missing apiToken", func(t *testing.T) {
|
||
host, status := test.NewTestHost(invalidFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
// 测试完整 Fireworks 配置解析
|
||
t.Run("fireworks complete config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(completeFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
})
|
||
}
|
||
|
||
// RunFireworksOnHttpRequestHeadersTests 测试 Fireworks 请求头处理
|
||
func RunFireworksOnHttpRequestHeadersTests(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试 Fireworks 聊天完成请求头处理
|
||
t.Run("fireworks chat completion request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
{"Content-Type", "application/json"},
|
||
})
|
||
|
||
// 应该返回 HeaderStopIteration,因为需要处理请求体
|
||
require.Equal(t, types.HeaderStopIteration, action)
|
||
|
||
// 验证请求头是否被正确处理
|
||
requestHeaders := host.GetRequestHeaders()
|
||
require.NotNil(t, requestHeaders)
|
||
|
||
// 验证 Host 是否被改为 Fireworks 域名
|
||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||
require.True(t, hasHost, "Host header should exist")
|
||
require.Equal(t, "api.fireworks.ai", hostValue, "Host should be changed to Fireworks domain")
|
||
|
||
// 验证 Authorization 是否被设置
|
||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||
require.True(t, hasAuth, "Authorization header should exist")
|
||
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token with Bearer prefix")
|
||
|
||
// 验证 Path 保持 OpenAI 兼容格式
|
||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||
require.True(t, hasPath, "Path header should exist")
|
||
require.Equal(t, "/v1/chat/completions", pathValue, "Path should remain OpenAI compatible")
|
||
|
||
// 检查是否有相关的处理日志
|
||
debugLogs := host.GetDebugLogs()
|
||
hasFireworksLogs := false
|
||
for _, log := range debugLogs {
|
||
if strings.Contains(log, "fireworks") || strings.Contains(log, "ai-proxy") {
|
||
hasFireworksLogs = true
|
||
break
|
||
}
|
||
}
|
||
require.True(t, hasFireworksLogs, "Should have Fireworks or ai-proxy processing logs")
|
||
})
|
||
|
||
// 测试 Fireworks 文本完成请求头处理
|
||
t.Run("fireworks completion request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/completions"},
|
||
{":method", "POST"},
|
||
{"Content-Type", "application/json"},
|
||
})
|
||
|
||
require.Equal(t, types.HeaderStopIteration, action)
|
||
|
||
// 验证请求头处理
|
||
requestHeaders := host.GetRequestHeaders()
|
||
require.NotNil(t, requestHeaders)
|
||
|
||
// 验证 Host 转换
|
||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||
require.True(t, hasHost)
|
||
require.Equal(t, "api.fireworks.ai", hostValue)
|
||
|
||
// 验证 Path 保持 OpenAI 兼容格式
|
||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||
require.True(t, hasPath)
|
||
require.Equal(t, "/v1/completions", pathValue, "Path should remain OpenAI compatible for completions")
|
||
|
||
// 验证 Authorization 设置
|
||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||
require.True(t, hasAuth, "Authorization header should exist for completions")
|
||
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
|
||
})
|
||
|
||
// 测试 Fireworks 模型列表请求头处理
|
||
t.Run("fireworks models request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/models"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// TODO: Due to the limitations of the test framework, we just treat it as a request with body here.
|
||
//require.Equal(t, types.ActionContinue, action)
|
||
require.Equal(t, types.HeaderStopIteration, action)
|
||
|
||
// 验证请求头处理
|
||
requestHeaders := host.GetRequestHeaders()
|
||
require.NotNil(t, requestHeaders)
|
||
|
||
// 验证 Host 转换
|
||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||
require.True(t, hasHost)
|
||
require.Equal(t, "api.fireworks.ai", hostValue)
|
||
|
||
// 验证 Path 保持 OpenAI 兼容格式
|
||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||
require.True(t, hasPath)
|
||
require.Equal(t, "/v1/models", pathValue, "Path should remain OpenAI compatible for models")
|
||
|
||
// 验证 Authorization 设置
|
||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||
require.True(t, hasAuth, "Authorization header should exist for models")
|
||
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
|
||
})
|
||
})
|
||
}
|
||
|
||
// RunFireworksOnHttpRequestBodyTests 测试 Fireworks 请求体处理
|
||
func RunFireworksOnHttpRequestBodyTests(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试 Fireworks 聊天完成请求体处理
|
||
t.Run("fireworks chat completion request body", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
{"Content-Type", "application/json"},
|
||
})
|
||
|
||
// 测试请求体
|
||
requestBody := `{
|
||
"model": "gpt-3.5-turbo",
|
||
"messages": [
|
||
{"role": "user", "content": "Hello, world!"}
|
||
],
|
||
"stream": false
|
||
}`
|
||
|
||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 验证请求体被正确处理
|
||
actualRequestBody := host.GetRequestBody()
|
||
require.NotNil(t, actualRequestBody)
|
||
|
||
// 验证模型映射
|
||
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
"Model should be mapped to Fireworks model")
|
||
require.Contains(t, string(actualRequestBody), "Hello, world!",
|
||
"Request content should be preserved")
|
||
})
|
||
|
||
// 测试 Fireworks 流式聊天完成请求体处理
|
||
t.Run("fireworks streaming chat completion request body", func(t *testing.T) {
|
||
host, status := test.NewTestHost(fireworksMultiModelConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/chat/completions"},
|
||
{":method", "POST"},
|
||
{"Content-Type", "application/json"},
|
||
})
|
||
|
||
// 测试流式请求体
|
||
requestBody := `{
|
||
"model": "gpt-4",
|
||
"messages": [
|
||
{"role": "user", "content": "Write a poem about AI"}
|
||
],
|
||
"stream": true
|
||
}`
|
||
|
||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 验证请求体被正确处理
|
||
actualRequestBody := host.GetRequestBody()
|
||
require.NotNil(t, actualRequestBody)
|
||
|
||
// 验证模型映射(gpt-4 应该映射到 70b 模型)
|
||
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||
"GPT-4 should be mapped to Fireworks 70b model")
|
||
require.Contains(t, string(actualRequestBody), "Write a poem about AI",
|
||
"Request content should be preserved")
|
||
require.Contains(t, string(actualRequestBody), `"stream": true`,
|
||
"Stream flag should be preserved")
|
||
})
|
||
|
||
// 测试 Fireworks 文本完成请求体处理
|
||
t.Run("fireworks completion request body", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicFireworksConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/v1/completions"},
|
||
{":method", "POST"},
|
||
{"Content-Type", "application/json"},
|
||
})
|
||
|
||
// 测试完成请求体
|
||
requestBody := `{
|
||
"model": "gpt-3.5-turbo",
|
||
"prompt": "The future of AI is",
|
||
"max_tokens": 100
|
||
}`
|
||
|
||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
// 验证请求体被正确处理
|
||
actualRequestBody := host.GetRequestBody()
|
||
require.NotNil(t, actualRequestBody)
|
||
|
||
// 验证模型映射
|
||
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||
"Model should be mapped to Fireworks model")
|
||
require.Contains(t, string(actualRequestBody), "The future of AI is",
|
||
"Prompt should be preserved")
|
||
})
|
||
})
|
||
}
|