From 49aad4152ca216aeff138e9d5ca9d3d07e41b417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BE=84=E6=BD=AD?= Date: Tue, 18 Feb 2025 09:57:48 +0800 Subject: [PATCH] Supports completions API & support config openai baseUrl through `openaiCustomUrl` (#1765) --- plugins/wasm-go/extensions/ai-proxy/main.go | 3 + .../extensions/ai-proxy/provider/claude.go | 2 + .../extensions/ai-proxy/provider/model.go | 87 ++++++++++++++----- .../extensions/ai-proxy/provider/openai.go | 66 ++++++++++---- .../extensions/ai-proxy/provider/provider.go | 2 + .../extensions/ai-proxy/provider/qwen.go | 9 +- 6 files changed, 124 insertions(+), 45 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 569a0e9ba..dc6bc123c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -270,6 +270,9 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) { func getApiName(path string) provider.ApiName { // openai style + if strings.HasSuffix(path, "/v1/completions") { + return provider.ApiNameCompletion + } if strings.HasSuffix(path, "/v1/chat/completions") { return provider.ApiNameChatCompletion } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index b185cae54..ceef37437 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -17,6 +17,7 @@ import ( const ( claudeDomain = "api.anthropic.com" claudeChatCompletionPath = "/v1/messages" + claudeCompletionPath = "/v1/complete" defaultVersion = "2023-06-01" defaultMaxTokens = 4096 ) @@ -88,6 +89,7 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ string(ApiNameChatCompletion): claudeChatCompletionPath, + string(ApiNameCompletion): claudeCompletionPath, // docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api string(ApiNameEmbeddings): PathOpenAIEmbeddings, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 61bef7467..726a18fca 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -19,22 +19,55 @@ const ( ) type chatCompletionRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Seed int `json:"seed,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *streamOptions `json:"stream_options,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Tools []tool `json:"tools,omitempty"` - ToolChoice *toolChoice `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - Stop []string `json:"stop,omitempty"` - ResponseFormat map[string]interface{} `json:"response_format,omitempty"` + Messages []chatMessage `json:"messages"` + Model string `json:"model"` + Store bool `json:"store,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + Logprobs bool `json:"logprobs,omitempty"` + TopLogprobs int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction map[string]interface{} `json:"prediction,omitempty"` + Audio map[string]interface{} `json:"audio,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat map[string]interface{} `json:"response_format,omitempty"` + Seed int `json:"seed,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Tools []tool `json:"tools,omitempty"` + ToolChoice *toolChoice `json:"tool_choice,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` +} + +type CompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + Logprobs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Seed int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *streamOptions `json:"stream_options,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } type streamOptions struct { @@ -62,16 +95,18 @@ type chatCompletionResponse struct { Choices []chatCompletionChoice `json:"choices"` Created int64 `json:"created,omitempty"` Model string `json:"model,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` SystemFingerprint string `json:"system_fingerprint,omitempty"` Object string `json:"object,omitempty"` Usage usage `json:"usage,omitempty"` } type chatCompletionChoice struct { - Index int `json:"index"` - Message *chatMessage `json:"message,omitempty"` - Delta *chatMessage `json:"delta,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Message *chatMessage `json:"message,omitempty"` + Delta *chatMessage `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Logprobs map[string]interface{} `json:"logprobs,omitempty"` } type usage struct { @@ -81,10 +116,14 @@ type usage struct { } type chatMessage struct { - Name string `json:"name,omitempty"` - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - ToolCalls []toolCall `json:"tool_calls,omitempty"` + Id string `json:"id,omitempty"` + Audio map[string]interface{} `json:"audio,omitempty"` + Name string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []toolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` } type messageContent struct { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 60767b0d0..af53bb6ed 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -4,10 +4,12 @@ import ( "encoding/json" "fmt" "net/http" + "path" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -16,7 +18,10 @@ import ( const ( defaultOpenaiDomain = "api.openai.com" defaultOpenaiChatCompletionPath = "/v1/chat/completions" + defaultOpenaiCompletionPath = "/v1/completions" defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings" + defaultOpenaiAudioSpeech = "/v1/audio/speech" + defaultOpenaiImageGeneration = "/v1/images/generations" ) type openaiProviderInitializer struct { @@ -28,13 +33,24 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string { return map[string]string{ - string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, - string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, + string(ApiNameCompletion): defaultOpenaiCompletionPath, + string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, + string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, + string(ApiNameImageGeneration): defaultOpenaiImageGeneration, + string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech, } } +func isDirectPath(path string) bool { + return strings.HasSuffix(path, "/completions") || + strings.HasSuffix(path, "/chat/embeddings") || + strings.HasSuffix(path, "/audio/speech") || + strings.HasSuffix(path, "/images/generations") +} + func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { if config.openaiCustomUrl == "" { + config.setDefaultCapabilities(m.DefaultCapabilities()) return &openaiProvider{ config: config, contextCache: createContextCache(&config), @@ -45,20 +61,32 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi if len(pairs) != 2 { return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl) } - config.setDefaultCapabilities(m.DefaultCapabilities()) + customPath := "/" + pairs[1] + isDirectCustomPath := isDirectPath(customPath) + capabilities := m.DefaultCapabilities() + if !isDirectCustomPath { + for key, mapPath := range capabilities { + capabilities[key] = path.Join(customPath, mapPath) + } + } + config.setDefaultCapabilities(capabilities) + proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v", + pairs[0], customPath, isDirectCustomPath, capabilities) return &openaiProvider{ - config: config, - customDomain: pairs[0], - customPath: "/" + pairs[1], - contextCache: createContextCache(&config), + config: config, + customDomain: pairs[0], + customPath: customPath, + isDirectCustomPath: isDirectCustomPath, + contextCache: createContextCache(&config), }, nil } type openaiProvider struct { - config ProviderConfig - customDomain string - customPath string - contextCache *contextCache + config ProviderConfig + customDomain string + customPath string + isDirectCustomPath bool + contextCache *contextCache } func (m *openaiProvider) GetProviderType() string { @@ -71,15 +99,19 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa } func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { - if m.customPath == "" { + if m.customPath != "" { + if m.isDirectCustomPath || apiName == "" { + util.OverwriteRequestPathHeader(headers, m.customPath) + } else { + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) + } + } else { util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) - } else { - util.OverwriteRequestPathHeader(headers, m.customPath) } - if m.customDomain == "" { - util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain) - } else { + if m.customDomain != "" { util.OverwriteRequestHostHeader(headers, m.customDomain) + } else { + util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain) } if len(m.config.apiTokens) > 0 { util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index d2d9efbbc..67cce2888 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -22,11 +22,13 @@ const ( // ApiName 格式 {vendor}/{version}/{apitype} // 表示遵循 厂商/版本/接口类型 的格式 // 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank + ApiNameCompletion ApiName = "openai/v1/completions" ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" ApiNameEmbeddings ApiName = "openai/v1/embeddings" ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" + PathOpenAICompletions = "/v1/completions" PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIEmbeddings = "/v1/embeddings" diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index ddf70e791..2f757c683 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -648,10 +648,11 @@ type qwenUsage struct { } type qwenMessage struct { - Name string `json:"name,omitempty"` - Role string `json:"role"` - Content any `json:"content"` - ToolCalls []toolCall `json:"tool_calls,omitempty"` + Name string `json:"name,omitempty"` + Role string `json:"role"` + Content any `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []toolCall `json:"tool_calls,omitempty"` } type qwenVlMessageContent struct {