Files
higress/plugins/wasm-go/extensions/ai-history/main.go
2024-08-27 19:24:39 +08:00

481 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// File generated by hgctl. Modify as required.
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
package main
import (
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
QuestionContextKey = "question"
AnswerContentContextKey = "answer"
PartialMessageContextKey = "partialMessage"
ToolCallsContextKey = "toolCalls"
StreamContextKey = "stream"
DefaultCacheKeyPrefix = "higress-ai-history:"
IdentityKey = "identity"
ChatHistories = "chatHistories"
)
func main() {
wrapper.SetCtx(
"ai-history",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamResponseBody),
)
}
// @Name ai-history
// @Category protocol
// @Phase AUTHN
// @Priority 10
// @Title zh-CN AI History
// @Description zh-CN 大模型对话历史缓存
// @IconUrl
// @Version 0.1.0
//
// @Contact.name sakura
// @Contact.url
// @Contact.email
//
// @Example
// redis:
// serviceName: my-redis.dns
// timeout: 2000
//
// @End
type RedisInfo struct {
// @Title zh-CN redis 服务名称
// @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local
ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"`
// @Title zh-CN redis 服务端口
// @Description zh-CN 默认值为6379
ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"`
// @Title zh-CN 用户名
// @Description zh-CN 登陆 redis 的用户名,非必填
Username string `required:"false" yaml:"username" json:"username"`
// @Title zh-CN 密码
// @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码
Password string `required:"false" yaml:"password" json:"password"`
// @Title zh-CN 请求超时
// @Description zh-CN 请求 redis 的超时时间单位为毫秒。默认值是1000即1秒
Timeout int `required:"false" yaml:"timeout" json:"timeout"`
}
type KVExtractor struct {
// @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"`
// @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
}
type PluginConfig struct {
// @Title zh-CN Redis 地址信息
// @Description zh-CN 用于存储缓存结果的 Redis 地址
RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"`
// @Title zh-CN 缓存 key 的来源
// @Description zh-CN 往 redis 里存时,问题的提取方式
QuestionFrom KVExtractor `required:"true" yaml:"questionFrom" json:"questionFrom"`
// @Title zh-CN 缓存 value 的来源
// @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式
AnswerValueFrom KVExtractor `required:"true" yaml:"answerValueFrom" json:"answerValueFrom"`
// @Title zh-CN 流式响应下,缓存 value 的来源
// @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式
AnswerStreamValueFrom KVExtractor `required:"true" yaml:"answerStreamValueFrom" json:"answerStreamValueFrom"`
// @Title zh-CN Redis缓存Key的前缀
// @Description zh-CN 默认值是"higress-ai-cache:"
CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"`
// @Title zh-CN 身份解析方式
// @Description zh-CN 默认值是"Authorization"
IdentityHeader string `required:"false" yaml:"identityHeader" json:"identityHeader"`
// @Title zh-CN 默认填充历史对话轮数
// @Description zh-CN 默认值是 3
FillHistoryCnt int `required:"false" yaml:"fillHistoryCnt" json:"fillHistoryCnt"`
// @Title zh-CN 缓存的过期时间
// @Description zh-CN 单位是秒默认值为0即永不过期
CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"`
redisClient wrapper.RedisClient `yaml:"-" json:"-"`
}
type ChatHistory struct {
Role string `json:"role"`
Content string `json:"content"`
}
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
c.RedisInfo.ServiceName = json.Get("redis.serviceName").String()
if c.RedisInfo.ServiceName == "" {
return errors.New("redis service name must not be empty")
}
c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int())
if c.RedisInfo.ServicePort == 0 {
if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") {
// use default logic port which is 80 for static service
c.RedisInfo.ServicePort = 80
} else {
c.RedisInfo.ServicePort = 6379
}
}
c.RedisInfo.Username = json.Get("redis.username").String()
c.RedisInfo.Password = json.Get("redis.password").String()
c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int())
if c.RedisInfo.Timeout == 0 {
c.RedisInfo.Timeout = 1000
}
c.QuestionFrom.RequestBody = "messages.@reverse.0.content"
c.AnswerValueFrom.ResponseBody = "choices.0.message.content"
c.AnswerStreamValueFrom.ResponseBody = "choices.0.delta.content"
c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String()
if c.CacheKeyPrefix == "" {
c.CacheKeyPrefix = DefaultCacheKeyPrefix
}
c.IdentityHeader = json.Get("identityHeader").String()
if c.IdentityHeader == "" {
c.IdentityHeader = "Authorization"
}
c.FillHistoryCnt = int(json.Get("fillHistoryCnt").Int())
if c.FillHistoryCnt == 0 {
c.FillHistoryCnt = 3
}
c.CacheTTL = int(json.Get("cacheTTL").Int())
c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: c.RedisInfo.ServiceName,
Port: int64(c.RedisInfo.ServicePort),
})
return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout))
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
if !strings.Contains(contentType, "application/json") {
log.Warnf("content is not json, can't process:%s", contentType)
ctx.DontReadRequestBody()
return types.ActionContinue
}
// get identity key
identityKey, _ := proxywasm.GetHttpRequestHeader(config.IdentityHeader)
if identityKey == "" {
log.Warnf("identity key is empty")
return types.ActionContinue
}
identityKey = strings.ReplaceAll(identityKey, " ", "")
ctx.SetContext(IdentityKey, identityKey)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
// The request has a body and requires delaying the header transmission until a cache miss occurs,
// at which point the header should be sent.
return types.HeaderStopIteration
}
func TrimQuote(source string) string {
return strings.Trim(source, `"`)
}
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
bodyJson := gjson.ParseBytes(body)
if bodyJson.Get("stream").Bool() {
ctx.SetContext(StreamContextKey, struct{}{})
}
identityKey := ctx.GetStringContext(IdentityKey, "")
err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("redis get failed, err:%v", err)
_ = proxywasm.ResumeHttpRequest()
return
}
if response.IsNull() {
log.Debugf("cache miss, identityKey:%s", identityKey)
_ = proxywasm.ResumeHttpRequest()
return
}
chatHistories := response.String()
ctx.SetContext(ChatHistories, chatHistories)
var chat []ChatHistory
err := json.Unmarshal([]byte(chatHistories), &chat)
if err != nil {
log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err)
_ = proxywasm.ResumeHttpRequest()
return
}
path := ctx.Path()
if isQueryHistory(path) {
cnt := getIntQueryParameter("cnt", path, len(chat)/2) * 2
if cnt > len(chat) {
cnt = len(chat)
}
chat = chat[len(chat)-cnt:]
res, err := json.Marshal(chat)
if err != nil {
log.Errorf("marshal chat:%v failed, err:%v", chat, err)
_ = proxywasm.ResumeHttpRequest()
return
}
_ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1)
return
}
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
if question == "" {
log.Debug("parse question from request body failed")
_ = proxywasm.ResumeHttpRequest()
return
}
ctx.SetContext(QuestionContextKey, question)
fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2
currJson := bodyJson.Get("messages").String()
var currMessage []ChatHistory
err = json.Unmarshal([]byte(currJson), &currMessage)
if err != nil {
log.Errorf("unmarshal currMessage:%s failed, err:%v", currJson, err)
_ = proxywasm.ResumeHttpRequest()
return
}
finalChat := fillHistory(chat, currMessage, fillHistoryCnt)
var parameter map[string]any
err = json.Unmarshal(body, &parameter)
if err != nil {
log.Errorf("unmarshal body:%s failed, err:%v", body, err)
_ = proxywasm.ResumeHttpRequest()
return
}
parameter["messages"] = finalChat
parameterJson, err := json.Marshal(parameter)
if err != nil {
log.Errorf("marshal parameter:%v failed, err:%v", parameter, err)
_ = proxywasm.ResumeHttpRequest()
return
}
log.Infof("start to replace request body, parameter:%s", string(parameterJson))
_ = proxywasm.ReplaceHttpRequestBody(parameterJson)
_ = proxywasm.ResumeHttpRequest()
})
if err != nil {
log.Error("redis access failed")
return types.ActionContinue
}
return types.ActionPause
}
func fillHistory(chat []ChatHistory, currMessage []ChatHistory, fillHistoryCnt int) []ChatHistory {
userInputCnt := 0
for i := 0; i < len(currMessage); i++ {
if currMessage[i].Role == "user" {
userInputCnt++
}
}
if userInputCnt > 1 {
return currMessage
}
if fillHistoryCnt > len(chat) {
fillHistoryCnt = len(chat)
}
finalChat := append(chat[len(chat)-fillHistoryCnt:], currMessage...)
return finalChat
}
func isQueryHistory(path string) bool {
return strings.Contains(path, "ai-history/query")
}
func getIntQueryParameter(name string, path string, defaultValue int) int {
// 解析 URL
parsedURL, err := url.ParseRequestURI(path)
if err != nil {
fmt.Println("Error parsing URL:", err)
return defaultValue
}
// 获取查询参数
values := parsedURL.Query()
// 获取特定的查询参数 "defaultValue"
queryStr := values.Get(name)
if queryStr == "" {
return defaultValue
}
num, err := strconv.Atoi(queryStr)
if err != nil {
return defaultValue
}
return num
}
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
subMessages := strings.Split(sseMessage, "\n")
var message string
for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") {
message = msg
break
}
}
if len(message) < 6 {
log.Errorf("invalid message:%s", message)
return ""
}
// skip the prefix "data:"
bodyJson := message[5:]
if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
tempContentI := ctx.GetContext(AnswerContentContextKey)
if tempContentI == nil {
content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
ctx.SetContext(AnswerContentContextKey, content)
return content
}
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
content := tempContentI.(string) + append
ctx.SetContext(AnswerContentContextKey, content)
return content
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
// TODO: compatible with other providers
ctx.SetContext(ToolCallsContextKey, struct{}{})
return ""
}
log.Debugf("unknown message:%s", bodyJson)
return ""
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if strings.Contains(contentType, "text/event-stream") {
ctx.SetContext(StreamContextKey, struct{}{})
}
return types.ActionContinue
}
func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
if ctx.GetContext(ToolCallsContextKey) != nil {
// we should not cache tool call result
return chunk
}
questionI := ctx.GetContext(QuestionContextKey)
if questionI == nil {
return chunk
}
if isQueryHistory(ctx.Path()) {
return chunk
}
if !isLastChunk {
stream := ctx.GetContext(StreamContextKey)
if stream == nil {
tempContentI := ctx.GetContext(AnswerContentContextKey)
if tempContentI == nil {
ctx.SetContext(AnswerContentContextKey, chunk)
return chunk
}
tempContent := tempContentI.([]byte)
tempContent = append(tempContent, chunk...)
ctx.SetContext(AnswerContentContextKey, tempContent)
} else {
var partialMessage []byte
partialMessageI := ctx.GetContext(PartialMessageContextKey)
if partialMessageI != nil {
partialMessage = append(partialMessageI.([]byte), chunk...)
} else {
partialMessage = chunk
}
messages := strings.Split(string(partialMessage), "\n\n")
for i, msg := range messages {
if i < len(messages)-1 {
// process complete message
processSSEMessage(ctx, config, msg, log)
}
}
if !strings.HasSuffix(string(partialMessage), "\n\n") {
ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1]))
} else {
ctx.SetContext(PartialMessageContextKey, nil)
}
}
return chunk
}
stream := ctx.GetContext(StreamContextKey)
var value string
if stream == nil {
var body []byte
tempContentI := ctx.GetContext(AnswerContentContextKey)
if tempContentI != nil {
body = append(tempContentI.([]byte), chunk...)
} else {
body = chunk
}
bodyJson := gjson.ParseBytes(body)
value = TrimQuote(bodyJson.Get(config.AnswerValueFrom.ResponseBody).Raw)
if value == "" {
log.Warnf("parse value from response body failded, body:%s", body)
return chunk
}
} else {
if len(chunk) > 0 {
var lastMessage []byte
partialMessageI := ctx.GetContext(PartialMessageContextKey)
if partialMessageI != nil {
lastMessage = append(partialMessageI.([]byte), chunk...)
} else {
lastMessage = chunk
}
if !strings.HasSuffix(string(lastMessage), "\n\n") {
log.Warnf("invalid lastMessage:%s", lastMessage)
return chunk
}
// remove the last \n\n
lastMessage = lastMessage[:len(lastMessage)-2]
value = processSSEMessage(ctx, config, string(lastMessage), log)
} else {
tempContentI := ctx.GetContext(AnswerContentContextKey)
if tempContentI == nil {
return chunk
}
value = tempContentI.(string)
}
}
saveChatHistory(ctx, config, questionI, value, log)
return chunk
}
func saveChatHistory(ctx wrapper.HttpContext, config PluginConfig, questionI any, value string, log wrapper.Log) {
question := questionI.(string)
identityKey := ctx.GetStringContext(IdentityKey, "")
var chat []ChatHistory
chatHistories := ctx.GetStringContext(ChatHistories, "")
if chatHistories != "" {
err := json.Unmarshal([]byte(chatHistories), &chat)
if err != nil {
log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err)
return
}
}
chat = append(chat, ChatHistory{Role: "user", Content: question})
chat = append(chat, ChatHistory{Role: "assistant", Content: value})
if len(chat) > config.FillHistoryCnt*2 {
chat = chat[len(chat)-config.FillHistoryCnt*2:]
}
str, err := json.Marshal(chat)
if err != nil {
log.Errorf("marshal chat:%v failed, err:%v", chat, err)
return
}
log.Infof("start to Set history, identityKey:%s, chat:%s", identityKey, string(str))
_ = config.redisClient.Set(config.CacheKeyPrefix+identityKey, string(str), nil)
if config.CacheTTL != 0 {
_ = config.redisClient.Expire(config.CacheKeyPrefix+identityKey, config.CacheTTL, nil)
}
}