diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 4c321c825..83d468d2e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -231,6 +231,18 @@ Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下: | `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 | | `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为 11434 | +#### 通用代理(Generic) + +当只需要借助 AI Proxy 的鉴权、basePath 处理或首包超时能力,且不希望插件改写路径时,可将 `provider.type` 设置为 `generic`。该 Provider 不绑定任何模型厂商,也不会做能力映射。 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ------------- | -------- | -------- | ------ | -------------------------------------------------------------------- | +| `genericHost` | string | 非必填 | - | 指定要转发到的目标 Host;未配置时沿用客户端请求的 Host。 | + +- 配置了 `apiTokens` 时,会自动写入 `Authorization: Bearer ` 请求头,复用全局的 Token 轮询能力。 +- 当配置了 `firstByteTimeout` 时,会自动注入 `x-envoy-upstream-rq-first-byte-timeout-ms`。 +- `basePath` 与 `basePathHandling` 同样适用,可在通用转发中快捷地移除或添加统一前缀。 + #### 混元 混元所对应的 `type` 为 `hunyuan`。它特有的配置字段如下: diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index 9cf005323..0064ec851 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -197,6 +197,18 @@ For Ollama, the corresponding `type` is `ollama`. Its unique configuration field | `ollamaServerHost` | string | Required | - | The host address of the Ollama server. | | `ollamaServerPort` | number | Required | - | The port number of the Ollama server, defaults to 11434. | +#### Generic + +For a vendor-agnostic passthrough, set the provider `type` to `generic`. Requests are forwarded without path remapping, while still benefiting from the shared header/basePath utilities. + +| Name | Data Type | Requirement | Default | Description | +|----------------|-----------|-------------|---------|----------------------------------------------------------------------------------------------------------| +| `genericHost` | string | Optional | - | Overrides the upstream `Host` header. Use it to route traffic to a specific backend domain for generic proxying. | + +- When `apiTokens` are configured, the Generic provider injects `Authorization: Bearer ` automatically. +- `firstByteTimeout` applies to any request whose body sets `stream: true`, ensuring consistent streaming behavior even without capability definitions. +- `basePath` and `basePathHandling` remain available to strip or prepend prefixes before forwarding. + #### Hunyuan For Hunyuan, the corresponding `type` is `hunyuan`. Its unique configuration fields are: diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 97cb2b6ee..8851efa98 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -117,3 +117,9 @@ func TestFireworks(t *testing.T) { func TestUtil(t *testing.T) { test.RunMapRequestPathByCapabilityTests(t) } + +func TestGeneric(t *testing.T) { + test.RunGenericParseConfigTests(t) + test.RunGenericOnHttpRequestHeadersTests(t) + test.RunGenericOnHttpRequestBodyTests(t) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/generic.go b/plugins/wasm-go/extensions/ai-proxy/provider/generic.go new file mode 100644 index 000000000..ee35f65f7 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/generic.go @@ -0,0 +1,85 @@ +package provider + +import ( + "net/http" + "strconv" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +// genericProviderInitializer 用于创建一个不做能力映射的通用 Provider。 +type genericProviderInitializer struct{} + +// ValidateConfig 通用 Provider 不需要额外的配置校验。 +func (m *genericProviderInitializer) ValidateConfig(config *ProviderConfig) error { + return nil +} + +// DefaultCapabilities 返回空映射,表示不会做路径或能力重写。 +func (m *genericProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{} +} + +// CreateProvider 创建 generic provider,并沿用通用的上下文缓存能力。 +func (m *genericProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) + return &genericProvider{ + config: config, + }, nil +} + +// genericProvider 只负责公共的头部、请求体处理逻辑,不绑定任何厂商。 +type genericProvider struct { + config ProviderConfig +} + +func (m *genericProvider) GetProviderType() string { + return providerTypeGeneric +} + +// OnRequestHeaders 复用通用的 handleRequestHeaders,并在配置首包超时时写入相关头部。 +func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { + m.config.handleRequestHeaders(m, ctx, apiName) + if m.config.firstByteTimeout > 0 { + ctx.SetContext(ctxKeyIsStreaming, true) + m.applyFirstByteTimeout() + } + return nil +} + +func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + return types.ActionContinue, nil +} + +// TransformRequestHeaders 只处理鉴权与 Host 改写,不做路径重写。 +func (m *genericProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + if len(m.config.apiTokens) > 0 { + if token := m.config.GetApiTokenInUse(ctx); token != "" { + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+token) + } + } + if m.config.genericHost != "" { + util.OverwriteRequestHostHeader(headers, m.config.genericHost) + } + headers.Del("Content-Length") +} + +// applyFirstByteTimeout 在配置了 firstByteTimeout 时,为所有流式请求写入超时头。 +func (m *genericProvider) applyFirstByteTimeout() { + if m.config.firstByteTimeout == 0 { + return + } + err := proxywasm.ReplaceHttpRequestHeader( + "x-envoy-upstream-rq-first-byte-timeout-ms", + strconv.FormatUint(uint64(m.config.firstByteTimeout), 10), + ) + if err != nil { + log.Errorf("generic provider: failed to set first byte timeout header: %v", err) + return + } + log.Debugf("[generic][firstByteTimeout] %d", m.config.firstByteTimeout) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index c35597a9e..62c4e1ec9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -144,6 +144,7 @@ const ( providerTypeLongcat = "longcat" providerTypeFireworks = "fireworks" providerTypeVllm = "vllm" + providerTypeGeneric = "generic" protocolOpenAI = "openai" protocolOriginal = "original" @@ -227,6 +228,7 @@ var ( providerTypeLongcat: &longcatProviderInitializer{}, providerTypeFireworks: &fireworksProviderInitializer{}, providerTypeVllm: &vllmProviderInitializer{}, + providerTypeGeneric: &genericProviderInitializer{}, } ) @@ -409,6 +411,9 @@ type ProviderConfig struct { basePath string `required:"false" yaml:"basePath" json:"basePath"` // @Title zh-CN basePathHandling用于指定basePath的处理方式,可选值:removePrefix、prepend basePathHandling basePathHandling `required:"false" yaml:"basePathHandling" json:"basePathHandling"` + // @Title zh-CN generic Provider 对应的Host + // @Description zh-CN 仅适用于generic provider,用于覆盖请求转发的目标Host + genericHost string `required:"false" yaml:"genericHost" json:"genericHost"` // @Title zh-CN 首包超时 // @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0,表示不开启首包超时 firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"` @@ -619,6 +624,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.basePath != "" && c.basePathHandling == "" { c.basePathHandling = basePathHandlingRemovePrefix } + c.genericHost = json.Get("genericHost").String() c.vllmServerHost = json.Get("vllmServerHost").String() c.vllmCustomUrl = json.Get("vllmCustomUrl").String() } diff --git a/plugins/wasm-go/extensions/ai-proxy/test/generic.go b/plugins/wasm-go/extensions/ai-proxy/test/generic.go new file mode 100644 index 000000000..a2e4f3e8c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/generic.go @@ -0,0 +1,239 @@ +package test + +import ( + "encoding/json" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" +) + +// 通用测试配置:最简配置,覆盖 host 与 token 注入。 +var genericBasicConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "generic", + "apiTokens": []string{"sk-generic-basic"}, + "genericHost": "generic.backend.internal", + }, + }) + return data +}() + +// 通用测试配置:开启 basePath removePrefix。 +var genericBasePathConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "generic", + "apiTokens": []string{"sk-generic-basepath"}, + "genericHost": "basepath.backend.internal", + "basePath": "/proxy", + "basePathHandling": "removePrefix", + }, + }) + return data +}() + +// 通用测试配置:开启 basePath prepend。 +var genericPrependBasePathConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "generic", + "apiTokens": []string{"sk-generic-prepend"}, + "genericHost": "prepend.backend.internal", + "basePath": "/custom", + "basePathHandling": "prepend", + }, + }) + return data +}() + +// 通用测试配置:覆盖 firstByteTimeout,用于流式能力验证。 +var genericStreamingConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "generic", + "apiTokens": []string{"sk-generic-stream"}, + "genericHost": "stream.backend.internal", + "firstByteTimeout": 1500, + }, + }) + return data +}() + +// 通用测试配置:无 token,也不设置 host。 +var genericNoTokenConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "generic", + }, + }) + return data +}() + +func RunGenericParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("generic basic config", func(t *testing.T) { + host, status := test.NewTestHost(genericBasicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + t.Run("generic config without token", func(t *testing.T) { + host, status := test.NewTestHost(genericNoTokenConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + }) + + t.Run("generic config with streaming options", func(t *testing.T) { + host, status := test.NewTestHost(genericStreamingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) +} + +func RunGenericOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("generic injects token and custom host", func(t *testing.T) { + host, status := test.NewTestHost(genericBasicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "client.local"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "generic.backend.internal")) + require.True(t, test.HasHeaderWithValue(requestHeaders, "Authorization", "Bearer sk-generic-basic")) + + _, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length") + require.False(t, hasContentLength, "generic provider should remove Content-Length") + }) + + t.Run("generic removes basePath prefix", func(t *testing.T) { + host, status := test.NewTestHost(genericBasePathConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "client.local"}, + {":path", "/proxy/service/echo"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, ":path", "/service/echo")) + require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "basepath.backend.internal")) + }) + + t.Run("generic prepends basePath when configured", func(t *testing.T) { + host, status := test.NewTestHost(genericPrependBasePathConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "client.local"}, + {":path", "/v1/echo"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, ":path", "/custom/v1/echo")) + }) + + t.Run("generic firstByteTimeout injects timeout header only", func(t *testing.T) { + host, status := test.NewTestHost(genericStreamingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "client.local"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms", "1500")) + + _, hasAccept := test.GetHeaderValue(requestHeaders, "Accept") + require.False(t, hasAccept, "Accept header should remain untouched when enabling firstByteTimeout") + }) + }) +} + +func RunGenericOnHttpRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("generic body passthrough keeps headers unchanged with timeout", func(t *testing.T) { + host, status := test.NewTestHost(genericStreamingConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "client.local"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + body := `{"model":"gpt-any","stream":true}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + require.True(t, test.HasHeaderWithValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms", "1500")) + _, hasAccept := test.GetHeaderValue(requestHeaders, "Accept") + require.False(t, hasAccept, "Accept header should remain untouched even when firstByteTimeout is enabled") + + processedBody := host.GetRequestBody() + require.JSONEq(t, body, string(processedBody)) + }) + + t.Run("generic without first byte timeout keeps headers untouched", func(t *testing.T) { + host, status := test.NewTestHost(genericBasicConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "client.local"}, + {":path", "/v1/chat/completions"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + body := `{"model":"gpt-any","stream":true}` + action := host.CallOnHttpRequestBody([]byte(body)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + _, hasAccept := test.GetHeaderValue(requestHeaders, "Accept") + require.False(t, hasAccept, "Accept header should remain untouched when first byte timeout is disabled") + + _, hasTimeout := test.GetHeaderValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms") + require.False(t, hasTimeout, "timeout header should not be added when first byte timeout is disabled") + + processedBody := host.GetRequestBody() + require.JSONEq(t, body, string(processedBody)) + }) + }) +}