// File generated by hgctl. Modify as required. // See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 package main import ( "errors" "fmt" "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/resp" ) const ( CacheKeyContextKey = "cacheKey" CacheContentContextKey = "cacheContent" PartialMessageContextKey = "partialMessage" ToolCallsContextKey = "toolCalls" StreamContextKey = "stream" DefaultCacheKeyPrefix = "higress-ai-cache:" ) func main() { wrapper.SetCtx( "ai-cache", wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBodyBy(onHttpResponseBody), ) } // @Name ai-cache // @Category protocol // @Phase AUTHN // @Priority 10 // @Title zh-CN AI Cache // @Description zh-CN 大模型结果缓存 // @IconUrl // @Version 0.1.0 // // @Contact.name johnlanni // @Contact.url // @Contact.email // // @Example // redis: // serviceName: my-redis.dns // timeout: 2000 // cacheKeyFrom: // requestBody: "messages.@reverse.0.content" // cacheValueFrom: // responseBody: "choices.0.message.content" // cacheStreamValueFrom: // responseBody: "choices.0.delta.content" // returnResponseTemplate: | // {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} // returnStreamResponseTemplate: | // data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} // // data:[DONE] // // @End type RedisInfo struct { // @Title zh-CN redis 服务名称 // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` // @Title zh-CN redis 服务端口 // @Description zh-CN 默认值为6379 ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` // @Title zh-CN 用户名 // @Description zh-CN 登陆 redis 的用户名,非必填 Username string `required:"false" yaml:"username" json:"username"` // @Title zh-CN 密码 // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 Password string `required:"false" yaml:"password" json:"password"` // @Title zh-CN 请求超时 // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 Timeout int `required:"false" yaml:"timeout" json:"timeout"` } type KVExtractor struct { // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` } type PluginConfig struct { // @Title zh-CN Redis 地址信息 // @Description zh-CN 用于存储缓存结果的 Redis 地址 RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` // @Title zh-CN 缓存 key 的来源 // @Description zh-CN 往 redis 里存时,使用的 key 的提取方式 CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` // @Title zh-CN 缓存 value 的来源 // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` // @Title zh-CN 流式响应下,缓存 value 的来源 // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` // @Title zh-CN 返回 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` // @Title zh-CN 返回流式 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` // @Title zh-CN 缓存的过期时间 // @Description zh-CN 单位是秒,默认值为0,即永不过期 CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` // @Title zh-CN Redis缓存Key的前缀 // @Description zh-CN 默认值是"higress-ai-cache:" CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` redisClient wrapper.RedisClient `yaml:"-" json:"-"` } func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() if c.RedisInfo.ServiceName == "" { return errors.New("redis service name must not by empty") } c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) if c.RedisInfo.ServicePort == 0 { if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { // use default logic port which is 80 for static service c.RedisInfo.ServicePort = 80 } else { c.RedisInfo.ServicePort = 6379 } } c.RedisInfo.Username = json.Get("redis.username").String() c.RedisInfo.Password = json.Get("redis.password").String() c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) if c.RedisInfo.Timeout == 0 { c.RedisInfo.Timeout = 1000 } c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() if c.CacheKeyFrom.RequestBody == "" { c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" } c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() if c.CacheValueFrom.ResponseBody == "" { c.CacheValueFrom.ResponseBody = "choices.0.message.content" } c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() if c.CacheStreamValueFrom.ResponseBody == "" { c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" } c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() if c.ReturnResponseTemplate == "" { c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() if c.ReturnStreamResponseTemplate == "" { c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() if c.CacheKeyPrefix == "" { c.CacheKeyPrefix = DefaultCacheKeyPrefix } c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ FQDN: c.RedisInfo.ServiceName, Port: int64(c.RedisInfo.ServicePort), }) return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout)) } func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpRequestHeader("content-type") // The request does not have a body. if contentType == "" { return types.ActionContinue } if !strings.Contains(contentType, "application/json") { log.Warnf("content is not json, can't process:%s", contentType) ctx.DontReadRequestBody() return types.ActionContinue } proxywasm.RemoveHttpRequestHeader("Accept-Encoding") // The request has a body and requires delaying the header transmission until a cache miss occurs, // at which point the header should be sent. return types.HeaderStopIteration } func TrimQuote(source string) string { return strings.Trim(source, `"`) } func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) // TODO: It may be necessary to support stream mode determination for different LLM providers. stream := false if bodyJson.Get("stream").Bool() { stream = true ctx.SetContext(StreamContextKey, struct{}{}) } else if ctx.GetContext(StreamContextKey) != nil { stream = true } key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) if key == "" { log.Debug("parse key from request body failed") return types.ActionContinue } ctx.SetContext(CacheKeyContextKey, key) err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) { if err := response.Error(); err != nil { log.Errorf("redis get key:%s failed, err:%v", key, err) proxywasm.ResumeHttpRequest() return } if response.IsNull() { log.Debugf("cache miss, key:%s", key) proxywasm.ResumeHttpRequest() return } log.Debugf("cache hit, key:%s", key) ctx.SetContext(CacheKeyContextKey, nil) if !stream { proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) } else { proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) } }) if err != nil { log.Error("redis access failed") return types.ActionContinue } return types.ActionPause } func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { subMessages := strings.Split(sseMessage, "\n") var message string for _, msg := range subMessages { if strings.HasPrefix(msg, "data:") { message = msg break } } if len(message) < 6 { log.Errorf("invalid message:%s", message) return "" } // skip the prefix "data:" bodyJson := message[5:] if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() { tempContentI := ctx.GetContext(CacheContentContextKey) if tempContentI == nil { content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) ctx.SetContext(CacheContentContextKey, content) return content } append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) content := tempContentI.(string) + append ctx.SetContext(CacheContentContextKey, content) return content } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { // TODO: compatible with other providers ctx.SetContext(ToolCallsContextKey, struct{}{}) return "" } log.Debugf("unknown message:%s", bodyJson) return "" } func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(StreamContextKey, struct{}{}) } return types.ActionContinue } func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { if ctx.GetContext(ToolCallsContextKey) != nil { // we should not cache tool call result return chunk } keyI := ctx.GetContext(CacheKeyContextKey) if keyI == nil { return chunk } if !isLastChunk { stream := ctx.GetContext(StreamContextKey) if stream == nil { tempContentI := ctx.GetContext(CacheContentContextKey) if tempContentI == nil { ctx.SetContext(CacheContentContextKey, chunk) return chunk } tempContent := tempContentI.([]byte) tempContent = append(tempContent, chunk...) ctx.SetContext(CacheContentContextKey, tempContent) } else { var partialMessage []byte partialMessageI := ctx.GetContext(PartialMessageContextKey) if partialMessageI != nil { partialMessage = append(partialMessageI.([]byte), chunk...) } else { partialMessage = chunk } messages := strings.Split(string(partialMessage), "\n\n") for i, msg := range messages { if i < len(messages)-1 { // process complete message processSSEMessage(ctx, config, msg, log) } } if !strings.HasSuffix(string(partialMessage), "\n\n") { ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1])) } else { ctx.SetContext(PartialMessageContextKey, nil) } } return chunk } // last chunk key := keyI.(string) stream := ctx.GetContext(StreamContextKey) var value string if stream == nil { var body []byte tempContentI := ctx.GetContext(CacheContentContextKey) if tempContentI != nil { body = append(tempContentI.([]byte), chunk...) } else { body = chunk } bodyJson := gjson.ParseBytes(body) value = TrimQuote(bodyJson.Get(config.CacheValueFrom.ResponseBody).Raw) if value == "" { log.Warnf("parse value from response body failded, body:%s", body) return chunk } } else { if len(chunk) > 0 { var lastMessage []byte partialMessageI := ctx.GetContext(PartialMessageContextKey) if partialMessageI != nil { lastMessage = append(partialMessageI.([]byte), chunk...) } else { lastMessage = chunk } if !strings.HasSuffix(string(lastMessage), "\n\n") { log.Warnf("invalid lastMessage:%s", lastMessage) return chunk } // remove the last \n\n lastMessage = lastMessage[:len(lastMessage)-2] value = processSSEMessage(ctx, config, string(lastMessage), log) } else { tempContentI := ctx.GetContext(CacheContentContextKey) if tempContentI == nil { return chunk } value = tempContentI.(string) } } config.redisClient.Set(config.CacheKeyPrefix+key, value, nil) if config.CacheTTL != 0 { config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil) } return chunk }