Supports completions API & support config openai baseUrl through openaiCustomUrl (#1765)

This commit is contained in:
澄潭
2025-02-18 09:57:48 +08:00
committed by GitHub
parent 94aacf5153
commit 49aad4152c
6 changed files with 124 additions and 45 deletions

View File

@@ -4,10 +4,12 @@ import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -16,7 +18,10 @@ import (
const (
defaultOpenaiDomain = "api.openai.com"
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
defaultOpenaiCompletionPath = "/v1/completions"
defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings"
defaultOpenaiAudioSpeech = "/v1/audio/speech"
defaultOpenaiImageGeneration = "/v1/images/generations"
)
type openaiProviderInitializer struct {
@@ -28,13 +33,24 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameCompletion): defaultOpenaiCompletionPath,
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
}
}
func isDirectPath(path string) bool {
return strings.HasSuffix(path, "/completions") ||
strings.HasSuffix(path, "/chat/embeddings") ||
strings.HasSuffix(path, "/audio/speech") ||
strings.HasSuffix(path, "/images/generations")
}
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
if config.openaiCustomUrl == "" {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &openaiProvider{
config: config,
contextCache: createContextCache(&config),
@@ -45,20 +61,32 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
if len(pairs) != 2 {
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
}
config.setDefaultCapabilities(m.DefaultCapabilities())
customPath := "/" + pairs[1]
isDirectCustomPath := isDirectPath(customPath)
capabilities := m.DefaultCapabilities()
if !isDirectCustomPath {
for key, mapPath := range capabilities {
capabilities[key] = path.Join(customPath, mapPath)
}
}
config.setDefaultCapabilities(capabilities)
proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
pairs[0], customPath, isDirectCustomPath, capabilities)
return &openaiProvider{
config: config,
customDomain: pairs[0],
customPath: "/" + pairs[1],
contextCache: createContextCache(&config),
config: config,
customDomain: pairs[0],
customPath: customPath,
isDirectCustomPath: isDirectCustomPath,
contextCache: createContextCache(&config),
}, nil
}
type openaiProvider struct {
config ProviderConfig
customDomain string
customPath string
contextCache *contextCache
config ProviderConfig
customDomain string
customPath string
isDirectCustomPath bool
contextCache *contextCache
}
func (m *openaiProvider) GetProviderType() string {
@@ -71,15 +99,19 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
}
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" {
if m.customPath != "" {
if m.isDirectCustomPath || apiName == "" {
util.OverwriteRequestPathHeader(headers, m.customPath)
} else {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
}
} else {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
} else {
util.OverwriteRequestPathHeader(headers, m.customPath)
}
if m.customDomain == "" {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} else {
if m.customDomain != "" {
util.OverwriteRequestHostHeader(headers, m.customDomain)
} else {
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
}
if len(m.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))