[ai-cache] Implement a WASM plugin for LLM result retrieval based on vector similarity (#1290)

This commit is contained in:
Yang Beining
2024-10-27 08:21:04 +00:00
committed by GitHub
parent d309bf2e25
commit acec48ed8b
27 changed files with 2025 additions and 346 deletions

View File

@@ -0,0 +1,225 @@
package config
import (
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
CACHE_KEY_STRATEGY_LAST_QUESTION = "lastQuestion"
CACHE_KEY_STRATEGY_ALL_QUESTIONS = "allQuestions"
CACHE_KEY_STRATEGY_DISABLED = "disabled"
)
type PluginConfig struct {
// @Title zh-CN 返回 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
ResponseTemplate string
// @Title zh-CN 返回流式 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
StreamResponseTemplate string
cacheProvider cache.Provider
embeddingProvider embedding.Provider
vectorProvider vector.Provider
embeddingProviderConfig embedding.ProviderConfig
vectorProviderConfig vector.ProviderConfig
cacheProviderConfig cache.ProviderConfig
CacheKeyFrom string
CacheValueFrom string
CacheStreamValueFrom string
CacheToolCallsFrom string
// @Title zh-CN 启用语义化缓存
// @Description zh-CN 控制是否启用语义化缓存功能。true 表示启用false 表示禁用。
EnableSemanticCache bool
// @Title zh-CN 缓存键策略
// @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存)
CacheKeyStrategy string
}
func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
c.vectorProviderConfig.FromJson(json.Get("vector"))
c.embeddingProviderConfig.FromJson(json.Get("embedding"))
c.cacheProviderConfig.FromJson(json.Get("cache"))
if json.Get("redis").Exists() {
// compatible with legacy config
c.cacheProviderConfig.ConvertLegacyJson(json)
}
c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String()
if c.CacheKeyStrategy == "" {
c.CacheKeyStrategy = CACHE_KEY_STRATEGY_LAST_QUESTION // set default value
}
c.CacheKeyFrom = json.Get("cacheKeyFrom").String()
if c.CacheKeyFrom == "" {
c.CacheKeyFrom = "messages.@reverse.0.content"
}
c.CacheValueFrom = json.Get("cacheValueFrom").String()
if c.CacheValueFrom == "" {
c.CacheValueFrom = "choices.0.message.content"
}
c.CacheStreamValueFrom = json.Get("cacheStreamValueFrom").String()
if c.CacheStreamValueFrom == "" {
c.CacheStreamValueFrom = "choices.0.delta.content"
}
c.CacheToolCallsFrom = json.Get("cacheToolCallsFrom").String()
if c.CacheToolCallsFrom == "" {
c.CacheToolCallsFrom = "choices.0.delta.content.tool_calls"
}
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
if c.StreamResponseTemplate == "" {
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
}
c.ResponseTemplate = json.Get("responseTemplate").String()
if c.ResponseTemplate == "" {
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
}
if json.Get("enableSemanticCache").Exists() {
c.EnableSemanticCache = json.Get("enableSemanticCache").Bool()
} else {
c.EnableSemanticCache = true // set default value to true
}
// compatible with legacy config
convertLegacyMapFields(c, json, log)
}
func (c *PluginConfig) Validate() error {
// if cache provider is configured, validate it
if c.cacheProviderConfig.GetProviderType() != "" {
if err := c.cacheProviderConfig.Validate(); err != nil {
return err
}
}
if c.embeddingProviderConfig.GetProviderType() != "" {
if err := c.embeddingProviderConfig.Validate(); err != nil {
return err
}
}
if c.vectorProviderConfig.GetProviderType() != "" {
if err := c.vectorProviderConfig.Validate(); err != nil {
return err
}
}
// cache, vector, and embedding cannot all be empty
if c.vectorProviderConfig.GetProviderType() == "" &&
c.embeddingProviderConfig.GetProviderType() == "" &&
c.cacheProviderConfig.GetProviderType() == "" {
return fmt.Errorf("vector, embedding and cache provider cannot be all empty")
}
// Validate the value of CacheKeyStrategy
if c.CacheKeyStrategy != CACHE_KEY_STRATEGY_LAST_QUESTION &&
c.CacheKeyStrategy != CACHE_KEY_STRATEGY_ALL_QUESTIONS &&
c.CacheKeyStrategy != CACHE_KEY_STRATEGY_DISABLED {
return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy)
}
// If semantic cache is enabled, ensure necessary components are configured
// if c.EnableSemanticCache {
// if c.embeddingProviderConfig.GetProviderType() == "" {
// return fmt.Errorf("semantic cache is enabled but embedding provider is not configured")
// }
// // if only configure cache, just warn the user
// }
return nil
}
func (c *PluginConfig) Complete(log wrapper.Log) error {
var err error
if c.embeddingProviderConfig.GetProviderType() != "" {
log.Debugf("embedding provider is set to %s", c.embeddingProviderConfig.GetProviderType())
c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig)
if err != nil {
return err
}
} else {
log.Info("embedding provider is not configured")
c.embeddingProvider = nil
}
if c.cacheProviderConfig.GetProviderType() != "" {
log.Debugf("cache provider is set to %s", c.cacheProviderConfig.GetProviderType())
c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig)
if err != nil {
return err
}
} else {
log.Info("cache provider is not configured")
c.cacheProvider = nil
}
if c.vectorProviderConfig.GetProviderType() != "" {
log.Debugf("vector provider is set to %s", c.vectorProviderConfig.GetProviderType())
c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig)
if err != nil {
return err
}
} else {
log.Info("vector provider is not configured")
c.vectorProvider = nil
}
return nil
}
func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider {
return c.embeddingProvider
}
func (c *PluginConfig) GetVectorProvider() vector.Provider {
return c.vectorProvider
}
func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig {
return c.vectorProviderConfig
}
func (c *PluginConfig) GetCacheProvider() cache.Provider {
return c.cacheProvider
}
func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) {
keyMap := map[string]string{
"cacheKeyFrom.requestBody": "cacheKeyFrom",
"cacheValueFrom.requestBody": "cacheValueFrom",
"cacheStreamValueFrom.requestBody": "cacheStreamValueFrom",
"returnResponseTemplate": "responseTemplate",
"returnStreamResponseTemplate": "streamResponseTemplate",
}
for oldKey, newKey := range keyMap {
if json.Get(oldKey).Exists() {
log.Debugf("[convertLegacyMapFields] mapping %s to %s", oldKey, newKey)
setField(c, newKey, json.Get(oldKey).String(), log)
} else {
log.Debugf("[convertLegacyMapFields] %s not exists", oldKey)
}
}
}
func setField(c *PluginConfig, fieldName string, value string, log wrapper.Log) {
switch fieldName {
case "cacheKeyFrom":
c.CacheKeyFrom = value
case "cacheValueFrom":
c.CacheValueFrom = value
case "cacheStreamValueFrom":
c.CacheStreamValueFrom = value
case "responseTemplate":
c.ResponseTemplate = value
case "streamResponseTemplate":
c.StreamResponseTemplate = value
}
log.Debugf("[setField] set %s to %s", fieldName, value)
}