mirror of
https://github.com/alibaba/higress.git
synced 2026-02-27 14:10:51 +08:00
Supports completions API & support config openai baseUrl through openaiCustomUrl (#1765)
This commit is contained in:
@@ -270,6 +270,9 @@ func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
|
||||
func getApiName(path string) provider.ApiName {
|
||||
// openai style
|
||||
if strings.HasSuffix(path, "/v1/completions") {
|
||||
return provider.ApiNameCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/chat/completions") {
|
||||
return provider.ApiNameChatCompletion
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
const (
|
||||
claudeDomain = "api.anthropic.com"
|
||||
claudeChatCompletionPath = "/v1/messages"
|
||||
claudeCompletionPath = "/v1/complete"
|
||||
defaultVersion = "2023-06-01"
|
||||
defaultMaxTokens = 4096
|
||||
)
|
||||
@@ -88,6 +89,7 @@ func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): claudeChatCompletionPath,
|
||||
string(ApiNameCompletion): claudeCompletionPath,
|
||||
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
|
||||
@@ -19,22 +19,55 @@ const (
|
||||
)
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction map[string]interface{} `json:"prediction,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
BestOf int `json:"best_of,omitempty"`
|
||||
Echo bool `json:"echo,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
Logprobs int `json:"logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type streamOptions struct {
|
||||
@@ -62,16 +95,18 @@ type chatCompletionResponse struct {
|
||||
Choices []chatCompletionChoice `json:"choices"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Usage usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message *chatMessage `json:"message,omitempty"`
|
||||
Delta *chatMessage `json:"delta,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Index int `json:"index"`
|
||||
Message *chatMessage `json:"message,omitempty"`
|
||||
Delta *chatMessage `json:"delta,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Logprobs map[string]interface{} `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type usage struct {
|
||||
@@ -81,10 +116,14 @@ type usage struct {
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Audio map[string]interface{} `json:"audio,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
}
|
||||
|
||||
type messageContent struct {
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -16,7 +18,10 @@ import (
|
||||
const (
|
||||
defaultOpenaiDomain = "api.openai.com"
|
||||
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
|
||||
defaultOpenaiCompletionPath = "/v1/completions"
|
||||
defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings"
|
||||
defaultOpenaiAudioSpeech = "/v1/audio/speech"
|
||||
defaultOpenaiImageGeneration = "/v1/images/generations"
|
||||
)
|
||||
|
||||
type openaiProviderInitializer struct {
|
||||
@@ -28,13 +33,24 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
||||
string(ApiNameCompletion): defaultOpenaiCompletionPath,
|
||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
||||
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
|
||||
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
|
||||
}
|
||||
}
|
||||
|
||||
func isDirectPath(path string) bool {
|
||||
return strings.HasSuffix(path, "/completions") ||
|
||||
strings.HasSuffix(path, "/chat/embeddings") ||
|
||||
strings.HasSuffix(path, "/audio/speech") ||
|
||||
strings.HasSuffix(path, "/images/generations")
|
||||
}
|
||||
|
||||
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
if config.openaiCustomUrl == "" {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &openaiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
@@ -45,20 +61,32 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
if len(pairs) != 2 {
|
||||
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
|
||||
}
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
customPath := "/" + pairs[1]
|
||||
isDirectCustomPath := isDirectPath(customPath)
|
||||
capabilities := m.DefaultCapabilities()
|
||||
if !isDirectCustomPath {
|
||||
for key, mapPath := range capabilities {
|
||||
capabilities[key] = path.Join(customPath, mapPath)
|
||||
}
|
||||
}
|
||||
config.setDefaultCapabilities(capabilities)
|
||||
proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
|
||||
pairs[0], customPath, isDirectCustomPath, capabilities)
|
||||
return &openaiProvider{
|
||||
config: config,
|
||||
customDomain: pairs[0],
|
||||
customPath: "/" + pairs[1],
|
||||
contextCache: createContextCache(&config),
|
||||
config: config,
|
||||
customDomain: pairs[0],
|
||||
customPath: customPath,
|
||||
isDirectCustomPath: isDirectCustomPath,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type openaiProvider struct {
|
||||
config ProviderConfig
|
||||
customDomain string
|
||||
customPath string
|
||||
contextCache *contextCache
|
||||
config ProviderConfig
|
||||
customDomain string
|
||||
customPath string
|
||||
isDirectCustomPath bool
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *openaiProvider) GetProviderType() string {
|
||||
@@ -71,15 +99,19 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
if m.customPath == "" {
|
||||
if m.customPath != "" {
|
||||
if m.isDirectCustomPath || apiName == "" {
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
} else {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
}
|
||||
} else {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
} else {
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
}
|
||||
if m.customDomain == "" {
|
||||
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
|
||||
} else {
|
||||
if m.customDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, m.customDomain)
|
||||
} else {
|
||||
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
|
||||
}
|
||||
if len(m.config.apiTokens) > 0 {
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -22,11 +22,13 @@ const (
|
||||
// ApiName 格式 {vendor}/{version}/{apitype}
|
||||
// 表示遵循 厂商/版本/接口类型 的格式
|
||||
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
|
||||
ApiNameCompletion ApiName = "openai/v1/completions"
|
||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
|
||||
PathOpenAICompletions = "/v1/completions"
|
||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||
|
||||
|
||||
@@ -648,10 +648,11 @@ type qwenUsage struct {
|
||||
}
|
||||
|
||||
type qwenMessage struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type qwenVlMessageContent struct {
|
||||
|
||||
Reference in New Issue
Block a user