// File generated by hgctl. Modify as required. // See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 package main import ( "encoding/json" "errors" "fmt" "net/url" "strconv" "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" "github.com/tidwall/resp" ) const ( QuestionContextKey = "question" AnswerContentContextKey = "answer" PartialMessageContextKey = "partialMessage" ToolCallsContextKey = "toolCalls" StreamContextKey = "stream" DefaultCacheKeyPrefix = "higress-ai-history:" IdentityKey = "identity" ChatHistories = "chatHistories" ) func main() { wrapper.SetCtx( "ai-history", wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders), wrapper.ProcessStreamingResponseBodyBy(onHttpStreamResponseBody), ) } // @Name ai-history // @Category protocol // @Phase AUTHN // @Priority 10 // @Title zh-CN AI History // @Description zh-CN 大模型对话历史缓存 // @IconUrl // @Version 0.1.0 // // @Contact.name sakura // @Contact.url // @Contact.email // // @Example // redis: // serviceName: my-redis.dns // timeout: 2000 // // @End type RedisInfo struct { // @Title zh-CN redis 服务名称 // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` // @Title zh-CN redis 服务端口 // @Description zh-CN 默认值为6379 ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` // @Title zh-CN 用户名 // @Description zh-CN 登陆 redis 的用户名,非必填 Username string `required:"false" yaml:"username" json:"username"` // @Title zh-CN 密码 // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 Password string `required:"false" yaml:"password" json:"password"` // @Title zh-CN 请求超时 // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 Timeout int `required:"false" yaml:"timeout" json:"timeout"` // @Title zh-CN Database // @Description zh-CN redis database Database int `required:"false" yaml:"database" json:"database"` } type KVExtractor struct { // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` } type PluginConfig struct { // @Title zh-CN Redis 地址信息 // @Description zh-CN 用于存储缓存结果的 Redis 地址 RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` // @Title zh-CN 缓存 key 的来源 // @Description zh-CN 往 redis 里存时,问题的提取方式 QuestionFrom KVExtractor `required:"true" yaml:"questionFrom" json:"questionFrom"` // @Title zh-CN 缓存 value 的来源 // @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式 AnswerValueFrom KVExtractor `required:"true" yaml:"answerValueFrom" json:"answerValueFrom"` // @Title zh-CN 流式响应下,缓存 value 的来源 // @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式 AnswerStreamValueFrom KVExtractor `required:"true" yaml:"answerStreamValueFrom" json:"answerStreamValueFrom"` // @Title zh-CN Redis缓存Key的前缀 // @Description zh-CN 默认值是"higress-ai-cache:" CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` // @Title zh-CN 身份解析方式 // @Description zh-CN 默认值是"Authorization" IdentityHeader string `required:"false" yaml:"identityHeader" json:"identityHeader"` // @Title zh-CN 默认填充历史对话轮数 // @Description zh-CN 默认值是 3 FillHistoryCnt int `required:"false" yaml:"fillHistoryCnt" json:"fillHistoryCnt"` // @Title zh-CN 缓存的过期时间 // @Description zh-CN 单位是秒,默认值为0,即永不过期 CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` redisClient wrapper.RedisClient `yaml:"-" json:"-"` } type ChatHistory struct { Role string `json:"role"` Content string `json:"content"` } func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() if c.RedisInfo.ServiceName == "" { return errors.New("redis service name must not be empty") } c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) if c.RedisInfo.ServicePort == 0 { if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { // use default logic port which is 80 for static service c.RedisInfo.ServicePort = 80 } else { c.RedisInfo.ServicePort = 6379 } } c.RedisInfo.Username = json.Get("redis.username").String() c.RedisInfo.Password = json.Get("redis.password").String() c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) if c.RedisInfo.Timeout == 0 { c.RedisInfo.Timeout = 1000 } c.RedisInfo.Database = int(json.Get("redis.database").Int()) c.QuestionFrom.RequestBody = "messages.@reverse.0.content" c.AnswerValueFrom.ResponseBody = "choices.0.message.content" c.AnswerStreamValueFrom.ResponseBody = "choices.0.delta.content" c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() if c.CacheKeyPrefix == "" { c.CacheKeyPrefix = DefaultCacheKeyPrefix } c.IdentityHeader = json.Get("identityHeader").String() if c.IdentityHeader == "" { c.IdentityHeader = "Authorization" } c.FillHistoryCnt = int(json.Get("fillHistoryCnt").Int()) if c.FillHistoryCnt == 0 { c.FillHistoryCnt = 3 } c.CacheTTL = int(json.Get("cacheTTL").Int()) c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ FQDN: c.RedisInfo.ServiceName, Port: int64(c.RedisInfo.ServicePort), }) return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout), wrapper.WithDataBase(c.RedisInfo.Database)) } func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpRequestHeader("content-type") if !strings.Contains(contentType, "application/json") { log.Warnf("content is not json, can't process:%s", contentType) ctx.DontReadRequestBody() return types.ActionContinue } // get identity key identityKey, _ := proxywasm.GetHttpRequestHeader(config.IdentityHeader) if identityKey == "" { log.Warnf("identity key is empty") return types.ActionContinue } identityKey = strings.ReplaceAll(identityKey, " ", "") ctx.SetContext(IdentityKey, identityKey) _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Content-Length") // The request has a body and requires delaying the header transmission until a cache miss occurs, // at which point the header should be sent. return types.HeaderStopIteration } func TrimQuote(source string) string { return strings.Trim(source, `"`) } func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) if bodyJson.Get("stream").Bool() { ctx.SetContext(StreamContextKey, struct{}{}) } identityKey := ctx.GetStringContext(IdentityKey, "") question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String()) if question == "" { log.Debug("parse question from request body failed") return types.ActionContinue } ctx.SetContext(QuestionContextKey, question) err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) { if err := response.Error(); err != nil { log.Errorf("redis get failed, err:%v", err) _ = proxywasm.ResumeHttpRequest() return } if response.IsNull() { log.Debugf("cache miss, identityKey:%s", identityKey) _ = proxywasm.ResumeHttpRequest() return } chatHistories := response.String() ctx.SetContext(ChatHistories, chatHistories) var chat []ChatHistory err := json.Unmarshal([]byte(chatHistories), &chat) if err != nil { log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err) _ = proxywasm.ResumeHttpRequest() return } path := ctx.Path() if isQueryHistory(path) { cnt := getIntQueryParameter("cnt", path, len(chat)/2) * 2 if cnt > len(chat) { cnt = len(chat) } chat = chat[len(chat)-cnt:] res, err := json.Marshal(chat) if err != nil { log.Errorf("marshal chat:%v failed, err:%v", chat, err) _ = proxywasm.ResumeHttpRequest() return } _ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1) return } fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2 currJson := bodyJson.Get("messages").String() var currMessage []ChatHistory err = json.Unmarshal([]byte(currJson), &currMessage) if err != nil { log.Errorf("unmarshal currMessage:%s failed, err:%v", currJson, err) _ = proxywasm.ResumeHttpRequest() return } finalChat := fillHistory(chat, currMessage, fillHistoryCnt) var parameter map[string]any err = json.Unmarshal(body, ¶meter) if err != nil { log.Errorf("unmarshal body:%s failed, err:%v", body, err) _ = proxywasm.ResumeHttpRequest() return } parameter["messages"] = finalChat parameterJson, err := json.Marshal(parameter) if err != nil { log.Errorf("marshal parameter:%v failed, err:%v", parameter, err) _ = proxywasm.ResumeHttpRequest() return } log.Infof("start to replace request body, parameter:%s", string(parameterJson)) _ = proxywasm.ReplaceHttpRequestBody(parameterJson) _ = proxywasm.ResumeHttpRequest() }) if err != nil { log.Error("redis access failed") return types.ActionContinue } return types.ActionPause } func fillHistory(chat []ChatHistory, currMessage []ChatHistory, fillHistoryCnt int) []ChatHistory { userInputCnt := 0 for i := 0; i < len(currMessage); i++ { if currMessage[i].Role == "user" { userInputCnt++ } } if userInputCnt > 1 { return currMessage } if fillHistoryCnt > len(chat) { fillHistoryCnt = len(chat) } finalChat := append(chat[len(chat)-fillHistoryCnt:], currMessage...) return finalChat } func isQueryHistory(path string) bool { return strings.Contains(path, "ai-history/query") } func getIntQueryParameter(name string, path string, defaultValue int) int { // 解析 URL parsedURL, err := url.ParseRequestURI(path) if err != nil { fmt.Println("Error parsing URL:", err) return defaultValue } // 获取查询参数 values := parsedURL.Query() // 获取特定的查询参数 "defaultValue" queryStr := values.Get(name) if queryStr == "" { return defaultValue } num, err := strconv.Atoi(queryStr) if err != nil { return defaultValue } return num } func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { content := "" for _, chunk := range strings.Split(sseMessage, "\n\n") { subMessages := strings.Split(chunk, "\n") var message string for _, msg := range subMessages { if strings.HasPrefix(msg, "data:") { message = msg break } } if len(message) < 6 { log.Errorf("invalid message:%s", message) return content } // skip the prefix "data:" bodyJson := message[5:] if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() { tempContentI := ctx.GetContext(AnswerContentContextKey) if tempContentI == nil { content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) ctx.SetContext(AnswerContentContextKey, content) } else { append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) content = tempContentI.(string) + append ctx.SetContext(AnswerContentContextKey, content) } } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { // TODO: compatible with other providers ctx.SetContext(ToolCallsContextKey, struct{}{}) } log.Debugf("unknown message:%s", bodyJson) } return content } func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(StreamContextKey, struct{}{}) } return types.ActionContinue } func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { if ctx.GetContext(ToolCallsContextKey) != nil { // we should not cache tool call result return chunk } questionI := ctx.GetContext(QuestionContextKey) if questionI == nil { return chunk } if isQueryHistory(ctx.Path()) { return chunk } if !isLastChunk { stream := ctx.GetContext(StreamContextKey) if stream == nil { tempContentI := ctx.GetContext(AnswerContentContextKey) if tempContentI == nil { ctx.SetContext(AnswerContentContextKey, chunk) return chunk } tempContent := tempContentI.([]byte) tempContent = append(tempContent, chunk...) ctx.SetContext(AnswerContentContextKey, tempContent) } else { var partialMessage []byte partialMessageI := ctx.GetContext(PartialMessageContextKey) if partialMessageI != nil { partialMessage = append(partialMessageI.([]byte), chunk...) } else { partialMessage = chunk } messages := strings.Split(string(partialMessage), "\n\n") for i, msg := range messages { if i < len(messages)-1 { // process complete message processSSEMessage(ctx, config, msg, log) } } if !strings.HasSuffix(string(partialMessage), "\n\n") { ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1])) } else { ctx.SetContext(PartialMessageContextKey, nil) } } return chunk } stream := ctx.GetContext(StreamContextKey) var value string if stream == nil { var body []byte tempContentI := ctx.GetContext(AnswerContentContextKey) if tempContentI != nil { body = append(tempContentI.([]byte), chunk...) } else { body = chunk } bodyJson := gjson.ParseBytes(body) value = TrimQuote(bodyJson.Get(config.AnswerValueFrom.ResponseBody).Raw) if value == "" { log.Warnf("parse value from response body failded, body:%s", body) return chunk } } else { if len(chunk) > 0 { var lastMessage []byte partialMessageI := ctx.GetContext(PartialMessageContextKey) if partialMessageI != nil { lastMessage = append(partialMessageI.([]byte), chunk...) } else { lastMessage = chunk } if !strings.HasSuffix(string(lastMessage), "\n\n") { log.Warnf("invalid lastMessage:%s", lastMessage) return chunk } // remove the last \n\n lastMessage = lastMessage[:len(lastMessage)-2] value = processSSEMessage(ctx, config, string(lastMessage), log) } else { tempContentI := ctx.GetContext(AnswerContentContextKey) if tempContentI == nil { return chunk } value = tempContentI.(string) } } saveChatHistory(ctx, config, questionI, value, log) return chunk } func saveChatHistory(ctx wrapper.HttpContext, config PluginConfig, questionI any, value string, log wrapper.Log) { question := questionI.(string) identityKey := ctx.GetStringContext(IdentityKey, "") var chat []ChatHistory chatHistories := ctx.GetStringContext(ChatHistories, "") if chatHistories != "" { err := json.Unmarshal([]byte(chatHistories), &chat) if err != nil { log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err) return } } chat = append(chat, ChatHistory{Role: "user", Content: question}) chat = append(chat, ChatHistory{Role: "assistant", Content: value}) if len(chat) > config.FillHistoryCnt*2 { chat = chat[len(chat)-config.FillHistoryCnt*2:] } str, err := json.Marshal(chat) if err != nil { log.Errorf("marshal chat:%v failed, err:%v", chat, err) return } log.Infof("start to Set history, identityKey:%s, chat:%s", identityKey, string(str)) _ = config.redisClient.Set(config.CacheKeyPrefix+identityKey, string(str), nil) if config.CacheTTL != 0 { _ = config.redisClient.Expire(config.CacheKeyPrefix+identityKey, config.CacheTTL, nil) } }