Supports completions API & support config openai baseUrl through openaiCustomUrl (#1765)

This commit is contained in:
澄潭
2025-02-18 09:57:48 +08:00
committed by GitHub
parent 94aacf5153
commit 49aad4152c
6 changed files with 124 additions and 45 deletions

View File

@@ -270,6 +270,9 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
func getApiName(path string) provider.ApiName { func getApiName(path string) provider.ApiName {
// openai style // openai style
if strings.HasSuffix(path, "/v1/completions") {
return provider.ApiNameCompletion
}
if strings.HasSuffix(path, "/v1/chat/completions") { if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion return provider.ApiNameChatCompletion
} }

View File

@@ -17,6 +17,7 @@ import (
const ( const (
claudeDomain = "api.anthropic.com" claudeDomain = "api.anthropic.com"
claudeChatCompletionPath = "/v1/messages" claudeChatCompletionPath = "/v1/messages"
claudeCompletionPath = "/v1/complete"
defaultVersion = "2023-06-01" defaultVersion = "2023-06-01"
defaultMaxTokens = 4096 defaultMaxTokens = 4096
) )
@@ -88,6 +89,7 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string { func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{ return map[string]string{
string(ApiNameChatCompletion): claudeChatCompletionPath, string(ApiNameChatCompletion): claudeChatCompletionPath,
string(ApiNameCompletion): claudeCompletionPath,
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api // docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
string(ApiNameEmbeddings): PathOpenAIEmbeddings, string(ApiNameEmbeddings): PathOpenAIEmbeddings,
} }

View File

@@ -19,22 +19,55 @@ const (
) )
type chatCompletionRequest struct { type chatCompletionRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"` Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"` Model string `json:"model"`
Store bool `json:"store,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
Logprobs bool `json:"logprobs,omitempty"`
TopLogprobs int `json:"top_logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Prediction map[string]interface{} `json:"prediction,omitempty"`
Audio map[string]interface{} `json:"audio,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
Seed int `json:"seed,omitempty"` Seed int `json:"seed,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Stop []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"` StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"` Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"` ToolChoice *toolChoice `json:"tool_choice,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
}
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
BestOf int `json:"best_of,omitempty"`
Echo bool `json:"echo,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
Logprobs int `json:"logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Suffix string `json:"suffix,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
} }
type streamOptions struct { type streamOptions struct {
@@ -62,6 +95,7 @@ type chatCompletionResponse struct {
Choices []chatCompletionChoice `json:"choices"` Choices []chatCompletionChoice `json:"choices"`
Created int64 `json:"created,omitempty"` Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"` SystemFingerprint string `json:"system_fingerprint,omitempty"`
Object string `json:"object,omitempty"` Object string `json:"object,omitempty"`
Usage usage `json:"usage,omitempty"` Usage usage `json:"usage,omitempty"`
@@ -72,6 +106,7 @@ type chatCompletionChoice struct {
Message *chatMessage `json:"message,omitempty"` Message *chatMessage `json:"message,omitempty"`
Delta *chatMessage `json:"delta,omitempty"` Delta *chatMessage `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"` FinishReason string `json:"finish_reason,omitempty"`
Logprobs map[string]interface{} `json:"logprobs,omitempty"`
} }
type usage struct { type usage struct {
@@ -81,10 +116,14 @@ type usage struct {
} }
type chatMessage struct { type chatMessage struct {
Id string `json:"id,omitempty"`
Audio map[string]interface{} `json:"audio,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"` Content any `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"` ToolCalls []toolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"`
} }
type messageContent struct { type messageContent struct {

View File

@@ -4,10 +4,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"path"
"strings" "strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
) )
@@ -16,7 +18,10 @@ import (
const ( const (
defaultOpenaiDomain = "api.openai.com" defaultOpenaiDomain = "api.openai.com"
defaultOpenaiChatCompletionPath = "/v1/chat/completions" defaultOpenaiChatCompletionPath = "/v1/chat/completions"
defaultOpenaiCompletionPath = "/v1/completions"
defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings" defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings"
defaultOpenaiAudioSpeech = "/v1/audio/speech"
defaultOpenaiImageGeneration = "/v1/images/generations"
) )
type openaiProviderInitializer struct { type openaiProviderInitializer struct {
@@ -28,13 +33,24 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string { func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{ return map[string]string{
string(ApiNameCompletion): defaultOpenaiCompletionPath,
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath, string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath, string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
} }
} }
func isDirectPath(path string) bool {
return strings.HasSuffix(path, "/completions") ||
strings.HasSuffix(path, "/chat/embeddings") ||
strings.HasSuffix(path, "/audio/speech") ||
strings.HasSuffix(path, "/images/generations")
}
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
if config.openaiCustomUrl == "" { if config.openaiCustomUrl == "" {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &openaiProvider{ return &openaiProvider{
config: config, config: config,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
@@ -45,11 +61,22 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
if len(pairs) != 2 { if len(pairs) != 2 {
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl) return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
} }
config.setDefaultCapabilities(m.DefaultCapabilities()) customPath := "/" + pairs[1]
isDirectCustomPath := isDirectPath(customPath)
capabilities := m.DefaultCapabilities()
if !isDirectCustomPath {
for key, mapPath := range capabilities {
capabilities[key] = path.Join(customPath, mapPath)
}
}
config.setDefaultCapabilities(capabilities)
proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
pairs[0], customPath, isDirectCustomPath, capabilities)
return &openaiProvider{ return &openaiProvider{
config: config, config: config,
customDomain: pairs[0], customDomain: pairs[0],
customPath: "/" + pairs[1], customPath: customPath,
isDirectCustomPath: isDirectCustomPath,
contextCache: createContextCache(&config), contextCache: createContextCache(&config),
}, nil }, nil
} }
@@ -58,6 +85,7 @@ type openaiProvider struct {
config ProviderConfig config ProviderConfig
customDomain string customDomain string
customPath string customPath string
isDirectCustomPath bool
contextCache *contextCache contextCache *contextCache
} }
@@ -71,15 +99,19 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
} }
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" { if m.customPath != "" {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) if m.isDirectCustomPath || apiName == "" {
} else {
util.OverwriteRequestPathHeader(headers, m.customPath) util.OverwriteRequestPathHeader(headers, m.customPath)
}
if m.customDomain == "" {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} else { } else {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
}
} else {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
}
if m.customDomain != "" {
util.OverwriteRequestHostHeader(headers, m.customDomain) util.OverwriteRequestHostHeader(headers, m.customDomain)
} else {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} }
if len(m.config.apiTokens) > 0 { if len(m.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -22,11 +22,13 @@ const (
// ApiName 格式 {vendor}/{version}/{apitype} // ApiName 格式 {vendor}/{version}/{apitype}
// 表示遵循 厂商/版本/接口类型 的格式 // 表示遵循 厂商/版本/接口类型 的格式
// 目前openai是事实意义上的标准但是也有其他厂商存在其他任务的一些可能的标准比如cohere的rerank // 目前openai是事实意义上的标准但是也有其他厂商存在其他任务的一些可能的标准比如cohere的rerank
ApiNameCompletion ApiName = "openai/v1/completions"
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions" ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings" ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration" ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech" ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
PathOpenAICompletions = "/v1/completions"
PathOpenAIChatCompletions = "/v1/chat/completions" PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings" PathOpenAIEmbeddings = "/v1/embeddings"

View File

@@ -651,6 +651,7 @@ type qwenMessage struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"` ToolCalls []toolCall `json:"tool_calls,omitempty"`
} }