diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 3ef90a175..975d64e5c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -135,6 +135,7 @@ const ( providerTypeOpenRouter = "openrouter" providerTypeLongcat = "longcat" providerTypeFireworks = "fireworks" + providerTypeVllm = "vllm" protocolOpenAI = "openai" protocolOriginal = "original" @@ -217,6 +218,7 @@ var ( providerTypeOpenRouter: &openrouterProviderInitializer{}, providerTypeLongcat: &longcatProviderInitializer{}, providerTypeFireworks: &fireworksProviderInitializer{}, + providerTypeVllm: &vllmProviderInitializer{}, } ) @@ -408,6 +410,12 @@ type ProviderConfig struct { // @Title zh-CN Triton Server 部署的 Domain // @Description 仅适用于 NVIDIA Triton Interference Server :path 中的 modelVersion 参考:"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/protocol/extension_generate.html" tritonDomain string `required:"false" yaml:"tritonDomain" json:"tritonDomain"` + // @Title zh-CN vLLM自定义后端URL + // @Description zh-CN 仅适用于vLLM服务。vLLM服务的完整URL,包含协议、域名、端口等 + vllmCustomUrl string `required:"false" yaml:"vllmCustomUrl" json:"vllmCustomUrl"` + // @Title zh-CN vLLM主机地址 + // @Description zh-CN 仅适用于vLLM服务,指定vLLM服务器的主机地址,例如:vllm-service.cluster.local + vllmServerHost string `required:"false" yaml:"vllmServerHost" json:"vllmServerHost"` } func (c *ProviderConfig) GetId() string { @@ -422,6 +430,14 @@ func (c *ProviderConfig) GetProtocol() string { return c.protocol } +func (c *ProviderConfig) GetVllmCustomUrl() string { + return c.vllmCustomUrl +} + +func (c *ProviderConfig) GetVllmServerHost() string { + return c.vllmServerHost +} + func (c *ProviderConfig) IsOpenAIProtocol() bool { return c.protocol == protocolOpenAI } @@ -591,6 +607,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.basePath != "" && c.basePathHandling == "" { c.basePathHandling = basePathHandlingRemovePrefix } + c.vllmServerHost = json.Get("vllmServerHost").String() + c.vllmCustomUrl = json.Get("vllmCustomUrl").String() } func (c *ProviderConfig) Validate() error { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go b/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go new file mode 100644 index 000000000..994e03930 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vllm.go @@ -0,0 +1,178 @@ +package provider + +import ( + "net/http" + "path" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +const ( + defaultVllmDomain = "vllm-service.cluster.local" +) + +// isVllmDirectPath checks if the path is a known standard vLLM interface path. +func isVllmDirectPath(path string) bool { + return strings.HasSuffix(path, "/completions") || + strings.HasSuffix(path, "/rerank") +} + +type vllmProviderInitializer struct{} + +func (m *vllmProviderInitializer) ValidateConfig(config *ProviderConfig) error { + // vLLM supports both authenticated and unauthenticated access + // If API tokens are configured, they will be used for authentication + // If no tokens are configured, the service will be accessed without authentication + return nil +} + +func (m *vllmProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameCompletion): PathOpenAICompletions, + string(ApiNameModels): PathOpenAIModels, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameCohereV1Rerank): PathCohereV1Rerank, + } +} + +func (m *vllmProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + if config.GetVllmCustomUrl() == "" { + config.setDefaultCapabilities(m.DefaultCapabilities()) + return &vllmProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil + } + + // Parse custom URL to extract domain and path + customUrl := strings.TrimPrefix(strings.TrimPrefix(config.GetVllmCustomUrl(), "http://"), "https://") + pairs := strings.SplitN(customUrl, "/", 2) + customPath := "/" + if len(pairs) == 2 { + customPath += pairs[1] + } + + // Check if the custom path is a direct path + isDirectCustomPath := isVllmDirectPath(customPath) + capabilities := m.DefaultCapabilities() + if !isDirectCustomPath { + for key, mapPath := range capabilities { + capabilities[key] = path.Join(customPath, strings.TrimPrefix(mapPath, "/v1")) + } + } + config.setDefaultCapabilities(capabilities) + + return &vllmProvider{ + config: config, + customDomain: pairs[0], + customPath: customPath, + isDirectCustomPath: isDirectCustomPath, + contextCache: createContextCache(&config), + }, nil +} + +type vllmProvider struct { + config ProviderConfig + customDomain string + customPath string + isDirectCustomPath bool + contextCache *contextCache +} + +func (m *vllmProvider) GetProviderType() string { + return providerTypeVllm +} + +func (m *vllmProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { + m.config.handleRequestHeaders(m, ctx, apiName) + return nil +} + +func (m *vllmProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + if !m.config.isSupportedAPI(apiName) { + return types.ActionContinue, errUnsupportedApiName + } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body) +} + +func (m *vllmProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + if m.isDirectCustomPath { + util.OverwriteRequestPathHeader(headers, m.customPath) + } else if apiName != "" { + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) + } + + // Set vLLM server host + if m.customDomain != "" { + util.OverwriteRequestHostHeader(headers, m.customDomain) + } else { + // Fallback to legacy vllmServerHost configuration + serverHost := m.config.GetVllmServerHost() + if serverHost == "" { + serverHost = defaultVllmDomain + } else { + // Extract domain from host:port format if present + if strings.Contains(serverHost, ":") { + parts := strings.SplitN(serverHost, ":", 2) + serverHost = parts[0] + } + } + util.OverwriteRequestHostHeader(headers, serverHost) + } + + // Add Bearer Token authentication if API tokens are configured + if len(m.config.apiTokens) > 0 { + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + } + + // Remove Content-Length header to allow body modification + headers.Del("Content-Length") +} + +func (m *vllmProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + // For vLLM, we can use the default transformation which handles model mapping + return m.config.defaultTransformRequestBody(ctx, apiName, body) +} + +func (m *vllmProvider) GetApiName(path string) ApiName { + if strings.Contains(path, PathOpenAIChatCompletions) { + return ApiNameChatCompletion + } + if strings.Contains(path, PathOpenAICompletions) { + return ApiNameCompletion + } + if strings.Contains(path, PathOpenAIModels) { + return ApiNameModels + } + if strings.Contains(path, PathOpenAIEmbeddings) { + return ApiNameEmbeddings + } + if strings.Contains(path, PathCohereV1Rerank) { + return ApiNameCohereV1Rerank + } + return "" +} + +// TransformResponseHeaders handles response header transformation for vLLM +func (m *vllmProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + // Remove Content-Length header to allow response body modification + headers.Del("Content-Length") +} + +// TransformResponseBody handles response body transformation for vLLM +func (m *vllmProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + // For now, just return the body as-is + // This can be extended to handle vLLM-specific response transformations + return body, nil +} + +// OnStreamingResponseBody handles streaming response body for vLLM +func (m *vllmProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { + // For now, just return the chunk as-is + // This can be extended to handle vLLM-specific streaming transformations + return chunk, nil +}