[ai-cache] Implement a WASM plugin for LLM result retrieval based on vector similarity (#1290)

This commit is contained in:
Yang Beining
2024-10-27 08:21:04 +00:00
committed by GitHub
parent d309bf2e25
commit acec48ed8b
27 changed files with 2025 additions and 346 deletions

View File

@@ -1,5 +1,5 @@
# File generated by hgctl. Modify as required.
docker-compose-test/
*
!/.gitignore

View File

@@ -1,9 +1,15 @@
## 简介
---
title: AI 缓存
keywords: [higress,ai cache]
description: AI 缓存插件配置参考
---
**Note**
> 需要数据面的proxy wasm版本大于等于0.2.100
> 编译时需要带上版本的tag例如`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./`
>
## 功能说明
@@ -20,32 +26,112 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的
插件执行优先级:`10`
## 配置说明
配置分为 3 个部分向量数据库vector文本向量化接口embedding缓存数据库cache同时也提供了细粒度的 LLM 请求/响应提取参数配置等。
## 配置说明
本插件同时支持基于向量数据库的语义化缓存和基于字符串匹配的缓存方法,如果同时配置了向量数据库和缓存数据库,优先使用向量数据库。
*Note*: 向量数据库(vector) 和 缓存数据库(cache) 不能同时为空,否则本插件无法提供缓存服务。
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| vector | string | optional | "" | 向量存储服务提供者类型,例如 dashvector |
| embedding | string | optional | "" | 请求文本向量化服务类型,例如 dashscope |
| cache | string | optional | "" | 缓存服务类型,例如 redis |
| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) |
| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用则使用字符串匹配的方式来查找缓存此时需要配置cache服务 |
根据是否需要启用语义缓存,可以只配置组件的组合为:
1. `cache`: 仅启用字符串匹配缓存
3. `vector (+ embedding)`: 启用语义化缓存, 其中若 `vector` 未提供字符串表征服务,则需要自行配置 `embedding` 服务
2. `vector (+ embedding) + cache`: 启用语义化缓存并用缓存服务存储LLM响应以加速
注意若不配置相关组件,则可以忽略相应组件的`required`字段。
## 向量数据库服务vector
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| vector.type | string | required | "" | 向量存储服务提供者类型,例如 dashvector |
| vector.serviceName | string | required | "" | 向量存储服务名称 |
| vector.serviceHost | string | required | "" | 向量存储服务域名 |
| vector.servicePort | int64 | optional | 443 | 向量存储服务端口 |
| vector.apiKey | string | optional | "" | 向量存储服务 API Key |
| vector.topK | int | optional | 1 | 返回TopK结果默认为 1 |
| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间单位为毫秒。默认值是10000即10秒 |
| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID |
| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 |
| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine``DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than小于)、`lte` (less than or equal to小等于)、`gt` (greater than大于)、`gte` (greater than or equal to大等于) |
## 文本向量化服务embedding
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 dashscope |
| embedding.serviceName | string | required | "" | 请求文本向量化服务名称 |
| embedding.serviceHost | string | optional | "" | 请求文本向量化服务域名 |
| embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 |
| embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key |
| embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间单位为毫秒。默认值是10000即10秒 |
| embedding.model | string | optional | "" | 请求文本向量化服务的模型名称 |
## 缓存服务cache
| cache.type | string | required | "" | 缓存服务类型,例如 redis |
| --- | --- | --- | --- | --- |
| cache.serviceName | string | required | "" | 缓存服务名称 |
| cache.serviceHost | string | required | "" | 缓存服务域名 |
| cache.servicePort | int64 | optional | 6379 | 缓存服务端口 |
| cache.username | string | optional | "" | 缓存服务用户名 |
| cache.password | string | optional | "" | 缓存服务密码 |
| cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间单位为毫秒。默认值是10000即10秒 |
| cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0即 永不过期|
| cacheKeyPrefix | string | optional | "higress-ai-cache:" | 缓存 Key 的前缀,默认值为 "higress-ai-cache:" |
## 其他配置
| Name | Type | Requirement | Default | Description |
| --- | --- | --- | --- | --- |
| cacheKeyFrom | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheValueFrom | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheStreamValueFrom | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheToolCallsFrom | string | optional | "choices.0.delta.content.tool_calls" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
| Name | Type | Requirement | Default | Description |
| -------- | -------- | -------- | -------- | -------- |
| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 |
| cacheKeyPrefix | string | optional | "higress-ai-cache:" | Redis缓存Key的前缀 |
| cacheTTL | integer | optional | 0 | 缓存的过期时间单位是秒默认值为0即永不过期 |
| redis.serviceName | string | requried | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
| redis.servicePort | integer | optional | 6379 | redis 服务端口 |
| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 |
| redis.username | string | optional | - | 登陆 redis 的用户名 |
| redis.password | string | optional | - | 登陆 redis 的密码 |
| returnResponseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
| returnStreamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
## 配置示例
### 基础配置
```yaml
embedding:
type: dashscope
serviceName: my_dashscope.dns
apiKey: [Your Key]
vector:
type: dashvector
serviceName: my_dashvector.dns
collectionID: [Your Collection ID]
serviceDomain: [Your domain]
apiKey: [Your key]
cache:
type: redis
serviceName: my_redis.dns
servicePort: 6379
timeout: 100
```
旧版本配置兼容
```yaml
redis:
serviceName: my-redis.dns
timeout: 2000
serviceName: my_redis.dns
servicePort: 6379
timeout: 100
```
## 进阶用法
当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content
GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 的 content 作为 key可以写成 `messages.@reverse.#(role=="user").content`
@@ -55,3 +141,7 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user
还可以支持管道语法,例如希望取到数第二个 role 为 user 的 content 作为 key可以写成`messages.@reverse.#(role=="user")#.content|1`
更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。
## 常见问题
1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。

View File

@@ -0,0 +1,135 @@
package cache
import (
"errors"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
PROVIDER_TYPE_REDIS = "redis"
DEFAULT_CACHE_PREFIX = "higress-ai-cache:"
)
type providerInitializer interface {
ValidateConfig(ProviderConfig) error
CreateProvider(ProviderConfig) (Provider, error)
}
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_REDIS: &redisProviderInitializer{},
}
)
type ProviderConfig struct {
// @Title zh-CN redis 缓存服务提供者类型
// @Description zh-CN 缓存服务提供者类型,例如 redis
typ string
// @Title zh-CN redis 缓存服务名称
// @Description zh-CN 缓存服务名称
serviceName string
// @Title zh-CN redis 缓存服务端口
// @Description zh-CN 缓存服务端口默认值为6379
servicePort int
// @Title zh-CN redis 缓存服务地址
// @Description zh-CN Cache 缓存服务地址,非必填
serviceHost string
// @Title zh-CN 缓存服务用户名
// @Description zh-CN 缓存服务用户名,非必填
username string
// @Title zh-CN 缓存服务密码
// @Description zh-CN 缓存服务密码,非必填
password string
// @Title zh-CN 请求超时
// @Description zh-CN 请求缓存服务的超时时间单位为毫秒。默认值是10000即10秒
timeout uint32
// @Title zh-CN 缓存过期时间
// @Description zh-CN 缓存过期时间单位为秒。默认值是0即永不过期
cacheTTL int
// @Title 缓存 Key 前缀
// @Description 缓存 Key 的前缀,默认值为 "higressAiCache:"
cacheKeyPrefix string
}
func (c *ProviderConfig) GetProviderType() string {
return c.typ
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
c.serviceName = json.Get("serviceName").String()
c.servicePort = int(json.Get("servicePort").Int())
if !json.Get("servicePort").Exists() {
c.servicePort = 6379
}
c.serviceHost = json.Get("serviceHost").String()
c.username = json.Get("username").String()
if !json.Get("username").Exists() {
c.username = ""
}
c.password = json.Get("password").String()
if !json.Get("password").Exists() {
c.password = ""
}
c.timeout = uint32(json.Get("timeout").Int())
if !json.Get("timeout").Exists() {
c.timeout = 10000
}
c.cacheTTL = int(json.Get("cacheTTL").Int())
if !json.Get("cacheTTL").Exists() {
c.cacheTTL = 0
// c.cacheTTL = 3600000
}
if json.Get("cacheKeyPrefix").Exists() {
c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String()
} else {
c.cacheKeyPrefix = DEFAULT_CACHE_PREFIX
}
}
func (c *ProviderConfig) ConvertLegacyJson(json gjson.Result) {
c.FromJson(json.Get("redis"))
c.typ = "redis"
if json.Get("cacheTTL").Exists() {
c.cacheTTL = int(json.Get("cacheTTL").Int())
}
}
func (c *ProviderConfig) Validate() error {
if c.typ == "" {
return errors.New("cache service type is required")
}
if c.serviceName == "" {
return errors.New("cache service name is required")
}
if c.cacheTTL < 0 {
return errors.New("cache TTL must be greater than or equal to 0")
}
initializer, has := providerInitializers[c.typ]
if !has {
return errors.New("unknown cache service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
return err
}
return nil
}
func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
return nil, errors.New("unknown provider type: " + pc.typ)
}
return initializer.CreateProvider(pc)
}
type Provider interface {
GetProviderType() string
Init(username string, password string, timeout uint32) error
Get(key string, cb wrapper.RedisResponseCallback) error
Set(key string, value string, cb wrapper.RedisResponseCallback) error
GetCacheKeyPrefix() string
}

View File

@@ -0,0 +1,58 @@
package cache
import (
"errors"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type redisProviderInitializer struct {
}
func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error {
if len(cf.serviceName) == 0 {
return errors.New("cache service name is required")
}
return nil
}
func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig) (Provider, error) {
rp := redisProvider{
config: cf,
client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: cf.serviceName,
Host: cf.serviceHost,
Port: int64(cf.servicePort)}),
}
err := rp.Init(cf.username, cf.password, cf.timeout)
return &rp, err
}
type redisProvider struct {
config ProviderConfig
client wrapper.RedisClient
}
func (rp *redisProvider) GetProviderType() string {
return PROVIDER_TYPE_REDIS
}
func (rp *redisProvider) Init(username string, password string, timeout uint32) error {
return rp.client.Init(rp.config.username, rp.config.password, int64(rp.config.timeout))
}
func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error {
return rp.client.Get(key, cb)
}
func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error {
if rp.config.cacheTTL == 0 {
return rp.client.Set(key, value, cb)
} else {
return rp.client.SetEx(key, value, rp.config.cacheTTL, cb)
}
}
func (rp *redisProvider) GetCacheKeyPrefix() string {
return rp.config.cacheKeyPrefix
}

View File

@@ -0,0 +1,225 @@
package config
import (
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
CACHE_KEY_STRATEGY_LAST_QUESTION = "lastQuestion"
CACHE_KEY_STRATEGY_ALL_QUESTIONS = "allQuestions"
CACHE_KEY_STRATEGY_DISABLED = "disabled"
)
type PluginConfig struct {
// @Title zh-CN 返回 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
ResponseTemplate string
// @Title zh-CN 返回流式 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
StreamResponseTemplate string
cacheProvider cache.Provider
embeddingProvider embedding.Provider
vectorProvider vector.Provider
embeddingProviderConfig embedding.ProviderConfig
vectorProviderConfig vector.ProviderConfig
cacheProviderConfig cache.ProviderConfig
CacheKeyFrom string
CacheValueFrom string
CacheStreamValueFrom string
CacheToolCallsFrom string
// @Title zh-CN 启用语义化缓存
// @Description zh-CN 控制是否启用语义化缓存功能。true 表示启用false 表示禁用。
EnableSemanticCache bool
// @Title zh-CN 缓存键策略
// @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存)
CacheKeyStrategy string
}
func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
c.vectorProviderConfig.FromJson(json.Get("vector"))
c.embeddingProviderConfig.FromJson(json.Get("embedding"))
c.cacheProviderConfig.FromJson(json.Get("cache"))
if json.Get("redis").Exists() {
// compatible with legacy config
c.cacheProviderConfig.ConvertLegacyJson(json)
}
c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String()
if c.CacheKeyStrategy == "" {
c.CacheKeyStrategy = CACHE_KEY_STRATEGY_LAST_QUESTION // set default value
}
c.CacheKeyFrom = json.Get("cacheKeyFrom").String()
if c.CacheKeyFrom == "" {
c.CacheKeyFrom = "messages.@reverse.0.content"
}
c.CacheValueFrom = json.Get("cacheValueFrom").String()
if c.CacheValueFrom == "" {
c.CacheValueFrom = "choices.0.message.content"
}
c.CacheStreamValueFrom = json.Get("cacheStreamValueFrom").String()
if c.CacheStreamValueFrom == "" {
c.CacheStreamValueFrom = "choices.0.delta.content"
}
c.CacheToolCallsFrom = json.Get("cacheToolCallsFrom").String()
if c.CacheToolCallsFrom == "" {
c.CacheToolCallsFrom = "choices.0.delta.content.tool_calls"
}
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
if c.StreamResponseTemplate == "" {
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
}
c.ResponseTemplate = json.Get("responseTemplate").String()
if c.ResponseTemplate == "" {
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
}
if json.Get("enableSemanticCache").Exists() {
c.EnableSemanticCache = json.Get("enableSemanticCache").Bool()
} else {
c.EnableSemanticCache = true // set default value to true
}
// compatible with legacy config
convertLegacyMapFields(c, json, log)
}
func (c *PluginConfig) Validate() error {
// if cache provider is configured, validate it
if c.cacheProviderConfig.GetProviderType() != "" {
if err := c.cacheProviderConfig.Validate(); err != nil {
return err
}
}
if c.embeddingProviderConfig.GetProviderType() != "" {
if err := c.embeddingProviderConfig.Validate(); err != nil {
return err
}
}
if c.vectorProviderConfig.GetProviderType() != "" {
if err := c.vectorProviderConfig.Validate(); err != nil {
return err
}
}
// cache, vector, and embedding cannot all be empty
if c.vectorProviderConfig.GetProviderType() == "" &&
c.embeddingProviderConfig.GetProviderType() == "" &&
c.cacheProviderConfig.GetProviderType() == "" {
return fmt.Errorf("vector, embedding and cache provider cannot be all empty")
}
// Validate the value of CacheKeyStrategy
if c.CacheKeyStrategy != CACHE_KEY_STRATEGY_LAST_QUESTION &&
c.CacheKeyStrategy != CACHE_KEY_STRATEGY_ALL_QUESTIONS &&
c.CacheKeyStrategy != CACHE_KEY_STRATEGY_DISABLED {
return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy)
}
// If semantic cache is enabled, ensure necessary components are configured
// if c.EnableSemanticCache {
// if c.embeddingProviderConfig.GetProviderType() == "" {
// return fmt.Errorf("semantic cache is enabled but embedding provider is not configured")
// }
// // if only configure cache, just warn the user
// }
return nil
}
func (c *PluginConfig) Complete(log wrapper.Log) error {
var err error
if c.embeddingProviderConfig.GetProviderType() != "" {
log.Debugf("embedding provider is set to %s", c.embeddingProviderConfig.GetProviderType())
c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig)
if err != nil {
return err
}
} else {
log.Info("embedding provider is not configured")
c.embeddingProvider = nil
}
if c.cacheProviderConfig.GetProviderType() != "" {
log.Debugf("cache provider is set to %s", c.cacheProviderConfig.GetProviderType())
c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig)
if err != nil {
return err
}
} else {
log.Info("cache provider is not configured")
c.cacheProvider = nil
}
if c.vectorProviderConfig.GetProviderType() != "" {
log.Debugf("vector provider is set to %s", c.vectorProviderConfig.GetProviderType())
c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig)
if err != nil {
return err
}
} else {
log.Info("vector provider is not configured")
c.vectorProvider = nil
}
return nil
}
func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider {
return c.embeddingProvider
}
func (c *PluginConfig) GetVectorProvider() vector.Provider {
return c.vectorProvider
}
func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig {
return c.vectorProviderConfig
}
func (c *PluginConfig) GetCacheProvider() cache.Provider {
return c.cacheProvider
}
func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) {
keyMap := map[string]string{
"cacheKeyFrom.requestBody": "cacheKeyFrom",
"cacheValueFrom.requestBody": "cacheValueFrom",
"cacheStreamValueFrom.requestBody": "cacheStreamValueFrom",
"returnResponseTemplate": "responseTemplate",
"returnStreamResponseTemplate": "streamResponseTemplate",
}
for oldKey, newKey := range keyMap {
if json.Get(oldKey).Exists() {
log.Debugf("[convertLegacyMapFields] mapping %s to %s", oldKey, newKey)
setField(c, newKey, json.Get(oldKey).String(), log)
} else {
log.Debugf("[convertLegacyMapFields] %s not exists", oldKey)
}
}
}
func setField(c *PluginConfig, fieldName string, value string, log wrapper.Log) {
switch fieldName {
case "cacheKeyFrom":
c.CacheKeyFrom = value
case "cacheValueFrom":
c.CacheValueFrom = value
case "cacheStreamValueFrom":
c.CacheStreamValueFrom = value
case "responseTemplate":
c.ResponseTemplate = value
case "streamResponseTemplate":
c.StreamResponseTemplate = value
}
log.Debugf("[setField] set %s to %s", fieldName, value)
}

View File

@@ -0,0 +1,275 @@
package main
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/resp"
)
// CheckCacheForKey checks if the key is in the cache, or triggers similarity search if not found.
func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error {
activeCacheProvider := c.GetCacheProvider()
if activeCacheProvider == nil {
log.Debugf("[%s] [CheckCacheForKey] no cache provider configured, performing similarity search", PLUGIN_NAME)
return performSimilaritySearch(key, ctx, c, log, key, stream)
}
queryKey := activeCacheProvider.GetCacheKeyPrefix() + key
log.Debugf("[%s] [CheckCacheForKey] querying cache with key: %s", PLUGIN_NAME, queryKey)
err := activeCacheProvider.Get(queryKey, func(response resp.Value) {
handleCacheResponse(key, response, ctx, log, stream, c, useSimilaritySearch)
})
if err != nil {
log.Errorf("[%s] [CheckCacheForKey] failed to retrieve key: %s from cache, error: %v", PLUGIN_NAME, key, err)
return err
}
return nil
}
// handleCacheResponse processes cache response and handles cache hits and misses.
func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) {
if err := response.Error(); err == nil && !response.IsNull() {
log.Infof("[%s] cache hit for key: %s", PLUGIN_NAME, key)
processCacheHit(key, response.String(), stream, ctx, c, log)
return
}
log.Infof("[%s] [handleCacheResponse] cache miss for key: %s", PLUGIN_NAME, key)
if err := response.Error(); err != nil {
log.Errorf("[%s] [handleCacheResponse] error retrieving key: %s from cache, error: %v", PLUGIN_NAME, key, err)
}
if useSimilaritySearch && c.EnableSemanticCache {
if err := performSimilaritySearch(key, ctx, c, log, key, stream); err != nil {
log.Errorf("[%s] [handleCacheResponse] failed to perform similarity search for key: %s, error: %v", PLUGIN_NAME, key, err)
proxywasm.ResumeHttpRequest()
}
} else {
proxywasm.ResumeHttpRequest()
}
}
// processCacheHit handles a successful cache hit.
func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) {
if strings.TrimSpace(response) == "" {
log.Warnf("[%s] [processCacheHit] cached response for key %s is empty", PLUGIN_NAME, key)
proxywasm.ResumeHttpRequest()
return
}
log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response)
// Escape the response to ensure consistent formatting
escapedResponse := strings.Trim(strconv.Quote(response), "\"")
ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil)
if stream {
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1)
} else {
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, escapedResponse)), -1)
}
}
// performSimilaritySearch determines the appropriate similarity search method to use.
func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, queryString string, stream bool) error {
activeVectorProvider := c.GetVectorProvider()
if activeVectorProvider == nil {
return logAndReturnError(log, "[performSimilaritySearch] no vector provider configured for similarity search")
}
// Check if the active vector provider implements the StringQuerier interface.
if _, ok := activeVectorProvider.(vector.StringQuerier); ok {
log.Debugf("[%s] [performSimilaritySearch] active vector provider implements StringQuerier interface, performing string query", PLUGIN_NAME)
return performStringQuery(key, queryString, ctx, c, log, stream)
}
// Check if the active vector provider implements the EmbeddingQuerier interface.
if _, ok := activeVectorProvider.(vector.EmbeddingQuerier); ok {
log.Debugf("[%s] [performSimilaritySearch] active vector provider implements EmbeddingQuerier interface, performing embedding query", PLUGIN_NAME)
return performEmbeddingQuery(key, ctx, c, log, stream)
}
return logAndReturnError(log, "[performSimilaritySearch] no suitable querier or embedding provider available for similarity search")
}
// performStringQuery executes the string-based similarity search.
func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error {
stringQuerier, ok := c.GetVectorProvider().(vector.StringQuerier)
if !ok {
return logAndReturnError(log, "[performStringQuery] active vector provider does not implement StringQuerier interface")
}
return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) {
handleQueryResults(key, results, ctx, log, stream, c, err)
})
}
// performEmbeddingQuery executes the embedding-based similarity search.
func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error {
embeddingQuerier, ok := c.GetVectorProvider().(vector.EmbeddingQuerier)
if !ok {
return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] active vector provider does not implement EmbeddingQuerier interface"))
}
activeEmbeddingProvider := c.GetEmbeddingProvider()
if activeEmbeddingProvider == nil {
return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] no embedding provider configured for similarity search"))
}
return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) {
log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err)
if err != nil {
handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error getting embedding for key: %s", PLUGIN_NAME, key), log)
return
}
ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding)
err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) {
handleQueryResults(key, results, ctx, log, stream, c, err)
})
if err != nil {
handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error querying vector database for key: %s", PLUGIN_NAME, key), log)
}
})
}
// handleQueryResults processes the results of similarity search and determines next actions.
func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, err error) {
if err != nil {
handleInternalError(err, fmt.Sprintf("[%s] [handleQueryResults] error querying vector database for key: %s", PLUGIN_NAME, key), log)
return
}
if len(results) == 0 {
log.Warnf("[%s] [handleQueryResults] no similar keys found for key: %s", PLUGIN_NAME, key)
proxywasm.ResumeHttpRequest()
return
}
mostSimilarData := results[0]
log.Debugf("[%s] [handleQueryResults] for key: %s, the most similar key found: %s with score: %f", PLUGIN_NAME, key, mostSimilarData.Text, mostSimilarData.Score)
simThreshold := c.GetVectorProviderConfig().Threshold
simThresholdRelation := c.GetVectorProviderConfig().ThresholdRelation
if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) {
log.Infof("[%s] key accepted: %s with score: %f", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score)
if mostSimilarData.Answer != "" {
// direct return the answer if available
cacheResponse(ctx, c, key, mostSimilarData.Answer, log)
processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log)
} else {
if c.GetCacheProvider() != nil {
CheckCacheForKey(mostSimilarData.Text, ctx, c, log, stream, false)
} else {
// Otherwise, do not check the cache, directly return
log.Infof("[%s] cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text)
proxywasm.ResumeHttpRequest()
}
}
} else {
log.Infof("[%s] score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score)
proxywasm.ResumeHttpRequest()
}
}
// logAndReturnError logs an error and returns it.
func logAndReturnError(log wrapper.Log, message string) error {
message = fmt.Sprintf("[%s] %s", PLUGIN_NAME, message)
log.Errorf(message)
return errors.New(message)
}
// handleInternalError logs an error and resumes the HTTP request.
func handleInternalError(err error, message string, log wrapper.Log) {
if err != nil {
log.Errorf("[%s] [handleInternalError] %s: %v", PLUGIN_NAME, message, err)
} else {
log.Errorf("[%s] [handleInternalError] %s", PLUGIN_NAME, message)
}
// proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1)
proxywasm.ResumeHttpRequest()
}
// Caches the response value
func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) {
if strings.TrimSpace(value) == "" {
log.Warnf("[%s] [cacheResponse] cached value for key %s is empty", PLUGIN_NAME, key)
return
}
activeCacheProvider := c.GetCacheProvider()
if activeCacheProvider != nil {
queryKey := activeCacheProvider.GetCacheKeyPrefix() + key
_ = activeCacheProvider.Set(queryKey, value, nil)
log.Debugf("[%s] [cacheResponse] cache set success, key: %s, length of value: %d", PLUGIN_NAME, queryKey, len(value))
}
}
// Handles embedding upload if available
func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) {
embedding := ctx.GetContext(CACHE_KEY_EMBEDDING_KEY)
if embedding == nil {
return
}
emb, ok := embedding.([]float64)
if !ok {
log.Errorf("[%s] [uploadEmbeddingAndAnswer] embedding is not of expected type []float64", PLUGIN_NAME)
return
}
activeVectorProvider := c.GetVectorProvider()
if activeVectorProvider == nil {
log.Debugf("[%s] [uploadEmbeddingAndAnswer] no vector provider configured for uploading embedding", PLUGIN_NAME)
return
}
// Attempt to upload answer embedding first
if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok {
log.Infof("[%s] uploading answer embedding for key: %s", PLUGIN_NAME, key)
err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil)
if err != nil {
log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload answer embedding for key: %s, error: %v", PLUGIN_NAME, key, err)
} else {
return // If successful, return early
}
}
// If answer embedding upload fails, attempt normal embedding upload
if embUploader, ok := activeVectorProvider.(vector.EmbeddingUploader); ok {
log.Infof("[%s] uploading embedding for key: %s", PLUGIN_NAME, key)
err := embUploader.UploadEmbedding(key, emb, ctx, log, nil)
if err != nil {
log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload embedding for key: %s, error: %v", PLUGIN_NAME, key, err)
}
}
}
// 主要用于相似度/距离/点积判断
// 余弦相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。
// 距离度量的是两个向量在空间上的远近程度。距离越小,两个向量越接近。
// compare 函数根据操作符进行判断并返回结果
func compare(operator string, value1 float64, value2 float64) bool {
switch operator {
case "gt":
return value1 > value2
case "gte":
return value1 >= value2
case "lt":
return value1 < value2
case "lte":
return value1 <= value2
default:
return false
}
}

View File

@@ -0,0 +1,187 @@
package embedding
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
const (
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
DASHSCOPE_PORT = 443
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v2"
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)
type dashScopeProviderInitializer struct {
}
func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiKey == "" {
return errors.New("[DashScope] apiKey is required")
}
return nil
}
func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = DASHSCOPE_PORT
}
if c.serviceHost == "" {
c.serviceHost = DASHSCOPE_DOMAIN
}
return &DSProvider{
config: c,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.serviceName,
Host: c.serviceHost,
Port: int64(c.servicePort),
}),
}, nil
}
func (d *DSProvider) GetProviderType() string {
return PROVIDER_TYPE_DASHSCOPE
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type Input struct {
Texts []string `json:"texts"`
}
type Params struct {
TextType string `json:"text_type"`
}
type Response struct {
RequestID string `json:"request_id"`
Output Output `json:"output"`
Usage Usage `json:"usage"`
}
type Output struct {
Embeddings []Embedding `json:"embeddings"`
}
type Usage struct {
TotalTokens int `json:"total_tokens"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Params `json:"parameters"`
}
type Document struct {
Vector []float64 `json:"vector"`
Fields map[string]string `json:"fields"`
}
type DSProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
model := d.config.model
if model == "" {
model = DASHSCOPE_DEFAULT_MODEL_NAME
}
data := EmbeddingRequest{
Model: model,
Input: Input{
Texts: texts,
},
Parameters: Params{
TextType: "query",
},
}
requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("failed to marshal request data: %v", err)
return "", nil, nil, err
}
if d.config.apiKey == "" {
err := errors.New("dashScopeKey is empty")
log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err
}
headers := [][2]string{
{"Authorization", "Bearer " + d.config.apiKey},
{"Content-Type", "application/json"},
}
return DASHSCOPE_ENDPOINT, headers, requestBody, err
}
type Result struct {
ID string `json:"id"`
Vector []float64 `json:"vector,omitempty"`
Fields map[string]interface{} `json:"fields"`
Score float64 `json:"score"`
}
func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
var resp Response
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (d *DSProvider) GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error {
embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log)
if err != nil {
log.Errorf("failed to construct parameters: %v", err)
return err
}
var resp *Response
err = d.client.Post(embUrl, embHeaders, embRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode))
callback(nil, err)
return
}
log.Debugf("get embedding response: %d, %s", statusCode, responseBody)
resp, err = d.parseTextEmbedding(responseBody)
if err != nil {
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}
if len(resp.Output.Embeddings) == 0 {
err = errors.New("no embedding found in response")
callback(nil, err)
return
}
callback(resp.Output.Embeddings[0].Embedding, nil)
}, d.config.timeout)
return err
}

View File

@@ -0,0 +1,101 @@
package embedding
import (
"errors"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
PROVIDER_TYPE_DASHSCOPE = "dashscope"
)
type providerInitializer interface {
ValidateConfig(ProviderConfig) error
CreateProvider(ProviderConfig) (Provider, error)
}
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
}
)
type ProviderConfig struct {
// @Title zh-CN 文本特征提取服务提供者类型
// @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope
typ string
// @Title zh-CN DashScope 文本特征提取服务名称
// @Description zh-CN 文本特征提取服务名称
serviceName string
// @Title zh-CN 文本特征提取服务域名
// @Description zh-CN 文本特征提取服务域名
serviceHost string
// @Title zh-CN 文本特征提取服务端口
// @Description zh-CN 文本特征提取服务端口
servicePort int64
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
// @Title zh-CN 文本特征提取服务超时时间
// @Description zh-CN 文本特征提取服务超时时间
timeout uint32
// @Title zh-CN 文本特征提取服务使用的模型
// @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1"
model string
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = json.Get("servicePort").Int()
c.apiKey = json.Get("apiKey").String()
c.timeout = uint32(json.Get("timeout").Int())
c.model = json.Get("model").String()
if c.timeout == 0 {
c.timeout = 10000
}
}
func (c *ProviderConfig) Validate() error {
if c.serviceName == "" {
return errors.New("embedding service name is required")
}
if c.apiKey == "" {
return errors.New("embedding service API key is required")
}
if c.typ == "" {
return errors.New("embedding service type is required")
}
initializer, has := providerInitializers[c.typ]
if !has {
return errors.New("unknown embedding service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
return err
}
return nil
}
func (c *ProviderConfig) GetProviderType() string {
return c.typ
}
func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
return nil, errors.New("unknown provider type: " + pc.typ)
}
return initializer.CreateProvider(pc)
}
type Provider interface {
GetProviderType() string
GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error
}

View File

@@ -0,0 +1,27 @@
package embedding
// import (
// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
// )
// const (
// weaviateURL = "172.17.0.1:8081"
// )
// type weaviateProviderInitializer struct {
// }
// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error {
// return nil
// }
// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
// return &DSProvider{
// config: config,
// client: wrapper.NewClusterClient(wrapper.DnsCluster{
// ServiceName: config.ServiceName,
// Port: dashScopePort,
// Domain: dashScopeDomain,
// }),
// }, nil
// }

View File

@@ -7,17 +7,18 @@ go 1.19
replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441
github.com/alibaba/higress/plugins/wasm-go v1.4.2
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1
github.com/tidwall/sjson v1.2.5
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
)

View File

@@ -1,24 +1,21 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -1,33 +1,33 @@
// File generated by hgctl. Modify as required.
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
// 这个文件中主要将OnHttpRequestHeaders、OnHttpRequestBody、OnHttpResponseHeaders、OnHttpResponseBody这四个函数实现
// 其中的缓存思路调用cache.go中的逻辑然后cache.go中的逻辑会调用textEmbeddingProvider和vectorStoreProvider中的逻辑实例
package main
import (
"errors"
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
CacheKeyContextKey = "cacheKey"
CacheContentContextKey = "cacheContent"
PartialMessageContextKey = "partialMessage"
ToolCallsContextKey = "toolCalls"
StreamContextKey = "stream"
DefaultCacheKeyPrefix = "higress-ai-cache:"
SkipCacheHeader = "x-higress-skip-ai-cache"
PLUGIN_NAME = "ai-cache"
CACHE_KEY_CONTEXT_KEY = "cacheKey"
CACHE_KEY_EMBEDDING_KEY = "cacheKeyEmbedding"
CACHE_CONTENT_CONTEXT_KEY = "cacheContent"
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
TOOL_CALLS_CONTEXT_KEY = "toolCalls"
STREAM_CONTEXT_KEY = "stream"
SKIP_CACHE_HEADER = "x-higress-skip-ai-cache"
ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage"
)
func main() {
// CreateClient()
wrapper.SetCtx(
"ai-cache",
PLUGIN_NAME,
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
@@ -36,146 +36,26 @@ func main() {
)
}
// @Name ai-cache
// @Category protocol
// @Phase AUTHN
// @Priority 10
// @Title zh-CN AI Cache
// @Description zh-CN 大模型结果缓存
// @IconUrl
// @Version 0.1.0
//
// @Contact.name johnlanni
// @Contact.url
// @Contact.email
//
// @Example
// redis:
// serviceName: my-redis.dns
// timeout: 2000
// cacheKeyFrom:
// requestBody: "messages.@reverse.0.content"
// cacheValueFrom:
// responseBody: "choices.0.message.content"
// cacheStreamValueFrom:
// responseBody: "choices.0.delta.content"
// returnResponseTemplate: |
// {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}
// returnStreamResponseTemplate: |
// data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}
//
// data:[DONE]
//
// @End
type RedisInfo struct {
// @Title zh-CN redis 服务名称
// @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local
ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"`
// @Title zh-CN redis 服务端口
// @Description zh-CN 默认值为6379
ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"`
// @Title zh-CN 用户名
// @Description zh-CN 登陆 redis 的用户名,非必填
Username string `required:"false" yaml:"username" json:"username"`
// @Title zh-CN 密码
// @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码
Password string `required:"false" yaml:"password" json:"password"`
// @Title zh-CN 请求超时
// @Description zh-CN 请求 redis 的超时时间单位为毫秒。默认值是1000即1秒
Timeout int `required:"false" yaml:"timeout" json:"timeout"`
func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) error {
// config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider"))
// config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider"))
// config.RedisConfig.FromJson(json.Get("redis"))
c.FromJson(json, log)
if err := c.Validate(); err != nil {
return err
}
// Note that initializing the client during the parseConfig phase may cause errors, such as Redis not being usable in Docker Compose.
if err := c.Complete(log); err != nil {
log.Errorf("complete config failed: %v", err)
return err
}
return nil
}
type KVExtractor struct {
// @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"`
// @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
}
type PluginConfig struct {
// @Title zh-CN Redis 地址信息
// @Description zh-CN 用于存储缓存结果的 Redis 地址
RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"`
// @Title zh-CN 缓存 key 的来源
// @Description zh-CN 往 redis 里存时,使用的 key 的提取方式
CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"`
// @Title zh-CN 缓存 value 的来源
// @Description zh-CN 往 redis 里存时,使用的 value 的提取方式
CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"`
// @Title zh-CN 流式响应下,缓存 value 的来源
// @Description zh-CN 往 redis 里存时,使用的 value 的提取方式
CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"`
// @Title zh-CN 返回 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"`
// @Title zh-CN 返回流式 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"`
// @Title zh-CN 缓存的过期时间
// @Description zh-CN 单位是秒默认值为0即永不过期
CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"`
// @Title zh-CN Redis缓存Key的前缀
// @Description zh-CN 默认值是"higress-ai-cache:"
CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"`
redisClient wrapper.RedisClient `yaml:"-" json:"-"`
}
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
c.RedisInfo.ServiceName = json.Get("redis.serviceName").String()
if c.RedisInfo.ServiceName == "" {
return errors.New("redis service name must not by empty")
}
c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int())
if c.RedisInfo.ServicePort == 0 {
if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") {
// use default logic port which is 80 for static service
c.RedisInfo.ServicePort = 80
} else {
c.RedisInfo.ServicePort = 6379
}
}
c.RedisInfo.Username = json.Get("redis.username").String()
c.RedisInfo.Password = json.Get("redis.password").String()
c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int())
if c.RedisInfo.Timeout == 0 {
c.RedisInfo.Timeout = 1000
}
c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String()
if c.CacheKeyFrom.RequestBody == "" {
c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content"
}
c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String()
if c.CacheValueFrom.ResponseBody == "" {
c.CacheValueFrom.ResponseBody = "choices.0.message.content"
}
c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String()
if c.CacheStreamValueFrom.ResponseBody == "" {
c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content"
}
c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String()
if c.ReturnResponseTemplate == "" {
c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
}
c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String()
if c.ReturnStreamResponseTemplate == "" {
c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
}
c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String()
if c.CacheKeyPrefix == "" {
c.CacheKeyPrefix = DefaultCacheKeyPrefix
}
c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: c.RedisInfo.ServiceName,
Port: int64(c.RedisInfo.ServicePort),
})
return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout))
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
skipCache, _ := proxywasm.GetHttpRequestHeader(SkipCacheHeader)
func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
skipCache, _ := proxywasm.GetHttpRequestHeader(SKIP_CACHE_HEADER)
if skipCache == "on" {
ctx.SetContext(SkipCacheHeader, struct{}{})
ctx.SetContext(SKIP_CACHE_HEADER, struct{}{})
ctx.DontReadRequestBody()
return types.ActionContinue
}
@@ -185,199 +65,123 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap
return types.ActionContinue
}
if !strings.Contains(contentType, "application/json") {
log.Warnf("content is not json, can't process:%s", contentType)
log.Warnf("content is not json, can't process: %s", contentType)
ctx.DontReadRequestBody()
return types.ActionContinue
}
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
// The request has a body and requires delaying the header transmission until a cache miss occurs,
// at which point the header should be sent.
return types.HeaderStopIteration
}
func TrimQuote(source string) string {
return strings.Trim(source, `"`)
}
func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []byte, log wrapper.Log) types.Action {
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
bodyJson := gjson.ParseBytes(body)
// TODO: It may be necessary to support stream mode determination for different LLM providers.
stream := false
if bodyJson.Get("stream").Bool() {
stream = true
ctx.SetContext(StreamContextKey, struct{}{})
} else if ctx.GetContext(StreamContextKey) != nil {
stream = true
ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{})
}
key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw)
var key string
if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_LAST_QUESTION {
log.Debugf("[onHttpRequestBody] cache key strategy is last question, cache key from: %s", c.CacheKeyFrom)
key = bodyJson.Get(c.CacheKeyFrom).String()
} else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_ALL_QUESTIONS {
log.Debugf("[onHttpRequestBody] cache key strategy is all questions, cache key from: messages")
messages := bodyJson.Get("messages").Array()
var userMessages []string
for _, msg := range messages {
if msg.Get("role").String() == "user" {
userMessages = append(userMessages, msg.Get("content").String())
}
}
key = strings.Join(userMessages, "\n")
} else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_DISABLED {
log.Info("[onHttpRequestBody] cache key strategy is disabled")
ctx.DontReadRequestBody()
return types.ActionContinue
} else {
log.Warnf("[onHttpRequestBody] unknown cache key strategy: %s", c.CacheKeyStrategy)
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key)
log.Debugf("[onHttpRequestBody] key: %s", key)
if key == "" {
log.Debug("parse key from request body failed")
log.Debug("[onHttpRequestBody] parse key from request body failed")
ctx.DontReadResponseBody()
return types.ActionContinue
}
ctx.SetContext(CacheKeyContextKey, key)
err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("redis get key:%s failed, err:%v", key, err)
proxywasm.ResumeHttpRequest()
return
}
if response.IsNull() {
log.Debugf("cache miss, key:%s", key)
proxywasm.ResumeHttpRequest()
return
}
log.Debugf("cache hit, key:%s", key)
ctx.SetContext(CacheKeyContextKey, nil)
if !stream {
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1)
} else {
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1)
}
})
if err != nil {
log.Error("redis access failed")
if err := CheckCacheForKey(key, ctx, c, log, stream, true); err != nil {
log.Errorf("[onHttpRequestBody] check cache for key: %s failed, error: %v", key, err)
return types.ActionContinue
}
return types.ActionPause
}
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
subMessages := strings.Split(sseMessage, "\n")
var message string
for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") {
message = msg
break
}
}
if len(message) < 6 {
log.Errorf("invalid message:%s", message)
return ""
}
// skip the prefix "data:"
bodyJson := message[5:]
if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() {
tempContentI := ctx.GetContext(CacheContentContextKey)
if tempContentI == nil {
content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw)
ctx.SetContext(CacheContentContextKey, content)
return content
}
append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw)
content := tempContentI.(string) + append
ctx.SetContext(CacheContentContextKey, content)
return content
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
// TODO: compatible with other providers
ctx.SetContext(ToolCallsContextKey, struct{}{})
return ""
}
log.Debugf("unknown message:%s", bodyJson)
return ""
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
skipCache := ctx.GetContext(SkipCacheHeader)
func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
skipCache := ctx.GetContext(SKIP_CACHE_HEADER)
if skipCache != nil {
ctx.DontReadResponseBody()
return types.ActionContinue
}
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if strings.Contains(contentType, "text/event-stream") {
ctx.SetContext(StreamContextKey, struct{}{})
ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{})
}
if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil {
ctx.DontReadResponseBody()
return types.ActionContinue
}
return types.ActionContinue
}
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
if ctx.GetContext(ToolCallsContextKey) != nil {
// we should not cache tool call result
return chunk
}
keyI := ctx.GetContext(CacheKeyContextKey)
if keyI == nil {
return chunk
}
if !isLastChunk {
stream := ctx.GetContext(StreamContextKey)
if stream == nil {
tempContentI := ctx.GetContext(CacheContentContextKey)
if tempContentI == nil {
ctx.SetContext(CacheContentContextKey, chunk)
return chunk
}
tempContent := tempContentI.([]byte)
tempContent = append(tempContent, chunk...)
ctx.SetContext(CacheContentContextKey, tempContent)
} else {
var partialMessage []byte
partialMessageI := ctx.GetContext(PartialMessageContextKey)
if partialMessageI != nil {
partialMessage = append(partialMessageI.([]byte), chunk...)
} else {
partialMessage = chunk
}
messages := strings.Split(string(partialMessage), "\n\n")
for i, msg := range messages {
if i < len(messages)-1 {
// process complete message
processSSEMessage(ctx, config, msg, log)
}
}
if !strings.HasSuffix(string(partialMessage), "\n\n") {
ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1]))
} else {
ctx.SetContext(PartialMessageContextKey, nil)
}
}
return chunk
}
// last chunk
key := keyI.(string)
stream := ctx.GetContext(StreamContextKey)
var value string
if stream == nil {
var body []byte
tempContentI := ctx.GetContext(CacheContentContextKey)
if tempContentI != nil {
body = append(tempContentI.([]byte), chunk...)
} else {
body = chunk
}
bodyJson := gjson.ParseBytes(body)
func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
log.Debugf("[onHttpResponseBody] is last chunk: %v", isLastChunk)
log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk))
value = TrimQuote(bodyJson.Get(config.CacheValueFrom.ResponseBody).Raw)
if value == "" {
log.Warnf("parse value from response body failded, body:%s", body)
return chunk
if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil {
return chunk
}
key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY)
if key == nil {
log.Debug("[onHttpResponseBody] key is nil, skip cache")
return chunk
}
if !isLastChunk {
if err := handleNonLastChunk(ctx, c, chunk, log); err != nil {
log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err)
// Set an empty struct in the context to indicate an error in processing the partial message
ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{})
}
return chunk
}
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
var value string
var err error
if stream == nil {
value, err = processNonStreamLastChunk(ctx, c, chunk, log)
} else {
if len(chunk) > 0 {
var lastMessage []byte
partialMessageI := ctx.GetContext(PartialMessageContextKey)
if partialMessageI != nil {
lastMessage = append(partialMessageI.([]byte), chunk...)
} else {
lastMessage = chunk
}
if !strings.HasSuffix(string(lastMessage), "\n\n") {
log.Warnf("invalid lastMessage:%s", lastMessage)
return chunk
}
// remove the last \n\n
lastMessage = lastMessage[:len(lastMessage)-2]
value = processSSEMessage(ctx, config, string(lastMessage), log)
} else {
tempContentI := ctx.GetContext(CacheContentContextKey)
if tempContentI == nil {
return chunk
}
value = tempContentI.(string)
}
value, err = processStreamLastChunk(ctx, c, chunk, log)
}
config.redisClient.Set(config.CacheKeyPrefix+key, value, nil)
if config.CacheTTL != 0 {
config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil)
if err != nil {
log.Errorf("[onHttpResponseBody] process last chunk failed, error: %v", err)
return chunk
}
cacheResponse(ctx, c, key.(string), value, log)
uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log)
return chunk
}

View File

@@ -0,0 +1,155 @@
package main
import (
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
err := error(nil)
if stream == nil {
err = handleNonStreamChunk(ctx, c, chunk, log)
} else {
err = handleStreamChunk(ctx, c, chunk, log)
}
return err
}
func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
if tempContentI == nil {
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk)
return nil
}
tempContent := tempContentI.([]byte)
tempContent = append(tempContent, chunk...)
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent)
return nil
}
func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
log.Debugf("[handleStreamChunk] cache content: %v", ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY))
if partialMessageI != nil {
partialMessage = append(partialMessageI.([]byte), chunk...)
} else {
partialMessage = chunk
}
messages := strings.Split(string(partialMessage), "\n\n")
for i, msg := range messages {
if i < len(messages)-1 {
_, err := processSSEMessage(ctx, c, msg, log)
if err != nil {
return fmt.Errorf("[handleStreamChunk] processSSEMessage failed, error: %v", err)
}
}
}
if !strings.HasSuffix(string(partialMessage), "\n\n") {
ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1]))
} else {
ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil)
}
return nil
}
func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
var body []byte
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
if tempContentI != nil {
body = append(tempContentI.([]byte), chunk...)
} else {
body = chunk
}
bodyJson := gjson.ParseBytes(body)
value := bodyJson.Get(c.CacheValueFrom).String()
if strings.TrimSpace(value) == "" {
return "", fmt.Errorf("[processNonStreamLastChunk] parse value from response body failed, body:%s", body)
}
return value, nil
}
func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
if len(chunk) > 0 {
var lastMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
if partialMessageI != nil {
lastMessage = append(partialMessageI.([]byte), chunk...)
} else {
lastMessage = chunk
}
if !strings.HasSuffix(string(lastMessage), "\n\n") {
return "", fmt.Errorf("[processStreamLastChunk] invalid lastMessage:%s", lastMessage)
}
lastMessage = lastMessage[:len(lastMessage)-2]
value, err := processSSEMessage(ctx, c, string(lastMessage), log)
if err != nil {
return "", fmt.Errorf("[processStreamLastChunk] processSSEMessage failed, error: %v", err)
}
return value, nil
}
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
if tempContentI == nil {
return "", nil
}
return tempContentI.(string), nil
}
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) {
subMessages := strings.Split(sseMessage, "\n")
var message string
for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") {
message = msg
break
}
}
if len(message) < 6 {
return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message)
}
// skip the prefix "data:"
bodyJson := message[5:]
if strings.TrimSpace(bodyJson) == "[DONE]" {
return "", nil
}
// Extract values from JSON fields
responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
if toolCalls.Exists() {
// TODO: Temporarily store the tool_calls value in the context for processing
ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
}
// Check if the ResponseBody field exists
if !responseBody.Exists() {
if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
return "", nil
}
return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
} else {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
// If there is no content in the cache, initialize and set the content
if tempContentI == nil {
content := responseBody.String()
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
return content, nil
}
// Update the content in the cache
appendMsg := responseBody.String()
content := tempContentI.(string) + appendMsg
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
return content, nil
}
}

View File

@@ -0,0 +1,256 @@
package vector
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type dashVectorProviderInitializer struct {
}
func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.apiKey) == 0 {
return errors.New("[DashVector] apiKey is required")
}
if len(config.collectionID) == 0 {
return errors.New("[DashVector] collectionID is required")
}
if len(config.serviceName) == 0 {
return errors.New("[DashVector] serviceName is required")
}
if len(config.serviceHost) == 0 {
return errors.New("[DashVector] serviceHost is required")
}
return nil
}
func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &DvProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type DvProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (d *DvProvider) GetProviderType() string {
return PROVIDER_TYPE_DASH_VECTOR
}
// type embeddingRequest struct {
// Model string `json:"model"`
// Input input `json:"input"`
// Parameters params `json:"parameters"`
// }
// type params struct {
// TextType string `json:"text_type"`
// }
// type input struct {
// Texts []string `json:"texts"`
// }
// queryResponse 定义查询响应的结构
type queryResponse struct {
Code int `json:"code"`
RequestID string `json:"request_id"`
Message string `json:"message"`
Output []result `json:"output"`
}
// queryRequest 定义查询请求的结构
type queryRequest struct {
Vector []float64 `json:"vector"`
TopK int `json:"topk"`
IncludeVector bool `json:"include_vector"`
}
// result 定义查询结果的结构
type result struct {
ID string `json:"id"`
Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化
Fields map[string]interface{} `json:"fields"`
Score float64 `json:"score"`
}
func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) {
url := fmt.Sprintf("/v1/collections/%s/query", d.config.collectionID)
requestData := queryRequest{
Vector: vector,
TopK: d.config.topK,
IncludeVector: false,
}
requestBody, err := json.Marshal(requestData)
if err != nil {
return "", nil, nil, err
}
header := [][2]string{
{"Content-Type", "application/json"},
{"dashvector-auth-token", d.config.apiKey},
}
return url, requestBody, header, nil
}
func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, error) {
var queryResp queryResponse
err := json.Unmarshal(responseBody, &queryResp)
if err != nil {
return queryResponse{}, err
}
return queryResp, nil
}
func (d *DvProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
url, body, headers, err := d.constructEmbeddingQueryParameters(emb)
log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers)
if err != nil {
err = fmt.Errorf("failed to construct embedding query parameters: %v", err)
return err
}
err = d.client.Post(url, headers, body,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
err = nil
if statusCode != http.StatusOK {
err = fmt.Errorf("failed to query embedding: %d", statusCode)
callback(nil, ctx, log, err)
return
}
log.Debugf("query embedding response: %d, %s", statusCode, responseBody)
results, err := d.ParseQueryResponse(responseBody, ctx, log)
if err != nil {
err = fmt.Errorf("failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout)
if err != nil {
err = fmt.Errorf("failed to query embedding: %v", err)
}
return err
}
func getStringValue(fields map[string]interface{}, key string) string {
if val, ok := fields[key]; ok {
return val.(string)
}
return ""
}
func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) {
resp, err := d.parseQueryResponse(responseBody)
if err != nil {
return nil, err
}
if len(resp.Output) == 0 {
return nil, errors.New("no query results found in response")
}
results := make([]QueryResult, 0, len(resp.Output))
for _, output := range resp.Output {
result := QueryResult{
Text: getStringValue(output.Fields, "query"),
Embedding: output.Vector,
Score: output.Score,
Answer: getStringValue(output.Fields, "answer"),
}
results = append(results, result)
}
return results, nil
}
type document struct {
Vector []float64 `json:"vector"`
Fields map[string]string `json:"fields"`
}
type insertRequest struct {
Docs []document `json:"docs"`
}
func (d *DvProvider) constructUploadParameters(emb []float64, queryString string, answer string) (string, []byte, [][2]string, error) {
url := "/v1/collections/" + d.config.collectionID + "/docs"
doc := document{
Vector: emb,
Fields: map[string]string{
"query": queryString,
"answer": answer,
},
}
requestBody, err := json.Marshal(insertRequest{Docs: []document{doc}})
if err != nil {
return "", nil, nil, err
}
header := [][2]string{
{"Content-Type", "application/json"},
{"dashvector-auth-token", d.config.apiKey},
}
return url, requestBody, header, err
}
func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "")
if err != nil {
return err
}
err = d.client.Post(
url,
headers,
body,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody))
if statusCode != http.StatusOK {
err = fmt.Errorf("failed to upload embedding: %d", statusCode)
}
callback(ctx, log, err)
},
d.config.timeout)
return err
}
func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer)
if err != nil {
return err
}
err = d.client.Post(
url,
headers,
body,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody))
if statusCode != http.StatusOK {
err = fmt.Errorf("failed to upload embedding: %d", statusCode)
}
callback(ctx, log, err)
},
d.config.timeout)
return err
}

View File

@@ -0,0 +1,167 @@
package vector
import (
"errors"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
PROVIDER_TYPE_DASH_VECTOR = "dashvector"
PROVIDER_TYPE_CHROMA = "chroma"
)
type providerInitializer interface {
ValidateConfig(ProviderConfig) error
CreateProvider(ProviderConfig) (Provider, error)
}
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{},
// PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{},
}
)
// QueryResult 定义通用的查询结果的结构体
type QueryResult struct {
Text string // 相似的文本
Embedding []float64 // 相似文本的向量
Score float64 // 文本的向量相似度或距离等度量
Answer string // 相似文本对应的LLM生成的回答
}
type Provider interface {
GetProviderType() string
}
type EmbeddingQuerier interface {
QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
}
type EmbeddingUploader interface {
UploadEmbedding(
queryString string,
queryEmb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
}
type AnswerAndEmbeddingUploader interface {
UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
answer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
}
type StringQuerier interface {
QueryString(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
}
type SimilarityThresholdProvider interface {
GetSimilarityThreshold() float64
}
type ProviderConfig struct {
// @Title zh-CN 向量存储服务提供者类型
// @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma
typ string
// @Title zh-CN 向量存储服务名称
// @Description zh-CN 向量存储服务名称
serviceName string
// @Title zh-CN 向量存储服务域名
// @Description zh-CN 向量存储服务域名
serviceHost string
// @Title zh-CN 向量存储服务端口
// @Description zh-CN 向量存储服务端口
servicePort int64
// @Title zh-CN 向量存储服务 API Key
// @Description zh-CN 向量存储服务 API Key
apiKey string
// @Title zh-CN 返回TopK结果
// @Description zh-CN 返回TopK结果默认为 1
topK int
// @Title zh-CN 请求超时
// @Description zh-CN 请求向量存储服务的超时时间单位为毫秒。默认值是10000即10秒
timeout uint32
// @Title zh-CN DashVector 向量存储服务 Collection ID
// @Description zh-CN DashVector 向量存储服务 Collection ID
collectionID string
// @Title zh-CN 相似度度量阈值
// @Description zh-CN 默认相似度度量阈值,默认为 1000。
Threshold float64
// @Title zh-CN 相似度度量比较方式
// @Description zh-CN 相似度度量比较方式,默认为小于。
// 相似度度量方式有 Cosine, DotProduct, Euclidean 等,前两者值越大相似度越高,后者值越小相似度越高。
// 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt对于 Euclidean 则选择 lt。
// 默认为 lt所有条件包括 lt (less than小于)、lte (less than or equal to小等于)、gt (greater than大于)、gte (greater than or equal to大等于)
ThresholdRelation string
}
func (c *ProviderConfig) GetProviderType() string {
return c.typ
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
// DashVector
c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = int64(json.Get("servicePort").Int())
if c.servicePort == 0 {
c.servicePort = 443
}
c.apiKey = json.Get("apiKey").String()
c.collectionID = json.Get("collectionID").String()
c.topK = int(json.Get("topK").Int())
if c.topK == 0 {
c.topK = 1
}
c.timeout = uint32(json.Get("timeout").Int())
if c.timeout == 0 {
c.timeout = 10000
}
c.Threshold = json.Get("threshold").Float()
if c.Threshold == 0 {
c.Threshold = 1000
}
c.ThresholdRelation = json.Get("thresholdRelation").String()
if c.ThresholdRelation == "" {
c.ThresholdRelation = "lt"
}
}
func (c *ProviderConfig) Validate() error {
if c.typ == "" {
return errors.New("vector database service is required")
}
initializer, has := providerInitializers[c.typ]
if !has {
return errors.New("unknown vector database service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
return err
}
return nil
}
func CreateProvider(pc ProviderConfig) (Provider, error) {
initializer, has := providerInitializers[pc.typ]
if !has {
return nil, errors.New("unknown provider type: " + pc.typ)
}
return initializer.CreateProvider(pc)
}

View File

@@ -0,0 +1,4 @@
.DEFAULT:
build:
tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go
mv ai-proxy.wasm ../../../../docker-compose-test/

View File

@@ -10,7 +10,7 @@ require (
github.com/alibaba/higress/plugins/wasm-go v0.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.14.3
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -13,8 +13,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -82,7 +82,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
providerConfig := pluginConfig.GetProviderConfig()
if apiName == "" && !providerConfig.IsOriginal() {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
_ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
return types.ActionContinue
}
ctx.SetContext(ctxKeyApiName, apiName)

View File

@@ -0,0 +1,2 @@
FROM scratch
COPY main.wasm plugin.wasm

View File

@@ -0,0 +1,4 @@
.DEFAULT:
build:
tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer' ./main.go
mv main.wasm ../../../../docker-compose-test/

View File

@@ -177,7 +177,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config RequestBlockConfig, lo
}
func onHttpRequestBody(ctx wrapper.HttpContext, config RequestBlockConfig, body []byte, log wrapper.Log) types.Action {
log.Infof("My request-block body: %s\n", string(body))
bodyStr := string(body)
if !config.caseSensitive {
bodyStr = strings.ToLower(bodyStr)
}

View File

@@ -7,7 +7,7 @@ require (
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.14.3
github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1
)

View File

@@ -4,6 +4,12 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43 h1:dCw7F/9ciw4NZN7w68wQRaygZ2zGOWMTIEoRvP1tlWs=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
@@ -14,6 +20,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -235,10 +235,11 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis
func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
args := make([]interface{}, 0)
args = append(args, "setex")
args = append(args, "set")
args = append(args, key)
args = append(args, ttl)
args = append(args, value)
args = append(args, "ex")
args = append(args, ttl)
return RedisCall(c.cluster, respString(args), callback)
}

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tests
import (
"testing"
"github.com/alibaba/higress/test/e2e/conformance/utils/http"
"github.com/alibaba/higress/test/e2e/conformance/utils/suite"
)
func init() {
Register(WasmPluginsAiCache)
}
var WasmPluginsAiCache = suite.ConformanceTest{
ShortName: "WasmPluginAiCache",
Description: "The Ingress in the higress-conformance-infra namespace test the ai-cache WASM plugin.",
Features: []suite.SupportedFeature{suite.WASMGoConformanceFeature},
Manifests: []string{"tests/go-wasm-ai-cache.yaml"},
Test: func(t *testing.T, suite *suite.ConformanceTestSuite) {
testcases := []http.Assertion{
{
Meta: http.AssertionMeta{
TestCaseName: "case 1: basic",
TargetBackend: "infra-backend-v1",
TargetNamespace: "higress-conformance-infra",
},
Request: http.AssertionRequest{
ActualRequest: http.Request{
Host: "dashscope.aliyuncs.com",
Path: "/v1/chat/completions",
Method: "POST",
ContentType: http.ContentTypeApplicationJson,
Body: []byte(`{
"model": "qwen-long",
"messages": [{"role":"user","content":"hi"}]}`),
},
ExpectedRequest: &http.ExpectedRequest{
Request: http.Request{
Host: "dashscope.aliyuncs.com",
Path: "/compatible-mode/v1/chat/completions",
Method: "POST",
ContentType: http.ContentTypeApplicationJson,
Body: []byte(`{
"model": "qwen-long",
"messages": [{"role":"user","content":"hi"}]}`),
},
},
},
Response: http.AssertionResponse{
ExpectedResponse: http.Response{
StatusCode: 200,
},
},
},
}
t.Run("WasmPlugins ai-cache", func(t *testing.T) {
for _, testcase := range testcases {
http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, suite.GatewayAddress, testcase)
}
})
},
}

View File

@@ -0,0 +1,103 @@
# Copyright (c) 2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
annotations:
name: wasmplugin-ai-cache-openai
namespace: higress-conformance-infra
spec:
ingressClassName: higress
rules:
- host: "dashscope.aliyuncs.com"
http:
paths:
- pathType: Prefix
path: "/"
backend:
service:
name: infra-backend-v1
port:
number: 8080
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
annotations:
name: wasmplugin-ai-cache-qwen
namespace: higress-conformance-infra
spec:
ingressClassName: higress
rules:
- host: "qwen.ai.com"
http:
paths:
- pathType: Prefix
path: "/"
backend:
service:
name: infra-backend-v1
port:
number: 8080
---
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: ai-cache
namespace: higress-system
spec:
priority: 400
matchRules:
- config:
embedding:
type: "dashscope"
serviceName: "qwen"
apiKey: "{{secret.qwenApiKey}}"
timeout: 12000
vector:
type: "dashvector"
serviceName: "dashvector"
collectionID: "{{secret.collectionID}}"
serviceDomain: "{{secret.serviceDomain}}"
apiKey: "{{secret.apiKey}}"
timeout: 12000
cache:
ingress:
- higress-conformance-infra/wasmplugin-ai-cache-openai
- higress-conformance-infra/wasmplugin-ai-cache-qwen
# url: file:///opt/plugins/wasm-go/extensions/ai-cache/plugin.wasm
url: oci://registry.cn-shanghai.aliyuncs.com/suchunsv/higress_ai:1.18
---
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: ai-proxy
namespace: higress-system
spec:
priority: 201
matchRules:
- config:
provider:
type: "qwen"
qwenEnableCompatible: true
apiTokens:
- "{{secret.qwenApiKey}}"
timeout: 1200000
modelMapping:
"*": "qwen-long"
ingress:
- higress-conformance-infra/wasmplugin-ai-cache-openai
- higress-conformance-infra/wasmplugin-ai-cache-qwen
url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-proxy:1.0.0