feat(ai-proxy): add Fireworks AI support (#2917)

This commit is contained in:
aias00
2025-09-21 14:32:04 +08:00
committed by GitHub
parent 47827ad271
commit 88a679ee07
6 changed files with 569 additions and 0 deletions

View File

@@ -0,0 +1,354 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基本 Fireworks 配置
var basicFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-test123456789"},
"modelMapping": map[string]string{
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// 测试配置Fireworks 多模型配置
var fireworksMultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-multi-model"},
"modelMapping": map[string]string{
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// 测试配置:无效 Fireworks 配置(缺少 apiToken
var invalidFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{},
"modelMapping": map[string]string{},
},
})
return data
}()
// 测试配置:完整 Fireworks 配置
var completeFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-complete-test"},
"modelMapping": map[string]string{
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// RunFireworksParseConfigTests 测试 Fireworks 配置解析
func RunFireworksParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本 Fireworks 配置解析
t.Run("basic fireworks config", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Fireworks 多模型配置解析
t.Run("fireworks multi model config", func(t *testing.T) {
host, status := test.NewTestHost(fireworksMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效 Fireworks 配置(缺少 apiToken
t.Run("invalid fireworks config - missing apiToken", func(t *testing.T) {
host, status := test.NewTestHost(invalidFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试完整 Fireworks 配置解析
t.Run("fireworks complete config", func(t *testing.T) {
host, status := test.NewTestHost(completeFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
// RunFireworksOnHttpRequestHeadersTests 测试 Fireworks 请求头处理
func RunFireworksOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Fireworks 聊天完成请求头处理
t.Run("fireworks chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回 HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 Fireworks 域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "api.fireworks.ai", hostValue, "Host should be changed to Fireworks domain")
// 验证 Authorization 是否被设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token with Bearer prefix")
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, "/v1/chat/completions", pathValue, "Path should remain OpenAI compatible")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasFireworksLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "fireworks") || strings.Contains(log, "ai-proxy") {
hasFireworksLogs = true
break
}
}
require.True(t, hasFireworksLogs, "Should have Fireworks or ai-proxy processing logs")
})
// 测试 Fireworks 文本完成请求头处理
t.Run("fireworks completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.fireworks.ai", hostValue)
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Equal(t, "/v1/completions", pathValue, "Path should remain OpenAI compatible for completions")
// 验证 Authorization 设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for completions")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
})
// 测试 Fireworks 模型列表请求头处理
t.Run("fireworks models request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/models"},
{":method", "GET"},
})
require.Equal(t, types.ActionContinue, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.fireworks.ai", hostValue)
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Equal(t, "/v1/models", pathValue, "Path should remain OpenAI compatible for models")
// 验证 Authorization 设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for models")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
})
})
}
// RunFireworksOnHttpRequestBodyTests 测试 Fireworks 请求体处理
func RunFireworksOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Fireworks 聊天完成请求体处理
t.Run("fireworks chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, world!"}
],
"stream": false
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
"Model should be mapped to Fireworks model")
require.Contains(t, string(actualRequestBody), "Hello, world!",
"Request content should be preserved")
})
// 测试 Fireworks 流式聊天完成请求体处理
t.Run("fireworks streaming chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(fireworksMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试流式请求体
requestBody := `{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Write a poem about AI"}
],
"stream": true
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射gpt-4 应该映射到 70b 模型)
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-70b-instruct",
"GPT-4 should be mapped to Fireworks 70b model")
require.Contains(t, string(actualRequestBody), "Write a poem about AI",
"Request content should be preserved")
require.Contains(t, string(actualRequestBody), `"stream": true`,
"Stream flag should be preserved")
})
// 测试 Fireworks 文本完成请求体处理
t.Run("fireworks completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试完成请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"prompt": "The future of AI is",
"max_tokens": 100
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
"Model should be mapped to Fireworks model")
require.Contains(t, string(actualRequestBody), "The future of AI is",
"Prompt should be preserved")
})
})
}