From 50f79c9099a5020868b3d39f98bf74815e19691e Mon Sep 17 00:00:00 2001 From: Yifan Gao <65636766+Claire-w@users.noreply.github.com> Date: Tue, 28 May 2024 10:55:59 +0800 Subject: [PATCH] feat: support ollama ai model (#1001) Co-authored-by: Kent Dong --- plugins/wasm-go/extensions/ai-proxy/README.md | 9 ++ plugins/wasm-go/extensions/ai-proxy/main.go | 2 +- .../extensions/ai-proxy/provider/ollama.go | 114 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 12 +- 4 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/ollama.go diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 018214d27..d8d30a83d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -93,6 +93,15 @@ Anthropic Claude 所对应的 `type` 为 `claude`。它特有的配置字段如 |-----------|--------|-----|-----|-------------------| | `version` | string | 必填 | - | Claude 服务的 API 版本 | +#### Ollama + +Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|-------------------|--------|------|-----|----------------------------------------------| +| `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 | +| `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为11434 | + ## 用法示例 ### 使用 OpenAI 协议代理 Azure OpenAI 服务 diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 3c5676129..876e00563 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -36,7 +36,7 @@ func main() { } func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error { - //log.Debugf("loading config: %s", json.String()) + // log.Debugf("loading config: %s", json.String()) pluginConfig.FromJson(json) if err := pluginConfig.Validate(); err != nil { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go new file mode 100644 index 000000000..b8df86bf3 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -0,0 +1,114 @@ +package provider + +import ( + "errors" + "fmt" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +// ollamaProvider is the provider for Ollama service. + +const ( + ollamaChatCompletionPath = "/v1/chat/completions" +) + +type ollamaProviderInitializer struct { +} + +func (m *ollamaProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.ollamaServerHost == "" { + return errors.New("missing ollamaServerHost in provider config") + } + if config.ollamaServerPort == 0 { + return errors.New("missing ollamaServerPort in provider config") + } + return nil +} + +func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + serverPortStr := fmt.Sprintf("%d", config.ollamaServerPort) + serviceDomain := config.ollamaServerHost + ":" + serverPortStr + return &ollamaProvider{ + config: config, + serviceDomain: serviceDomain, + contextCache: createContextCache(&config), + }, nil +} + +type ollamaProvider struct { + config ProviderConfig + serviceDomain string + contextCache *contextCache +} + +func (m *ollamaProvider) GetProviderType() string { + return providerTypeOllama +} + +func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestPath(ollamaChatCompletionPath) + _ = util.OverwriteRequestHost(m.serviceDomain) + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + + return types.ActionContinue, nil +} + +func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + + if m.config.modelMapping == nil && m.contextCache == nil { + return types.ActionContinue, nil + } + + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + + model := request.Model + if model == "" { + return types.ActionContinue, errors.New("missing model in chat completion request") + } + mappedModel := getMappedModel(model, m.config.modelMapping, log) + if mappedModel == "" { + return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + } + request.Model = mappedModel + + if m.contextCache != nil { + err := m.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } else { + return types.ActionContinue, err + } + } else { + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + return types.ActionContinue, err + } + _ = proxywasm.ResumeHttpRequest() + return types.ActionPause, nil + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index dcd87047e..ff396e28e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -23,6 +23,7 @@ const ( providerTypeBaichuan = "baichuan" providerTypeYi = "yi" providerTypeDeepSeek = "deepseek" + providerTypeOllama = "ollama" protocolOpenAI = "openai" protocolOriginal = "original" @@ -60,6 +61,7 @@ var ( providerTypeBaichuan: &baichuanProviderInitializer{}, providerTypeYi: &yiProviderInitializer{}, providerTypeDeepSeek: &deepseekProviderInitializer{}, + providerTypeOllama: &ollamaProviderInitializer{}, } ) @@ -89,7 +91,7 @@ type ResponseBodyHandler interface { type ProviderConfig struct { // @Title zh-CN AI服务提供商 - // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi" + // @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"ollama" typ string `required:"true" yaml:"type" json:"type"` // @Title zh-CN API Tokens // @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。 @@ -109,6 +111,12 @@ type ProviderConfig struct { // @Title zh-CN 启用通义千问搜索服务 // @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。 qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"` + // @Title zh-CN Ollama Server IP/Domain + // @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的主机地址。 + ollamaServerHost string `required:"false" yaml:"ollamaServerHost" json:"ollamaServerHost"` + // @Title zh-CN Ollama Server Port + // @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的端口号。 + ollamaServerPort uint32 `required:"false" yaml:"ollamaServerPort" json:"ollamaServerPort"` // @Title zh-CN 模型名称映射表 // @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射 modelMapping map[string]string `required:"false" yaml:"modelMapping" json:"modelMapping"` @@ -137,6 +145,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.qwenFileIds = append(c.qwenFileIds, fileId.String()) } c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool() + c.ollamaServerHost = json.Get("ollamaServerHost").String() + c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint()) c.modelMapping = make(map[string]string) for k, v := range json.Get("modelMapping").Map() { c.modelMapping[k] = v.String()