Files
higress/plugins/wasm-go/extensions/ai-proxy/test/fireworks.go

357 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基本 Fireworks 配置
var basicFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-test123456789"},
"modelMapping": map[string]string{
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// 测试配置Fireworks 多模型配置
var fireworksMultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-multi-model"},
"modelMapping": map[string]string{
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// 测试配置:无效 Fireworks 配置(缺少 apiToken
var invalidFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{},
"modelMapping": map[string]string{},
},
})
return data
}()
// 测试配置:完整 Fireworks 配置
var completeFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-complete-test"},
"modelMapping": map[string]string{
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// RunFireworksParseConfigTests 测试 Fireworks 配置解析
func RunFireworksParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本 Fireworks 配置解析
t.Run("basic fireworks config", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Fireworks 多模型配置解析
t.Run("fireworks multi model config", func(t *testing.T) {
host, status := test.NewTestHost(fireworksMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效 Fireworks 配置(缺少 apiToken
t.Run("invalid fireworks config - missing apiToken", func(t *testing.T) {
host, status := test.NewTestHost(invalidFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试完整 Fireworks 配置解析
t.Run("fireworks complete config", func(t *testing.T) {
host, status := test.NewTestHost(completeFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
// RunFireworksOnHttpRequestHeadersTests 测试 Fireworks 请求头处理
func RunFireworksOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Fireworks 聊天完成请求头处理
t.Run("fireworks chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回 HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 Fireworks 域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "api.fireworks.ai", hostValue, "Host should be changed to Fireworks domain")
// 验证 Authorization 是否被设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token with Bearer prefix")
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, "/v1/chat/completions", pathValue, "Path should remain OpenAI compatible")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasFireworksLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "fireworks") || strings.Contains(log, "ai-proxy") {
hasFireworksLogs = true
break
}
}
require.True(t, hasFireworksLogs, "Should have Fireworks or ai-proxy processing logs")
})
// 测试 Fireworks 文本完成请求头处理
t.Run("fireworks completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.fireworks.ai", hostValue)
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Equal(t, "/v1/completions", pathValue, "Path should remain OpenAI compatible for completions")
// 验证 Authorization 设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for completions")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
})
// 测试 Fireworks 模型列表请求头处理
t.Run("fireworks models request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/models"},
{":method", "GET"},
})
// TODO: Due to the limitations of the test framework, we just treat it as a request with body here.
//require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.fireworks.ai", hostValue)
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Equal(t, "/v1/models", pathValue, "Path should remain OpenAI compatible for models")
// 验证 Authorization 设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for models")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
})
})
}
// RunFireworksOnHttpRequestBodyTests 测试 Fireworks 请求体处理
func RunFireworksOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Fireworks 聊天完成请求体处理
t.Run("fireworks chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, world!"}
],
"stream": false
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
"Model should be mapped to Fireworks model")
require.Contains(t, string(actualRequestBody), "Hello, world!",
"Request content should be preserved")
})
// 测试 Fireworks 流式聊天完成请求体处理
t.Run("fireworks streaming chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(fireworksMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试流式请求体
requestBody := `{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Write a poem about AI"}
],
"stream": true
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射gpt-4 应该映射到 70b 模型)
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-70b-instruct",
"GPT-4 should be mapped to Fireworks 70b model")
require.Contains(t, string(actualRequestBody), "Write a poem about AI",
"Request content should be preserved")
require.Contains(t, string(actualRequestBody), `"stream": true`,
"Stream flag should be preserved")
})
// 测试 Fireworks 文本完成请求体处理
t.Run("fireworks completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试完成请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"prompt": "The future of AI is",
"max_tokens": 100
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
"Model should be mapped to Fireworks model")
require.Contains(t, string(actualRequestBody), "The future of AI is",
"Prompt should be preserved")
})
})
}