From dcea483c61a5727fe46fd01ff9c2139dcbeaa4c9 Mon Sep 17 00:00:00 2001 From: urlyy Date: Thu, 15 Aug 2024 18:53:56 +0800 Subject: [PATCH] Feat: Add Deepl support for plugins/ai-proxy (#1147) --- plugins/wasm-go/extensions/ai-proxy/README.md | 61 ++++++ plugins/wasm-go/extensions/ai-proxy/main.go | 20 +- .../extensions/ai-proxy/provider/deepl.go | 176 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 6 + 4 files changed, 255 insertions(+), 8 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/deepl.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index af0d56fbe..3f655a992 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -165,6 +165,14 @@ Gemini 所对应的 `type` 为 `gemini`。它特有的配置字段如下: | --------------------- | -------- | -------- |-----|-------------------------------------------------------------------------------------------------| | `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) | +#### DeepL + +DeepL 所对应的 `type` 为 `deepl`。它特有的配置字段如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ------------ | -------- | -------- | ------ | ---------------------------- | +| `targetLang` | string | 必填 | - | DeepL 翻译服务需要的目标语种 | + ## 用法示例 ### 使用 OpenAI 协议代理 Azure OpenAI 服务 @@ -1008,6 +1016,59 @@ provider: } ``` +### 使用 OpenAI 协议代理 DeepL 文本翻译服务 + +**配置信息** + +```yaml +provider: + type: deepl + apiTokens: + - "YOUR_DEEPL_API_TOKEN" + targetLang: "ZH" +``` + +**请求示例** +此处 `model` 表示 DeepL 的服务类型,只能填 `Free` 或 `Pro`。`content` 中设置需要翻译的文本;在 `role: system` 的 `content` 中可以包含可能影响翻译但本身不会被翻译的上下文,例如翻译产品名称时,可以将产品描述作为上下文传递,这种额外的上下文可能会提高翻译的质量。 + +```json +{ + "model": "Free", + "messages": [ + { + "role": "system", + "content": "money" + }, + { + "content": "sit by the bank" + }, + { + "content": "a bank in China" + } + ] +} +``` + +**响应示例** +```json +{ + "choices": [ + { + "index": 0, + "message": { "name": "EN", "role": "assistant", "content": "坐庄" } + }, + { + "index": 1, + "message": { "name": "EN", "role": "assistant", "content": "中国银行" } + } + ], + "created": 1722747752, + "model": "Free", + "object": "chat.completion", + "usage": {} +} +``` + ## 完整配置示例 ### Kubernetes 示例 diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index df2f36456..f09e0d4af 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -140,24 +140,18 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo return types.ActionContinue } - contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") - if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { - if err != nil { - log.Errorf("unable to load content-type header from response: %v", err) - } - ctx.BufferResponseBody() - } - if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnResponseHeaders(ctx, apiName, log) if err == nil { + checkStream(&ctx, &log) return action } _ = util.SendResponse(500, "ai-proxy.proc_resp_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process response headers: %v", err)) return types.ActionContinue } + checkStream(&ctx, &log) _, needHandleBody := activeProvider.(provider.ResponseBodyHandler) _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) if !needHandleBody && !needHandleStreamingBody { @@ -223,3 +217,13 @@ func getOpenAiApiName(path string) provider.ApiName { } return "" } + +func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) { + contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") + if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { + if err != nil { + log.Errorf("unable to load content-type header from response: %v", err) + } + (*ctx).BufferResponseBody() + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go new file mode 100644 index 000000000..fb233d7ae --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -0,0 +1,176 @@ +package provider + +import ( + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +// deeplProvider is the provider for DeepL service. +const ( + deeplHostPro = "api.deepl.com" + deeplHostFree = "api-free.deepl.com" + deeplChatCompletionPath = "/v2/translate" +) + +type deeplProviderInitializer struct { +} + +type deeplProvider struct { + config ProviderConfig + contextCache *contextCache +} + +// spec reference: https://developers.deepl.com/docs/v/zh/api-reference/translate/openapi-spec-for-text-translation +type deeplRequest struct { + // "Model" parameter is used to distinguish which service to use + Model string `json:"model,omitempty"` + Text []string `json:"text"` + SourceLang string `json:"source_lang,omitempty"` + TargetLang string `json:"target_lang"` + Context string `json:"context,omitempty"` + SplitSentences string `json:"split_sentences,omitempty"` + PreserveFormatting bool `json:"preserve_formatting,omitempty"` + Formality string `json:"formality,omitempty"` + GlossaryId string `json:"glossary_id,omitempty"` + TagHandling string `json:"tag_handling,omitempty"` + OutlineDetection bool `json:"outline_detection,omitempty"` + NonSplittingTags []string `json:"non_splitting_tags,omitempty"` + SplittingTags []string `json:"splitting_tags,omitempty"` + IgnoreTags []string `json:"ignore_tags,omitempty"` +} + +type deeplResponse struct { + Translations []deeplResponseTranslation `json:"translations,omitempty"` + Message string `json:"message,omitempty"` +} + +type deeplResponseTranslation struct { + DetectedSourceLanguage string `json:"detected_source_language"` + Text string `json:"text"` +} + +func (d *deeplProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.targetLang == "" { + return errors.New("missing targetLang in deepl provider config") + } + return nil +} + +func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &deeplProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +func (d *deeplProvider) GetProviderType() string { + return providerTypeDeepl +} + +func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestPath(deeplChatCompletionPath) + _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken()) + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + return types.HeaderStopIteration, nil +} + +func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + if d.config.protocol == protocolOriginal { + request := &deeplRequest{} + if err := json.Unmarshal(body, request); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + } + if err := d.overwriteRequestHost(request.Model); err != nil { + return types.ActionContinue, err + } + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + return types.ActionContinue, replaceJsonRequestBody(request, log) + } else { + originRequest := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, originRequest); err != nil { + return types.ActionContinue, err + } + if err := d.overwriteRequestHost(originRequest.Model); err != nil { + return types.ActionContinue, err + } + ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model) + deeplRequest := &deeplRequest{ + Text: make([]string, 0), + TargetLang: d.config.targetLang, + } + for _, msg := range originRequest.Messages { + if msg.Role == roleSystem { + deeplRequest.Context = msg.Content + } else { + deeplRequest.Text = append(deeplRequest.Text, msg.Content) + } + } + return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log) + } +} + +func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + _ = proxywasm.RemoveHttpResponseHeader("Content-Length") + return types.ActionContinue, nil +} + +func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + deeplResponse := &deeplResponse{} + if err := json.Unmarshal(body, deeplResponse); err != nil { + return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err) + } + response := d.responseDeepl2OpenAI(ctx, deeplResponse) + return types.ActionContinue, replaceJsonResponseBody(response, log) +} + +func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse { + var choices []chatCompletionChoice + // Fail + if deeplResponse.Message != "" { + choices = make([]chatCompletionChoice, 1) + choices[0] = chatCompletionChoice{ + Message: &chatMessage{Role: roleAssistant, Content: deeplResponse.Message}, + Index: 0, + } + } else { + // Success + choices = make([]chatCompletionChoice, len(deeplResponse.Translations)) + for idx, t := range deeplResponse.Translations { + choices[idx] = chatCompletionChoice{ + Index: idx, + Message: &chatMessage{Role: roleAssistant, Content: t.Text, Name: t.DetectedSourceLanguage}, + } + } + } + return &chatCompletionResponse{ + Created: time.Now().UnixMilli() / 1000, + Object: objectChatCompletion, + Choices: choices, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + } +} + +func (d *deeplProvider) overwriteRequestHost(model string) error { + if model == "Pro" { + _ = util.OverwriteRequestHost(deeplHostPro) + } else if model == "Free" { + _ = util.OverwriteRequestHost(deeplHostFree) + } else { + return errors.New(`deepl model should be "Free" or "Pro"`) + } + return nil +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index b747c8da1..abcf65991 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -35,6 +35,7 @@ const ( providerTypeCloudflare = "cloudflare" providerTypeSpark = "spark" providerTypeGemini = "gemini" + providerTypeDeepl = "deepl" protocolOpenAI = "openai" protocolOriginal = "original" @@ -88,6 +89,7 @@ var ( providerTypeCloudflare: &cloudflareProviderInitializer{}, providerTypeSpark: &sparkProviderInitializer{}, providerTypeGemini: &geminiProviderInitializer{}, + providerTypeDeepl: &deeplProviderInitializer{}, } ) @@ -176,6 +178,9 @@ type ProviderConfig struct { // @Title zh-CN Gemini AI内容过滤和安全级别设定 // @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"` + // @Title zh-CN 翻译服务需指定的目标语种 + // @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。 + targetLang string `required:"false" yaml:"targetLang" json:"targetLang"` } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -223,6 +228,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.geminiSafetySetting[k] = v.String() } } + c.targetLang = json.Get("targetLang").String() } func (c *ProviderConfig) Validate() error {