feat(ai-proxy): add Fireworks AI support (#2917)

This commit is contained in:
aias00
2025-09-21 14:32:04 +08:00
committed by GitHub
parent 47827ad271
commit 88a679ee07
6 changed files with 569 additions and 0 deletions

View File

@@ -177,6 +177,10 @@ Grok 所对应的 `type` 为 `grok`。它并无特有的配置字段。
OpenRouter 所对应的 `type``openrouter`。它并无特有的配置字段。 OpenRouter 所对应的 `type``openrouter`。它并无特有的配置字段。
#### Fireworks AI
Fireworks AI 所对应的 `type``fireworks`。它并无特有的配置字段。
#### 文心一言Baidu #### 文心一言Baidu
文心一言所对应的 `type``baidu`。它并无特有的配置字段。 文心一言所对应的 `type``baidu`。它并无特有的配置字段。
@@ -1018,6 +1022,63 @@ provider:
} }
``` ```
### 使用 OpenAI 协议代理 Fireworks AI 服务
**配置信息**
```yaml
provider:
type: fireworks
apiTokens:
- "YOUR_FIREWORKS_API_TOKEN"
modelMapping:
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct"
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct"
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct"
```
**请求示例**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"temperature": 0.7,
"max_tokens": 100
}
```
**响应示例**
```json
{
"id": "fw-123456789",
"object": "chat.completion",
"created": 1699123456,
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是一个由 Fireworks AI 提供的人工智能助手,基于 Llama 3.1 模型。我可以帮助回答问题、进行对话和提供各种信息。有什么我可以帮助你的吗?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 45,
"total_tokens": 60
}
}
```
### 使用自动协议兼容功能 ### 使用自动协议兼容功能
插件现在支持自动协议检测,可以同时处理 OpenAI 和 Claude 两种协议格式的请求。 插件现在支持自动协议检测,可以同时处理 OpenAI 和 Claude 两种协议格式的请求。
@@ -1982,6 +2043,7 @@ provider:
} }
} }
``` ```
### 使用 OpenAI 协议代理 NVIDIA Triton Interference Server 服务 ### 使用 OpenAI 协议代理 NVIDIA Triton Interference Server 服务
**配置信息** **配置信息**
@@ -2011,6 +2073,7 @@ providers:
"stream": false "stream": false
} }
``` ```
**响应示例** **响应示例**
```json ```json

View File

@@ -148,6 +148,10 @@ For Grok, the corresponding `type` is `grok`. It has no unique configuration fie
For OpenRouter, the corresponding `type` is `openrouter`. It has no unique configuration fields. For OpenRouter, the corresponding `type` is `openrouter`. It has no unique configuration fields.
#### Fireworks AI
For Fireworks AI, the corresponding `type` is `fireworks`. It has no unique configuration fields.
#### ERNIE Bot #### ERNIE Bot
For ERNIE Bot, the corresponding `type` is `baidu`. It has no unique configuration fields. For ERNIE Bot, the corresponding `type` is `baidu`. It has no unique configuration fields.
@@ -955,6 +959,63 @@ provider:
} }
``` ```
### Using OpenAI Protocol Proxy for Fireworks AI Service
**Configuration Information**
```yaml
provider:
type: fireworks
apiTokens:
- "YOUR_FIREWORKS_API_TOKEN"
modelMapping:
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct"
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct"
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct"
```
**Request Example**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, who are you?"
}
],
"temperature": 0.7,
"max_tokens": 100
}
```
**Response Example**
```json
{
"id": "fw-123456789",
"object": "chat.completion",
"created": 1699123456,
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am an AI assistant powered by Fireworks AI, based on the Llama 3.1 model. I can help answer questions, engage in conversations, and provide various information. How can I assist you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 38,
"total_tokens": 53
}
}
```
### Using Auto Protocol Compatibility ### Using Auto Protocol Compatibility
The plugin now supports automatic protocol detection, capable of handling both OpenAI and Claude protocol format requests simultaneously. The plugin now supports automatic protocol detection, capable of handling both OpenAI and Claude protocol format requests simultaneously.

View File

@@ -103,3 +103,9 @@ func TestAzure(t *testing.T) {
test.RunAzureOnHttpResponseHeadersTests(t) test.RunAzureOnHttpResponseHeadersTests(t)
test.RunAzureOnHttpResponseBodyTests(t) test.RunAzureOnHttpResponseBodyTests(t)
} }
func TestFireworks(t *testing.T) {
test.RunFireworksParseConfigTests(t)
test.RunFireworksOnHttpRequestHeadersTests(t)
test.RunFireworksOnHttpRequestBodyTests(t)
}

View File

@@ -0,0 +1,83 @@
package provider
import (
"errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// fireworksProvider is the provider for Fireworks AI service.
const (
fireworksDomain = "api.fireworks.ai"
)
type fireworksProviderInitializer struct{}
func (f *fireworksProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (f *fireworksProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameCompletion): PathOpenAICompletions,
string(ApiNameModels): PathOpenAIModels,
}
}
func (f *fireworksProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(f.DefaultCapabilities())
return &fireworksProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type fireworksProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (f *fireworksProvider) GetProviderType() string {
return providerTypeFireworks
}
func (f *fireworksProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
f.config.handleRequestHeaders(f, ctx, apiName)
return nil
}
func (f *fireworksProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !f.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return f.config.handleRequestBody(f, f.contextCache, ctx, apiName, body)
}
func (f *fireworksProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), f.config.capabilities)
util.OverwriteRequestHostHeader(headers, fireworksDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+f.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (f *fireworksProvider) GetApiName(path string) ApiName {
if strings.Contains(path, PathOpenAIChatCompletions) {
return ApiNameChatCompletion
}
if strings.Contains(path, PathOpenAICompletions) {
return ApiNameCompletion
}
if strings.Contains(path, PathOpenAIModels) {
return ApiNameModels
}
return ""
}

View File

@@ -134,6 +134,7 @@ const (
providerTypeTriton = "triton" providerTypeTriton = "triton"
providerTypeOpenRouter = "openrouter" providerTypeOpenRouter = "openrouter"
providerTypeLongcat = "longcat" providerTypeLongcat = "longcat"
providerTypeFireworks = "fireworks"
protocolOpenAI = "openai" protocolOpenAI = "openai"
protocolOriginal = "original" protocolOriginal = "original"
@@ -215,6 +216,7 @@ var (
providerTypeTriton: &tritonProviderInitializer{}, providerTypeTriton: &tritonProviderInitializer{},
providerTypeOpenRouter: &openrouterProviderInitializer{}, providerTypeOpenRouter: &openrouterProviderInitializer{},
providerTypeLongcat: &longcatProviderInitializer{}, providerTypeLongcat: &longcatProviderInitializer{},
providerTypeFireworks: &fireworksProviderInitializer{},
} }
) )

View File

@@ -0,0 +1,354 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基本 Fireworks 配置
var basicFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-test123456789"},
"modelMapping": map[string]string{
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// 测试配置Fireworks 多模型配置
var fireworksMultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-multi-model"},
"modelMapping": map[string]string{
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// 测试配置:无效 Fireworks 配置(缺少 apiToken
var invalidFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{},
"modelMapping": map[string]string{},
},
})
return data
}()
// 测试配置:完整 Fireworks 配置
var completeFireworksConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "fireworks",
"apiTokens": []string{"fw-complete-test"},
"modelMapping": map[string]string{
"gpt-4": "accounts/fireworks/models/llama-v3p1-70b-instruct",
"gpt-3.5-turbo": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"*": "accounts/fireworks/models/llama-v3p1-8b-instruct",
},
},
})
return data
}()
// RunFireworksParseConfigTests 测试 Fireworks 配置解析
func RunFireworksParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本 Fireworks 配置解析
t.Run("basic fireworks config", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Fireworks 多模型配置解析
t.Run("fireworks multi model config", func(t *testing.T) {
host, status := test.NewTestHost(fireworksMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效 Fireworks 配置(缺少 apiToken
t.Run("invalid fireworks config - missing apiToken", func(t *testing.T) {
host, status := test.NewTestHost(invalidFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试完整 Fireworks 配置解析
t.Run("fireworks complete config", func(t *testing.T) {
host, status := test.NewTestHost(completeFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
// RunFireworksOnHttpRequestHeadersTests 测试 Fireworks 请求头处理
func RunFireworksOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Fireworks 聊天完成请求头处理
t.Run("fireworks chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回 HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 是否被改为 Fireworks 域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "api.fireworks.ai", hostValue, "Host should be changed to Fireworks domain")
// 验证 Authorization 是否被设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token with Bearer prefix")
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Equal(t, "/v1/chat/completions", pathValue, "Path should remain OpenAI compatible")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasFireworksLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "fireworks") || strings.Contains(log, "ai-proxy") {
hasFireworksLogs = true
break
}
}
require.True(t, hasFireworksLogs, "Should have Fireworks or ai-proxy processing logs")
})
// 测试 Fireworks 文本完成请求头处理
t.Run("fireworks completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.fireworks.ai", hostValue)
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Equal(t, "/v1/completions", pathValue, "Path should remain OpenAI compatible for completions")
// 验证 Authorization 设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for completions")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
})
// 测试 Fireworks 模型列表请求头处理
t.Run("fireworks models request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/models"},
{":method", "GET"},
})
require.Equal(t, types.ActionContinue, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 Host 转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.fireworks.ai", hostValue)
// 验证 Path 保持 OpenAI 兼容格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Equal(t, "/v1/models", pathValue, "Path should remain OpenAI compatible for models")
// 验证 Authorization 设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for models")
require.Contains(t, authValue, "Bearer fw-test123456789", "Authorization should contain Fireworks API token")
})
})
}
// RunFireworksOnHttpRequestBodyTests 测试 Fireworks 请求体处理
func RunFireworksOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Fireworks 聊天完成请求体处理
t.Run("fireworks chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, world!"}
],
"stream": false
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
"Model should be mapped to Fireworks model")
require.Contains(t, string(actualRequestBody), "Hello, world!",
"Request content should be preserved")
})
// 测试 Fireworks 流式聊天完成请求体处理
t.Run("fireworks streaming chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(fireworksMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试流式请求体
requestBody := `{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Write a poem about AI"}
],
"stream": true
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射gpt-4 应该映射到 70b 模型)
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-70b-instruct",
"GPT-4 should be mapped to Fireworks 70b model")
require.Contains(t, string(actualRequestBody), "Write a poem about AI",
"Request content should be preserved")
require.Contains(t, string(actualRequestBody), `"stream": true`,
"Stream flag should be preserved")
})
// 测试 Fireworks 文本完成请求体处理
t.Run("fireworks completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicFireworksConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 测试完成请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"prompt": "The future of AI is",
"max_tokens": 100
}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被正确处理
actualRequestBody := host.GetRequestBody()
require.NotNil(t, actualRequestBody)
// 验证模型映射
require.Contains(t, string(actualRequestBody), "accounts/fireworks/models/llama-v3p1-8b-instruct",
"Model should be mapped to Fireworks model")
require.Contains(t, string(actualRequestBody), "The future of AI is",
"Prompt should be preserved")
})
})
}