diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index fb2557db4..a83bee186 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -300,6 +300,15 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下 AWS Bedrock 所对应的 type 为 bedrock。它特有的配置字段如下: +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|---------------------------|--------|------|-----|------------------------------| +| `modelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version | +| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain | + +#### NVIDIA Triton Interference Server + +NVIDIA Triton Interference Server 所对应的 type 为 triton。它特有的配置字段如下: + | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |---------------------------|--------|------|-----|------------------------------| | `awsAccessKey` | string | 必填 | - | AWS Access Key,用于身份认证 | @@ -1973,6 +1982,57 @@ provider: } } ``` +### 使用 OpenAI 协议代理 NVIDIA Triton Interference Server 服务 + +**配置信息** + +```yaml +providers: + - type: triton + tritonDomain: + tritonModelVersion: + apiTokens: + - "****" + modelMapping: + "*": gpt2 +``` + +**请求示例** + +```json +{ + "model": "gpt2", + "messages": [ + { + "role": "user", + "content": "你好,你是谁?" + } + ], + "stream": false +} +``` +**响应示例** + +```json +{ + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "我是一个AI模型" + }, + "finish_reason": "stop", + } + ], + "model": "gpt2", +} +``` + + + + + ## 完整配置示例 diff --git a/plugins/wasm-go/extensions/ai-proxy/README_EN.md b/plugins/wasm-go/extensions/ai-proxy/README_EN.md index 5c052f382..e924fe2e3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README_EN.md +++ b/plugins/wasm-go/extensions/ai-proxy/README_EN.md @@ -1748,6 +1748,53 @@ provider: } ``` +### Utilizing OpenAI Protocol Proxy for NVIDIA Triton Interference Server Services + +**Configuration Information** + +```yaml +providers: + - type: triton + tritonDomain: + tritonModelVersion: + apiTokens: + - "****" + modelMapping: + "*": gpt2 +``` + +**Request Example** + +```json +{ + "model": "gpt2", + "messages": [ + { + "role": "user", + "content": "Hi, who are you?" + } + ], + "stream": false +} +``` +**Response Example** + +```json +{ + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I am a lagguage model." + }, + "finish_reason": "stop", + } + ], + "model": "gpt2", +} +``` + ## Full Configuration Example ### Kubernetes Example diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index a67e4f7a7..513a3e47d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -131,6 +131,7 @@ const ( providerTypeDify = "dify" providerTypeBedrock = "bedrock" providerTypeVertex = "vertex" + providerTypeTriton = "triton" providerTypeOpenRouter = "openrouter" protocolOpenAI = "openai" @@ -210,6 +211,7 @@ var ( providerTypeDify: &difyProviderInitializer{}, providerTypeBedrock: &bedrockProviderInitializer{}, providerTypeVertex: &vertexProviderInitializer{}, + providerTypeTriton: &tritonProviderInitializer{}, providerTypeOpenRouter: &openrouterProviderInitializer{}, } ) @@ -396,6 +398,12 @@ type ProviderConfig struct { // @Title zh-CN 首包超时 // @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0,表示不开启首包超时 firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"` + // @Title zh-CN Triton Model Version + // @Description 仅适用于 NVIDIA Triton Interference Server :path 中的 modelVersion 参考:"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/protocol/extension_generate.html" + tritonModelVersion string `required:"false" yaml:"tritonModelVersion" json:"tritonModelVersion"` + // @Title zh-CN Triton Server 部署的 Domain + // @Description 仅适用于 NVIDIA Triton Interference Server :path 中的 modelVersion 参考:"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/protocol/extension_generate.html" + tritonDomain string `required:"false" yaml:"tritonDomain" json:"tritonDomain"` } func (c *ProviderConfig) GetId() string { @@ -551,6 +559,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.inputVariable = json.Get("inputVariable").String() c.outputVariable = json.Get("outputVariable").String() + // NVIDIA triton + c.tritonModelVersion = json.Get("tritonModelVersion").String() + c.tritonDomain = json.Get("tritonDomain").String() + c.capabilities = make(map[string]string) for capability, pathJson := range json.Get("capabilities").Map() { // 过滤掉不受支持的能力 diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/triton.go b/plugins/wasm-go/extensions/ai-proxy/provider/triton.go new file mode 100644 index 000000000..ddfd9eeaf --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/triton.go @@ -0,0 +1,226 @@ +package provider + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +const ( + tritonChatGenerationPath = "/v2/models/{MODEL_NAME}/generate" + tritonChatGenerationWithVersionPath = "/v2/models/{MODEL_NAME}/versions/{MODEL_VERSION}/generate" + tritonChatGenerationStreamPath = "/v2/models/{MODEL_NAME}/versions/generate_stream" + tritonChatGenerationStreamWithVersionPath = "/v2/models/{MODEL_NAME}/versions/${MODEL_VERSION}/generate_stream" +) + +type tritonProviderInitializer struct{} + +func (t *tritonProviderInitializer) ValidateConfig(config *ProviderConfig) error { + if config.apiTokens == nil || len(config.apiTokens) == 0 { + return errors.New("no apiToken found in provider config") + } + return nil +} + +func (t *tritonProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): tritonChatGenerationPath, + // string(d): tritonChatCompletionPath, + } +} + +func (t *tritonProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(t.DefaultCapabilities()) + return &tritonProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type tritonProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (t *tritonProvider) GetProviderType() string { + return providerTypeTriton +} + +func (t *tritonProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { + t.config.handleRequestHeaders(t, ctx, apiName) + return nil +} + +func (t *tritonProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + if !t.config.isSupportedAPI(apiName) { + return types.ActionContinue, errUnsupportedApiName + } + return t.config.handleRequestBody(t, t.contextCache, ctx, apiName, body) +} + +func (t *tritonProvider) GetApiName(path string) ApiName { + if strings.Contains(path, tritonChatGenerationPath) { + return ApiNameChatCompletion + } + return "" +} + +func (t *tritonProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { + request := &chatCompletionRequest{} + if err := t.config.parseRequestAndMapModel(ctx, request, body); err != nil { + return nil, err + } + + tritonRequest := t.BuildTritonTexGenRequest(request) + + finalPath := t.getFinalRequestPath(ctx, request, request.Stream) + util.OverwriteRequestPathHeader(headers, finalPath) + util.OverwriteRequestHostHeader(headers, t.config.tritonDomain) + log.Debugf("get current config.tritonDomain: %s", t.config.tritonDomain) + headers.Del("Content-Length") + + return json.Marshal(tritonRequest) +} + +func (t *tritonProvider) getFinalRequestPath(ctx wrapper.HttpContext, oriRequest *chatCompletionRequest, streaming bool) string { + res := tritonChatGenerationPath + log.Debugf("[Triton Server]: CurrentModelVersion: %s", t.config.tritonModelVersion) + if t.config.tritonModelVersion != "" { + res = tritonChatGenerationWithVersionPath + res = strings.Replace(res, "{MODEL_VERSION}", t.config.tritonModelVersion, 1) + } + + res = strings.Replace(res, "{MODEL_NAME}", oriRequest.Model, 1) + if streaming { + res += "_stream" + } + + log.Debugf("[Triton Server]: Get final RequestPath: %v", res) + return res +} + +type TritonGenerateRequest struct { + Id string `json:"id"` + TextInput string `json:"text_input"` + Parameters TritonGenerateParameter `json:"parameters"` +} + +type TritonGenerateParameter struct { + Stream bool `json:"stream"` + Temperature float64 `json:"temperature"` +} + +type TritonGenerateResponse struct { + Id string `json:"id"` + ModelName string `json:"model_name"` + ModelVersion string `json:"model_version"` + TextOutput string `json:"text_output"` + Error string `json:"error"` +} + +func (t *tritonProvider) BuildTritonTexGenRequest(origRequest *chatCompletionRequest) *TritonGenerateRequest { + res := &TritonGenerateRequest{ + Id: "", + TextInput: "", + Parameters: TritonGenerateParameter{ + Stream: origRequest.Stream, + Temperature: origRequest.Temperature, + }, + } + + for _, msg := range origRequest.Messages { + res.Id = msg.Id + res.TextInput = msg.StringContent() + } + return res +} +func (t *tritonProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + if apiName != ApiNameChatCompletion { + return body, nil + } + tritonRes := &TritonGenerateResponse{} + if err := json.Unmarshal(body, tritonRes); err != nil { + return nil, fmt.Errorf("unable to unmarshal claude response: %v", err) + } + if tritonRes.Error != "" { + return nil, fmt.Errorf("triton response error, error_message: %s", tritonRes.Error) + } + response := t.ParseResponse2OpenAI(tritonRes) + return json.Marshal(response) +} + +func (t *tritonProvider) ParseResponse2OpenAI(tritonRes *TritonGenerateResponse) *chatCompletionResponse { + res := &chatCompletionResponse{ + Id: tritonRes.Id, + Choices: []chatCompletionChoice{{ + Index: 0, + Message: &chatMessage{ + Id: "", + Role: roleAssistant, + Content: tritonRes.TextOutput, + }, + FinishReason: util.Ptr(finishReasonStop), + }}, + Model: tritonRes.ModelName, + } + return res +} + +func (t *tritonProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) { + log.Infof("[tritonProvider] receive chunk body: %s", string(chunk)) + if isLastChunk || len(chunk) == 0 { + return nil, nil + } + if name != ApiNameChatCompletion { + return chunk, nil + } + responseBuilder := &strings.Builder{} + lines := strings.Split(string(chunk), "\n") + for _, data := range lines { + if len(data) < 6 { + // ignore blank line or wrong format + continue + } + data = data[6:] + var tritonRes TritonGenerateResponse + if err := json.Unmarshal([]byte(data), &tritonRes); err != nil { + log.Errorf("unable to unmarshal triton response: %v", err) + continue + } + response := t.buildOpenAIStreamResponse(ctx, &tritonRes) + responseBody, err := json.Marshal(response) + if err != nil { + log.Errorf("unable to marshal response: %v", err) + return nil, err + } + responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) + } + modifiedResponseChunk := responseBuilder.String() + log.Debugf("[tritonStreamingProvider] modified response chunk: %s", modifiedResponseChunk) + return []byte(modifiedResponseChunk), nil +} + +func (t *tritonProvider) buildOpenAIStreamResponse(ctx wrapper.HttpContext, tritonRes *TritonGenerateResponse) *chatCompletionResponse { + var choice chatCompletionChoice + choice.Delta = &chatMessage{ + Id: tritonRes.Id, + Content: tritonRes.TextOutput, + } + streamResponse := chatCompletionResponse{ + Id: tritonRes.Id, + Object: objectChatCompletionChunk, + Created: time.Now().UnixMilli() / 1000, + Model: tritonRes.ModelName, + Choices: []chatCompletionChoice{choice}, + } + return &streamResponse + +}