mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat(ai-proxy): add batches & files support (#2355)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
@@ -361,9 +361,21 @@ func getApiName(path string) provider.ApiName {
|
|||||||
if strings.HasSuffix(path, "/v1/batches") {
|
if strings.HasSuffix(path, "/v1/batches") {
|
||||||
return provider.ApiNameBatches
|
return provider.ApiNameBatches
|
||||||
}
|
}
|
||||||
|
if util.RegRetrieveBatchPath.MatchString(path) {
|
||||||
|
return provider.ApiNameRetrieveBatch
|
||||||
|
}
|
||||||
|
if util.RegCancelBatchPath.MatchString(path) {
|
||||||
|
return provider.ApiNameCancelBatch
|
||||||
|
}
|
||||||
if strings.HasSuffix(path, "/v1/files") {
|
if strings.HasSuffix(path, "/v1/files") {
|
||||||
return provider.ApiNameFiles
|
return provider.ApiNameFiles
|
||||||
}
|
}
|
||||||
|
if util.RegRetrieveFilePath.MatchString(path) {
|
||||||
|
return provider.ApiNameRetrieveFile
|
||||||
|
}
|
||||||
|
if util.RegRetrieveFileContentPath.MatchString(path) {
|
||||||
|
return provider.ApiNameRetrieveFileContent
|
||||||
|
}
|
||||||
if strings.HasSuffix(path, "/v1/models") {
|
if strings.HasSuffix(path, "/v1/models") {
|
||||||
return provider.ApiNameModels
|
return provider.ApiNameModels
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,13 +42,10 @@ const (
|
|||||||
hunyuanAuthIdLen = 36
|
hunyuanAuthIdLen = 36
|
||||||
|
|
||||||
// docs: https://cloud.tencent.com/document/product/1729/111007
|
// docs: https://cloud.tencent.com/document/product/1729/111007
|
||||||
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
|
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
|
||||||
hunyuanOpenAiRequestPath = "/v1/chat/completions"
|
|
||||||
hunyuanOpenAiEmbeddings = "/v1/embeddings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type hunyuanProviderInitializer struct {
|
type hunyuanProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// ref: https://console.cloud.tencent.com/api/explorer?Product=hunyuan&Version=2023-09-01&Action=ChatCompletions
|
// ref: https://console.cloud.tencent.com/api/explorer?Product=hunyuan&Version=2023-09-01&Action=ChatCompletions
|
||||||
type hunyuanTextGenRequest struct {
|
type hunyuanTextGenRequest struct {
|
||||||
@@ -105,8 +102,8 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
|||||||
|
|
||||||
func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): hunyuanOpenAiRequestPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings,
|
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,7 +321,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hunyuan的流式返回:
|
// hunyuan的流式返回:
|
||||||
//data: {"Note":"以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
|
// data: {"Note":"以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
|
||||||
|
|
||||||
// openai的流式返回
|
// openai的流式返回
|
||||||
// data: {"id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", "object": "chat.completion.chunk", "created": 1677664795, "model": "gpt-3.5-turbo-0613", "choices": [{"delta": {"content": "The "}, "index": 0, "finish_reason": null}]}
|
// data: {"id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", "object": "chat.completion.chunk", "created": 1677664795, "model": "gpt-3.5-turbo-0613", "choices": [{"delta": {"content": "The "}, "index": 0, "finish_reason": null}]}
|
||||||
@@ -338,7 +335,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 初始化处理下标,以及将要返回的处理过的chunks
|
// 初始化处理下标,以及将要返回的处理过的chunks
|
||||||
var newEventPivot = -1
|
newEventPivot := -1
|
||||||
var outputBuffer []byte
|
var outputBuffer []byte
|
||||||
|
|
||||||
// 从buffer区取出若干完整的chunk,将其转为openAI格式后返回
|
// 从buffer区取出若干完整的chunk,将其转为openAI格式后返回
|
||||||
@@ -451,7 +448,6 @@ func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
||||||
|
|
||||||
fileMessage := hunyuanChatMessage{
|
fileMessage := hunyuanChatMessage{
|
||||||
Role: roleSystem,
|
Role: roleSystem,
|
||||||
Content: content,
|
Content: content,
|
||||||
|
|||||||
@@ -18,13 +18,10 @@ import (
|
|||||||
// moonshotProvider is the provider for Moonshot AI service.
|
// moonshotProvider is the provider for Moonshot AI service.
|
||||||
|
|
||||||
const (
|
const (
|
||||||
moonshotDomain = "api.moonshot.cn"
|
moonshotDomain = "api.moonshot.cn"
|
||||||
moonshotChatCompletionPath = "/v1/chat/completions"
|
|
||||||
moonshotModelsPath = "/v1/models"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type moonshotProviderInitializer struct {
|
type moonshotProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
if config.moonshotFileId != "" && config.context != nil {
|
if config.moonshotFileId != "" && config.context != nil {
|
||||||
@@ -38,8 +35,8 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err
|
|||||||
|
|
||||||
func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): moonshotChatCompletionPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
string(ApiNameModels): moonshotModelsPath,
|
string(ApiNameModels): PathOpenAIModels,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,17 +15,7 @@ import (
|
|||||||
// openaiProvider is the provider for OpenAI service.
|
// openaiProvider is the provider for OpenAI service.
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultOpenaiDomain = "api.openai.com"
|
defaultOpenaiDomain = "api.openai.com"
|
||||||
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
|
|
||||||
defaultOpenaiCompletionPath = "/v1/completions"
|
|
||||||
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
|
|
||||||
defaultOpenaiAudioSpeech = "/v1/audio/speech"
|
|
||||||
defaultOpenaiImageGeneration = "/v1/images/generations"
|
|
||||||
defaultOpenaiImageEdit = "/v1/images/edits"
|
|
||||||
defaultOpenaiImageVariation = "/v1/images/variations"
|
|
||||||
defaultOpenaiModels = "/v1/models"
|
|
||||||
defaultOpenaiFiles = "/v1/files"
|
|
||||||
defaultOpenaiBatchs = "/v1/batches"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type openaiProviderInitializer struct{}
|
type openaiProviderInitializer struct{}
|
||||||
@@ -36,15 +26,20 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
|||||||
|
|
||||||
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameCompletion): defaultOpenaiCompletionPath,
|
string(ApiNameCompletion): PathOpenAICompletions,
|
||||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||||
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
|
string(ApiNameImageGeneration): PathOpenAIImageGeneration,
|
||||||
string(ApiNameImageEdit): defaultOpenaiImageEdit,
|
string(ApiNameImageEdit): PathOpenAIImageEdit,
|
||||||
string(ApiNameImageVariation): defaultOpenaiImageVariation,
|
string(ApiNameImageVariation): PathOpenAIImageVariation,
|
||||||
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
|
string(ApiNameAudioSpeech): PathOpenAIAudioSpeech,
|
||||||
string(ApiNameModels): defaultOpenaiModels,
|
string(ApiNameModels): PathOpenAIModels,
|
||||||
string(ApiNameFiles): defaultOpenaiFiles,
|
string(ApiNameFiles): PathOpenAIFiles,
|
||||||
|
string(ApiNameRetrieveFile): PathOpenAIRetrieveFile,
|
||||||
|
string(ApiNameRetrieveFileContent): PathOpenAIRetrieveFileContent,
|
||||||
|
string(ApiNameBatches): PathOpenAIBatches,
|
||||||
|
string(ApiNameRetrieveBatch): PathOpenAIRetrieveBatch,
|
||||||
|
string(ApiNameCancelBatch): PathOpenAICancelBatch,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,23 +26,35 @@ const (
|
|||||||
// ApiName 格式 {vendor}/{version}/{apitype}
|
// ApiName 格式 {vendor}/{version}/{apitype}
|
||||||
// 表示遵循 厂商/版本/接口类型 的格式
|
// 表示遵循 厂商/版本/接口类型 的格式
|
||||||
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
|
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
|
||||||
ApiNameCompletion ApiName = "openai/v1/completions"
|
ApiNameCompletion ApiName = "openai/v1/completions"
|
||||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||||
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
||||||
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
||||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||||
ApiNameFiles ApiName = "openai/v1/files"
|
ApiNameFiles ApiName = "openai/v1/files"
|
||||||
ApiNameBatches ApiName = "openai/v1/batches"
|
ApiNameRetrieveFile ApiName = "openai/v1/retrievefile"
|
||||||
ApiNameModels ApiName = "openai/v1/models"
|
ApiNameRetrieveFileContent ApiName = "openai/v1/retrievefilecontent"
|
||||||
|
ApiNameBatches ApiName = "openai/v1/batches"
|
||||||
|
ApiNameRetrieveBatch ApiName = "openai/v1/retrievebatch"
|
||||||
|
ApiNameCancelBatch ApiName = "openai/v1/cancelbatch"
|
||||||
|
ApiNameModels ApiName = "openai/v1/models"
|
||||||
|
|
||||||
PathOpenAICompletions = "/v1/completions"
|
PathOpenAICompletions = "/v1/completions"
|
||||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||||
PathOpenAIFiles = "/v1/files"
|
PathOpenAIFiles = "/v1/files"
|
||||||
PathOpenAIBatches = "/v1/batches"
|
PathOpenAIRetrieveFile = "/v1/files/{file_id}"
|
||||||
PathOpenAIModels = "/v1/models"
|
PathOpenAIRetrieveFileContent = "/v1/files/{file_id}/content"
|
||||||
|
PathOpenAIBatches = "/v1/batches"
|
||||||
|
PathOpenAIRetrieveBatch = "/v1/batches/{batch_id}"
|
||||||
|
PathOpenAICancelBatch = "/v1/batches/{batch_id}/cancel"
|
||||||
|
PathOpenAIModels = "/v1/models"
|
||||||
|
PathOpenAIImageGeneration = "/v1/images/generations"
|
||||||
|
PathOpenAIImageEdit = "/v1/images/edits"
|
||||||
|
PathOpenAIImageVariation = "/v1/images/variations"
|
||||||
|
PathOpenAIAudioSpeech = "/v1/audio/speech"
|
||||||
|
|
||||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||||
|
|||||||
@@ -23,13 +23,19 @@ import (
|
|||||||
const (
|
const (
|
||||||
qwenResultFormatMessage = "message"
|
qwenResultFormatMessage = "message"
|
||||||
|
|
||||||
qwenDefaultDomain = "dashscope.aliyuncs.com"
|
qwenDefaultDomain = "dashscope.aliyuncs.com"
|
||||||
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
|
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
|
||||||
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||||
qwenChatCompatiblePath = "/compatible-mode/v1/chat/completions"
|
qwenCompatibleChatCompletionPath = "/compatible-mode/v1/chat/completions"
|
||||||
qwenTextEmbeddingCompatiblePath = "/compatible-mode/v1/embeddings"
|
qwenCompatibleCompletionsPath = "/compatible-mode/v1/completions"
|
||||||
qwenBailianPath = "/api/v1/apps"
|
qwenCompatibleTextEmbeddingPath = "/compatible-mode/v1/embeddings"
|
||||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
qwenCompatibleFilesPath = "/compatible-mode/v1/files"
|
||||||
|
qwenCompatibleRetrieveFilePath = "/compatible-mode/v1/files/{file_id}"
|
||||||
|
qwenCompatibleRetrieveFileContentPath = "/compatible-mode/v1/files/{file_id}/content"
|
||||||
|
qwenCompatibleBatchesPath = "/compatible-mode/v1/batches"
|
||||||
|
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
||||||
|
qwenBailianPath = "/api/v1/apps"
|
||||||
|
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
|
|
||||||
qwenTopPMin = 0.000001
|
qwenTopPMin = 0.000001
|
||||||
qwenTopPMax = 0.999999
|
qwenTopPMax = 0.999999
|
||||||
@@ -40,8 +46,7 @@ const (
|
|||||||
qwenVlModelPrefixName = "qwen-vl"
|
qwenVlModelPrefixName = "qwen-vl"
|
||||||
)
|
)
|
||||||
|
|
||||||
type qwenProviderInitializer struct {
|
type qwenProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
if len(config.qwenFileIds) != 0 && config.context != nil {
|
if len(config.qwenFileIds) != 0 && config.context != nil {
|
||||||
@@ -56,8 +61,14 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
|||||||
func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool) map[string]string {
|
func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool) map[string]string {
|
||||||
if qwenEnableCompatible {
|
if qwenEnableCompatible {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): qwenChatCompatiblePath,
|
string(ApiNameChatCompletion): qwenCompatibleChatCompletionPath,
|
||||||
string(ApiNameEmbeddings): qwenTextEmbeddingCompatiblePath,
|
string(ApiNameEmbeddings): qwenCompatibleTextEmbeddingPath,
|
||||||
|
string(ApiNameCompletion): qwenCompatibleCompletionsPath,
|
||||||
|
string(ApiNameFiles): qwenCompatibleFilesPath,
|
||||||
|
string(ApiNameRetrieveFile): qwenCompatibleRetrieveFilePath,
|
||||||
|
string(ApiNameRetrieveFileContent): qwenCompatibleRetrieveFileContentPath,
|
||||||
|
string(ApiNameBatches): qwenCompatibleBatchesPath,
|
||||||
|
string(ApiNameRetrieveBatch): qwenCompatibleRetrieveBatchPath,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
@@ -673,10 +684,10 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
|
|||||||
case strings.Contains(path, qwenChatCompletionPath),
|
case strings.Contains(path, qwenChatCompletionPath),
|
||||||
strings.Contains(path, qwenMultimodalGenerationPath),
|
strings.Contains(path, qwenMultimodalGenerationPath),
|
||||||
strings.Contains(path, qwenBailianPath),
|
strings.Contains(path, qwenBailianPath),
|
||||||
strings.Contains(path, qwenChatCompatiblePath):
|
strings.Contains(path, qwenCompatibleChatCompletionPath):
|
||||||
return ApiNameChatCompletion
|
return ApiNameChatCompletion
|
||||||
case strings.Contains(path, qwenTextEmbeddingPath),
|
case strings.Contains(path, qwenTextEmbeddingPath),
|
||||||
strings.Contains(path, qwenTextEmbeddingCompatiblePath):
|
strings.Contains(path, qwenCompatibleTextEmbeddingPath):
|
||||||
return ApiNameEmbeddings
|
return ApiNameEmbeddings
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -15,12 +15,10 @@ import (
|
|||||||
|
|
||||||
// sparkProvider is the provider for SparkLLM AI service.
|
// sparkProvider is the provider for SparkLLM AI service.
|
||||||
const (
|
const (
|
||||||
sparkHost = "spark-api-open.xf-yun.com"
|
sparkHost = "spark-api-open.xf-yun.com"
|
||||||
sparkChatCompletionPath = "/v1/chat/completions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type sparkProviderInitializer struct {
|
type sparkProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
type sparkProvider struct {
|
type sparkProvider struct {
|
||||||
config ProviderConfig
|
config ProviderConfig
|
||||||
@@ -58,7 +56,7 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
|||||||
|
|
||||||
func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string {
|
func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): sparkChatCompletionPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
stepfunDomain = "api.stepfun.com"
|
stepfunDomain = "api.stepfun.com"
|
||||||
stepfunChatCompletionPath = "/v1/chat/completions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type stepfunProviderInitializer struct {
|
type stepfunProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||||
@@ -27,7 +25,7 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
|||||||
func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
// stepfun的chat接口path和OpenAI的chat接口一样
|
// stepfun的chat接口path和OpenAI的chat接口一样
|
||||||
string(ApiNameChatCompletion): stepfunChatCompletionPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
togetherAIDomain = "api.together.xyz"
|
togetherAIDomain = "api.together.xyz"
|
||||||
togetherAICompletionPath = "/v1/chat/completions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type togetherAIProviderInitializer struct{}
|
type togetherAIProviderInitializer struct{}
|
||||||
@@ -26,7 +25,7 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e
|
|||||||
|
|
||||||
func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): togetherAICompletionPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,7 +66,7 @@ func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *togetherAIProvider) GetApiName(path string) ApiName {
|
func (m *togetherAIProvider) GetApiName(path string) ApiName {
|
||||||
if strings.Contains(path, togetherAICompletionPath) {
|
if strings.Contains(path, PathOpenAIChatCompletions) {
|
||||||
return ApiNameChatCompletion
|
return ApiNameChatCompletion
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -10,12 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
yiDomain = "api.lingyiwanwu.com"
|
yiDomain = "api.lingyiwanwu.com"
|
||||||
yiChatCompletionPath = "/v1/chat/completions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type yiProviderInitializer struct {
|
type yiProviderInitializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||||
@@ -26,7 +24,7 @@ func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
|||||||
|
|
||||||
func (m *yiProviderInitializer) DefaultCapabilities() map[string]string {
|
func (m *yiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
string(ApiNameChatCompletion): yiChatCompletionPath,
|
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package util
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
)
|
)
|
||||||
@@ -13,6 +15,13 @@ const (
|
|||||||
MimeTypeApplicationJson = "application/json"
|
MimeTypeApplicationJson = "application/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
RegRetrieveBatchPath = regexp.MustCompile(`^.*/v1/batches/(?P<batch_id>[^/]+)$`)
|
||||||
|
RegCancelBatchPath = regexp.MustCompile(`^.*/v1/batches/(?P<batch_id>[^/]+)/cancel$`)
|
||||||
|
RegRetrieveFilePath = regexp.MustCompile(`^.*/v1/files/(?P<file_id>[^/]+)$`)
|
||||||
|
RegRetrieveFileContentPath = regexp.MustCompile(`^.*/v1/files/(?P<file_id>[^/]+)/content$`)
|
||||||
|
)
|
||||||
|
|
||||||
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
|
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
|
||||||
|
|
||||||
var ErrorHandler ErrorHandlerFunc = func(statusCodeDetails string, err error) error {
|
var ErrorHandler ErrorHandlerFunc = func(statusCodeDetails string, err error) error {
|
||||||
@@ -62,9 +71,43 @@ func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string,
|
|||||||
if !exist {
|
if !exist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
|
originPath, err := proxywasm.GetHttpRequestHeader(":path")
|
||||||
|
if err == nil {
|
||||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
这里实现不太优雅,理应通过 apiName 来判断使用哪个正则替换
|
||||||
|
但 ApiName 定义在 provider 中, 而 provider 中又引用了 util
|
||||||
|
会导致循环引用
|
||||||
|
**/
|
||||||
|
if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") {
|
||||||
|
replacements := []struct {
|
||||||
|
regx *regexp.Regexp
|
||||||
|
key string
|
||||||
|
}{
|
||||||
|
{RegRetrieveFilePath, "file_id"},
|
||||||
|
{RegRetrieveFileContentPath, "file_id"},
|
||||||
|
{RegRetrieveBatchPath, "batch_id"},
|
||||||
|
{RegCancelBatchPath, "batch_id"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range replacements {
|
||||||
|
if r.regx.MatchString(originPath) {
|
||||||
|
subMatch := r.regx.FindStringSubmatch(originPath)
|
||||||
|
if subMatch == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
index := r.regx.SubexpIndex(r.key)
|
||||||
|
if index < 0 || index >= len(subMatch) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id := subMatch[index]
|
||||||
|
mappedPath = r.regx.ReplaceAllStringFunc(mappedPath, func(s string) string {
|
||||||
|
return strings.Replace(s, "{"+r.key+"}", id, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
headers.Set(":path", mappedPath)
|
headers.Set(":path", mappedPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user