implement generic provider for vendor-agnostic passthrough (#3175)

This commit is contained in:
woody
2025-12-03 09:52:47 +08:00
committed by jingze
parent 48433a6549
commit 2b49fd5b26
6 changed files with 360 additions and 0 deletions

View File

@@ -0,0 +1,85 @@
package provider
import (
"net/http"
"strconv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// genericProviderInitializer 用于创建一个不做能力映射的通用 Provider。
type genericProviderInitializer struct{}
// ValidateConfig 通用 Provider 不需要额外的配置校验。
func (m *genericProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil
}
// DefaultCapabilities 返回空映射,表示不会做路径或能力重写。
func (m *genericProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{}
}
// CreateProvider 创建 generic provider并沿用通用的上下文缓存能力。
func (m *genericProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &genericProvider{
config: config,
}, nil
}
// genericProvider 只负责公共的头部、请求体处理逻辑,不绑定任何厂商。
type genericProvider struct {
config ProviderConfig
}
func (m *genericProvider) GetProviderType() string {
return providerTypeGeneric
}
// OnRequestHeaders 复用通用的 handleRequestHeaders并在配置首包超时时写入相关头部。
func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
if m.config.firstByteTimeout > 0 {
ctx.SetContext(ctxKeyIsStreaming, true)
m.applyFirstByteTimeout()
}
return nil
}
func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
return types.ActionContinue, nil
}
// TransformRequestHeaders 只处理鉴权与 Host 改写,不做路径重写。
func (m *genericProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if len(m.config.apiTokens) > 0 {
if token := m.config.GetApiTokenInUse(ctx); token != "" {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+token)
}
}
if m.config.genericHost != "" {
util.OverwriteRequestHostHeader(headers, m.config.genericHost)
}
headers.Del("Content-Length")
}
// applyFirstByteTimeout 在配置了 firstByteTimeout 时,为所有流式请求写入超时头。
func (m *genericProvider) applyFirstByteTimeout() {
if m.config.firstByteTimeout == 0 {
return
}
err := proxywasm.ReplaceHttpRequestHeader(
"x-envoy-upstream-rq-first-byte-timeout-ms",
strconv.FormatUint(uint64(m.config.firstByteTimeout), 10),
)
if err != nil {
log.Errorf("generic provider: failed to set first byte timeout header: %v", err)
return
}
log.Debugf("[generic][firstByteTimeout] %d", m.config.firstByteTimeout)
}

View File

@@ -144,6 +144,7 @@ const (
providerTypeLongcat = "longcat"
providerTypeFireworks = "fireworks"
providerTypeVllm = "vllm"
providerTypeGeneric = "generic"
protocolOpenAI = "openai"
protocolOriginal = "original"
@@ -227,6 +228,7 @@ var (
providerTypeLongcat: &longcatProviderInitializer{},
providerTypeFireworks: &fireworksProviderInitializer{},
providerTypeVllm: &vllmProviderInitializer{},
providerTypeGeneric: &genericProviderInitializer{},
}
)
@@ -409,6 +411,9 @@ type ProviderConfig struct {
basePath string `required:"false" yaml:"basePath" json:"basePath"`
// @Title zh-CN basePathHandling用于指定basePath的处理方式可选值removePrefix、prepend
basePathHandling basePathHandling `required:"false" yaml:"basePathHandling" json:"basePathHandling"`
// @Title zh-CN generic Provider 对应的Host
// @Description zh-CN 仅适用于generic provider用于覆盖请求转发的目标Host
genericHost string `required:"false" yaml:"genericHost" json:"genericHost"`
// @Title zh-CN 首包超时
// @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0表示不开启首包超时
firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"`
@@ -619,6 +624,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.basePath != "" && c.basePathHandling == "" {
c.basePathHandling = basePathHandlingRemovePrefix
}
c.genericHost = json.Get("genericHost").String()
c.vllmServerHost = json.Get("vllmServerHost").String()
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
}