feat(ai-proxy): add batches & files support (#2355)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
Xijun Dai
2025-06-03 09:42:36 +08:00
committed by GitHub
parent 19946d46ca
commit 33fc47cefb
11 changed files with 145 additions and 86 deletions

View File

@@ -361,9 +361,21 @@ func getApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/batches") {
return provider.ApiNameBatches
}
if util.RegRetrieveBatchPath.MatchString(path) {
return provider.ApiNameRetrieveBatch
}
if util.RegCancelBatchPath.MatchString(path) {
return provider.ApiNameCancelBatch
}
if strings.HasSuffix(path, "/v1/files") {
return provider.ApiNameFiles
}
if util.RegRetrieveFilePath.MatchString(path) {
return provider.ApiNameRetrieveFile
}
if util.RegRetrieveFileContentPath.MatchString(path) {
return provider.ApiNameRetrieveFileContent
}
if strings.HasSuffix(path, "/v1/models") {
return provider.ApiNameModels
}

View File

@@ -42,13 +42,10 @@ const (
hunyuanAuthIdLen = 36
// docs: https://cloud.tencent.com/document/product/1729/111007
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
hunyuanOpenAiRequestPath = "/v1/chat/completions"
hunyuanOpenAiEmbeddings = "/v1/embeddings"
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
)
type hunyuanProviderInitializer struct {
}
type hunyuanProviderInitializer struct{}
// ref: https://console.cloud.tencent.com/api/explorer?Product=hunyuan&Version=2023-09-01&Action=ChatCompletions
type hunyuanTextGenRequest struct {
@@ -105,8 +102,8 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro
func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): hunyuanOpenAiRequestPath,
string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
}
@@ -324,7 +321,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
}
// hunyuan的流式返回:
//data: {"Note":"以上内容为AI生成不代表开发者立场请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
// data: {"Note":"以上内容为AI生成不代表开发者立场请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
// openai的流式返回
// data: {"id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", "object": "chat.completion.chunk", "created": 1677664795, "model": "gpt-3.5-turbo-0613", "choices": [{"delta": {"content": "The "}, "index": 0, "finish_reason": null}]}
@@ -338,7 +335,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
}
// 初始化处理下标以及将要返回的处理过的chunks
var newEventPivot = -1
newEventPivot := -1
var outputBuffer []byte
// 从buffer区取出若干完整的chunk将其转为openAI格式后返回
@@ -451,7 +448,6 @@ func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
}
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
fileMessage := hunyuanChatMessage{
Role: roleSystem,
Content: content,

View File

@@ -18,13 +18,10 @@ import (
// moonshotProvider is the provider for Moonshot AI service.
const (
moonshotDomain = "api.moonshot.cn"
moonshotChatCompletionPath = "/v1/chat/completions"
moonshotModelsPath = "/v1/models"
moonshotDomain = "api.moonshot.cn"
)
type moonshotProviderInitializer struct {
}
type moonshotProviderInitializer struct{}
func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.moonshotFileId != "" && config.context != nil {
@@ -38,8 +35,8 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err
func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): moonshotChatCompletionPath,
string(ApiNameModels): moonshotModelsPath,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameModels): PathOpenAIModels,
}
}

View File

@@ -15,17 +15,7 @@ import (
// openaiProvider is the provider for OpenAI service.
const (
defaultOpenaiDomain = "api.openai.com"
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
defaultOpenaiCompletionPath = "/v1/completions"
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
defaultOpenaiAudioSpeech = "/v1/audio/speech"
defaultOpenaiImageGeneration = "/v1/images/generations"
defaultOpenaiImageEdit = "/v1/images/edits"
defaultOpenaiImageVariation = "/v1/images/variations"
defaultOpenaiModels = "/v1/models"
defaultOpenaiFiles = "/v1/files"
defaultOpenaiBatchs = "/v1/batches"
defaultOpenaiDomain = "api.openai.com"
)
type openaiProviderInitializer struct{}
@@ -36,15 +26,20 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameCompletion): defaultOpenaiCompletionPath,
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
string(ApiNameImageEdit): defaultOpenaiImageEdit,
string(ApiNameImageVariation): defaultOpenaiImageVariation,
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
string(ApiNameModels): defaultOpenaiModels,
string(ApiNameFiles): defaultOpenaiFiles,
string(ApiNameCompletion): PathOpenAICompletions,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
string(ApiNameImageGeneration): PathOpenAIImageGeneration,
string(ApiNameImageEdit): PathOpenAIImageEdit,
string(ApiNameImageVariation): PathOpenAIImageVariation,
string(ApiNameAudioSpeech): PathOpenAIAudioSpeech,
string(ApiNameModels): PathOpenAIModels,
string(ApiNameFiles): PathOpenAIFiles,
string(ApiNameRetrieveFile): PathOpenAIRetrieveFile,
string(ApiNameRetrieveFileContent): PathOpenAIRetrieveFileContent,
string(ApiNameBatches): PathOpenAIBatches,
string(ApiNameRetrieveBatch): PathOpenAIRetrieveBatch,
string(ApiNameCancelBatch): PathOpenAICancelBatch,
}
}

View File

@@ -26,23 +26,35 @@ const (
// ApiName 格式 {vendor}/{version}/{apitype}
// 表示遵循 厂商/版本/接口类型 的格式
// 目前openai是事实意义上的标准但是也有其他厂商存在其他任务的一些可能的标准比如cohere的rerank
ApiNameCompletion ApiName = "openai/v1/completions"
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameImageEdit ApiName = "openai/v1/imageedit"
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
ApiNameFiles ApiName = "openai/v1/files"
ApiNameBatches ApiName = "openai/v1/batches"
ApiNameModels ApiName = "openai/v1/models"
ApiNameCompletion ApiName = "openai/v1/completions"
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameImageEdit ApiName = "openai/v1/imageedit"
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
ApiNameFiles ApiName = "openai/v1/files"
ApiNameRetrieveFile ApiName = "openai/v1/retrievefile"
ApiNameRetrieveFileContent ApiName = "openai/v1/retrievefilecontent"
ApiNameBatches ApiName = "openai/v1/batches"
ApiNameRetrieveBatch ApiName = "openai/v1/retrievebatch"
ApiNameCancelBatch ApiName = "openai/v1/cancelbatch"
ApiNameModels ApiName = "openai/v1/models"
PathOpenAICompletions = "/v1/completions"
PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"
PathOpenAIFiles = "/v1/files"
PathOpenAIBatches = "/v1/batches"
PathOpenAIModels = "/v1/models"
PathOpenAICompletions = "/v1/completions"
PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"
PathOpenAIFiles = "/v1/files"
PathOpenAIRetrieveFile = "/v1/files/{file_id}"
PathOpenAIRetrieveFileContent = "/v1/files/{file_id}/content"
PathOpenAIBatches = "/v1/batches"
PathOpenAIRetrieveBatch = "/v1/batches/{batch_id}"
PathOpenAICancelBatch = "/v1/batches/{batch_id}/cancel"
PathOpenAIModels = "/v1/models"
PathOpenAIImageGeneration = "/v1/images/generations"
PathOpenAIImageEdit = "/v1/images/edits"
PathOpenAIImageVariation = "/v1/images/variations"
PathOpenAIAudioSpeech = "/v1/audio/speech"
// TODO: 以下是一些非标准的API名称需要进一步确认是否支持
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"

View File

@@ -23,13 +23,19 @@ import (
const (
qwenResultFormatMessage = "message"
qwenDefaultDomain = "dashscope.aliyuncs.com"
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenChatCompatiblePath = "/compatible-mode/v1/chat/completions"
qwenTextEmbeddingCompatiblePath = "/compatible-mode/v1/embeddings"
qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenDefaultDomain = "dashscope.aliyuncs.com"
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatibleChatCompletionPath = "/compatible-mode/v1/chat/completions"
qwenCompatibleCompletionsPath = "/compatible-mode/v1/completions"
qwenCompatibleTextEmbeddingPath = "/compatible-mode/v1/embeddings"
qwenCompatibleFilesPath = "/compatible-mode/v1/files"
qwenCompatibleRetrieveFilePath = "/compatible-mode/v1/files/{file_id}"
qwenCompatibleRetrieveFileContentPath = "/compatible-mode/v1/files/{file_id}/content"
qwenCompatibleBatchesPath = "/compatible-mode/v1/batches"
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenTopPMin = 0.000001
qwenTopPMax = 0.999999
@@ -40,8 +46,7 @@ const (
qwenVlModelPrefixName = "qwen-vl"
)
type qwenProviderInitializer struct {
}
type qwenProviderInitializer struct{}
func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if len(config.qwenFileIds) != 0 && config.context != nil {
@@ -56,8 +61,14 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool) map[string]string {
if qwenEnableCompatible {
return map[string]string{
string(ApiNameChatCompletion): qwenChatCompatiblePath,
string(ApiNameEmbeddings): qwenTextEmbeddingCompatiblePath,
string(ApiNameChatCompletion): qwenCompatibleChatCompletionPath,
string(ApiNameEmbeddings): qwenCompatibleTextEmbeddingPath,
string(ApiNameCompletion): qwenCompatibleCompletionsPath,
string(ApiNameFiles): qwenCompatibleFilesPath,
string(ApiNameRetrieveFile): qwenCompatibleRetrieveFilePath,
string(ApiNameRetrieveFileContent): qwenCompatibleRetrieveFileContentPath,
string(ApiNameBatches): qwenCompatibleBatchesPath,
string(ApiNameRetrieveBatch): qwenCompatibleRetrieveBatchPath,
}
} else {
return map[string]string{
@@ -673,10 +684,10 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath),
strings.Contains(path, qwenBailianPath),
strings.Contains(path, qwenChatCompatiblePath):
strings.Contains(path, qwenCompatibleChatCompletionPath):
return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath),
strings.Contains(path, qwenTextEmbeddingCompatiblePath):
strings.Contains(path, qwenCompatibleTextEmbeddingPath):
return ApiNameEmbeddings
default:
return ""

View File

@@ -15,12 +15,10 @@ import (
// sparkProvider is the provider for SparkLLM AI service.
const (
sparkHost = "spark-api-open.xf-yun.com"
sparkChatCompletionPath = "/v1/chat/completions"
sparkHost = "spark-api-open.xf-yun.com"
)
type sparkProviderInitializer struct {
}
type sparkProviderInitializer struct{}
type sparkProvider struct {
config ProviderConfig
@@ -58,7 +56,7 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): sparkChatCompletionPath,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
}
}

View File

@@ -10,12 +10,10 @@ import (
)
const (
stepfunDomain = "api.stepfun.com"
stepfunChatCompletionPath = "/v1/chat/completions"
stepfunDomain = "api.stepfun.com"
)
type stepfunProviderInitializer struct {
}
type stepfunProviderInitializer struct{}
func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
@@ -27,7 +25,7 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro
func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// stepfun的chat接口path和OpenAI的chat接口一样
string(ApiNameChatCompletion): stepfunChatCompletionPath,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
}
}

View File

@@ -11,8 +11,7 @@ import (
)
const (
togetherAIDomain = "api.together.xyz"
togetherAICompletionPath = "/v1/chat/completions"
togetherAIDomain = "api.together.xyz"
)
type togetherAIProviderInitializer struct{}
@@ -26,7 +25,7 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e
func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): togetherAICompletionPath,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
}
}
@@ -67,7 +66,7 @@ func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
}
func (m *togetherAIProvider) GetApiName(path string) ApiName {
if strings.Contains(path, togetherAICompletionPath) {
if strings.Contains(path, PathOpenAIChatCompletions) {
return ApiNameChatCompletion
}
return ""

View File

@@ -10,12 +10,10 @@ import (
)
const (
yiDomain = "api.lingyiwanwu.com"
yiChatCompletionPath = "/v1/chat/completions"
yiDomain = "api.lingyiwanwu.com"
)
type yiProviderInitializer struct {
}
type yiProviderInitializer struct{}
func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
@@ -26,7 +24,7 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
func (m *yiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): yiChatCompletionPath,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
}
}

View File

@@ -2,6 +2,8 @@ package util
import (
"net/http"
"regexp"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
@@ -13,6 +15,13 @@ const (
MimeTypeApplicationJson = "application/json"
)
var (
RegRetrieveBatchPath = regexp.MustCompile(`^.*/v1/batches/(?P<batch_id>[^/]+)$`)
RegCancelBatchPath = regexp.MustCompile(`^.*/v1/batches/(?P<batch_id>[^/]+)/cancel$`)
RegRetrieveFilePath = regexp.MustCompile(`^.*/v1/files/(?P<file_id>[^/]+)$`)
RegRetrieveFileContentPath = regexp.MustCompile(`^.*/v1/files/(?P<file_id>[^/]+)/content$`)
)
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
var ErrorHandler ErrorHandlerFunc = func(statusCodeDetails string, err error) error {
@@ -62,9 +71,43 @@ func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string,
if !exist {
return
}
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
originPath, err := proxywasm.GetHttpRequestHeader(":path")
if err == nil {
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
}
/**
这里实现不太优雅,理应通过 apiName 来判断使用哪个正则替换
但 ApiName 定义在 provider 中, 而 provider 中又引用了 util
会导致循环引用
**/
if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") {
replacements := []struct {
regx *regexp.Regexp
key string
}{
{RegRetrieveFilePath, "file_id"},
{RegRetrieveFileContentPath, "file_id"},
{RegRetrieveBatchPath, "batch_id"},
{RegCancelBatchPath, "batch_id"},
}
for _, r := range replacements {
if r.regx.MatchString(originPath) {
subMatch := r.regx.FindStringSubmatch(originPath)
if subMatch == nil {
continue
}
index := r.regx.SubexpIndex(r.key)
if index < 0 || index >= len(subMatch) {
continue
}
id := subMatch[index]
mappedPath = r.regx.ReplaceAllStringFunc(mappedPath, func(s string) string {
return strings.Replace(s, "{"+r.key+"}", id, 1)
})
}
}
}
headers.Set(":path", mappedPath)
}