Supports completions API & support config openai baseUrl through openaiCustomUrl (#1765)

This commit is contained in:
澄潭
2025-02-18 09:57:48 +08:00
committed by GitHub
parent 94aacf5153
commit 49aad4152c
6 changed files with 124 additions and 45 deletions

View File

@@ -270,6 +270,9 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, "/v1/completions") {
return provider.ApiNameCompletion
}
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}

View File

@@ -17,6 +17,7 @@ import (
const (
claudeDomain = "api.anthropic.com"
claudeChatCompletionPath = "/v1/messages"
claudeCompletionPath = "/v1/complete"
defaultVersion = "2023-06-01"
defaultMaxTokens = 4096
)
@@ -88,6 +89,7 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): claudeChatCompletionPath,
string(ApiNameCompletion): claudeCompletionPath,
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}

View File

@@ -19,22 +19,55 @@ const (
)
type chatCompletionRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
Stop []string `json:"stop,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
Messages []chatMessage `json:"messages"`
Model string `json:"model"`
Store bool `json:"store,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
Logprobs bool `json:"logprobs,omitempty"`
TopLogprobs int `json:"top_logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Prediction map[string]interface{} `json:"prediction,omitempty"`
Audio map[string]interface{} `json:"audio,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
Seed int `json:"seed,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Stop []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
User string `json:"user,omitempty"`
}
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
BestOf int `json:"best_of,omitempty"`
Echo bool `json:"echo,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
Logprobs int `json:"logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Suffix string `json:"suffix,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
}
type streamOptions struct {
@@ -62,16 +95,18 @@ type chatCompletionResponse struct {
Choices []chatCompletionChoice `json:"choices"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Object string `json:"object,omitempty"`
Usage usage `json:"usage,omitempty"`
}
type chatCompletionChoice struct {
Index int `json:"index"`
Message *chatMessage `json:"message,omitempty"`
Delta *chatMessage `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Index int `json:"index"`
Message *chatMessage `json:"message,omitempty"`
Delta *chatMessage `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Logprobs map[string]interface{} `json:"logprobs,omitempty"`
}
type usage struct {
@@ -81,10 +116,14 @@ type usage struct {
}
type chatMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
Id string `json:"id,omitempty"`
Audio map[string]interface{} `json:"audio,omitempty"`
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"`
}
type messageContent struct {

View File

@@ -4,10 +4,12 @@ import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -16,7 +18,10 @@ import (
const (
defaultOpenaiDomain = "api.openai.com"
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
defaultOpenaiCompletionPath = "/v1/completions"
defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings"
defaultOpenaiAudioSpeech = "/v1/audio/speech"
defaultOpenaiImageGeneration = "/v1/images/generations"
)
type openaiProviderInitializer struct {
@@ -28,13 +33,24 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameCompletion): defaultOpenaiCompletionPath,
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
}
}
func isDirectPath(path string) bool {
return strings.HasSuffix(path, "/completions") ||
strings.HasSuffix(path, "/chat/embeddings") ||
strings.HasSuffix(path, "/audio/speech") ||
strings.HasSuffix(path, "/images/generations")
}
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
if config.openaiCustomUrl == "" {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &openaiProvider{
config: config,
contextCache: createContextCache(&config),
@@ -45,20 +61,32 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
if len(pairs) != 2 {
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
}
config.setDefaultCapabilities(m.DefaultCapabilities())
customPath := "/" + pairs[1]
isDirectCustomPath := isDirectPath(customPath)
capabilities := m.DefaultCapabilities()
if !isDirectCustomPath {
for key, mapPath := range capabilities {
capabilities[key] = path.Join(customPath, mapPath)
}
}
config.setDefaultCapabilities(capabilities)
proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
pairs[0], customPath, isDirectCustomPath, capabilities)
return &openaiProvider{
config: config,
customDomain: pairs[0],
customPath: "/" + pairs[1],
contextCache: createContextCache(&config),
config: config,
customDomain: pairs[0],
customPath: customPath,
isDirectCustomPath: isDirectCustomPath,
contextCache: createContextCache(&config),
}, nil
}
type openaiProvider struct {
config ProviderConfig
customDomain string
customPath string
contextCache *contextCache
config ProviderConfig
customDomain string
customPath string
isDirectCustomPath bool
contextCache *contextCache
}
func (m *openaiProvider) GetProviderType() string {
@@ -71,15 +99,19 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
}
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" {
if m.customPath != "" {
if m.isDirectCustomPath || apiName == "" {
util.OverwriteRequestPathHeader(headers, m.customPath)
} else {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
}
} else {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
} else {
util.OverwriteRequestPathHeader(headers, m.customPath)
}
if m.customDomain == "" {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} else {
if m.customDomain != "" {
util.OverwriteRequestHostHeader(headers, m.customDomain)
} else {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
}
if len(m.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -22,11 +22,13 @@ const (
// ApiName 格式 {vendor}/{version}/{apitype}
// 表示遵循 厂商/版本/接口类型 的格式
// 目前openai是事实意义上的标准但是也有其他厂商存在其他任务的一些可能的标准比如cohere的rerank
ApiNameCompletion ApiName = "openai/v1/completions"
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
PathOpenAICompletions = "/v1/completions"
PathOpenAIChatCompletions = "/v1/chat/completions"
PathOpenAIEmbeddings = "/v1/embeddings"

View File

@@ -648,10 +648,11 @@ type qwenUsage struct {
}
type qwenMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role"`
Content any `json:"content"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
Name string `json:"name,omitempty"`
Role string `json:"role"`
Content any `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
type qwenVlMessageContent struct {