mirror of
https://github.com/alibaba/higress.git
synced 2026-03-09 11:10:49 +08:00
481 lines
16 KiB
Go
481 lines
16 KiB
Go
// File generated by hgctl. Modify as required.
|
||
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
|
||
|
||
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/resp"
|
||
)
|
||
|
||
const (
|
||
QuestionContextKey = "question"
|
||
AnswerContentContextKey = "answer"
|
||
PartialMessageContextKey = "partialMessage"
|
||
ToolCallsContextKey = "toolCalls"
|
||
StreamContextKey = "stream"
|
||
DefaultCacheKeyPrefix = "higress-ai-history:"
|
||
IdentityKey = "identity"
|
||
ChatHistories = "chatHistories"
|
||
)
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"ai-history",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamResponseBody),
|
||
)
|
||
}
|
||
|
||
// @Name ai-history
|
||
// @Category protocol
|
||
// @Phase AUTHN
|
||
// @Priority 10
|
||
// @Title zh-CN AI History
|
||
// @Description zh-CN 大模型对话历史缓存
|
||
// @IconUrl
|
||
// @Version 0.1.0
|
||
//
|
||
// @Contact.name sakura
|
||
// @Contact.url
|
||
// @Contact.email
|
||
//
|
||
// @Example
|
||
// redis:
|
||
// serviceName: my-redis.dns
|
||
// timeout: 2000
|
||
//
|
||
// @End
|
||
|
||
type RedisInfo struct {
|
||
// @Title zh-CN redis 服务名称
|
||
// @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local
|
||
ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"`
|
||
// @Title zh-CN redis 服务端口
|
||
// @Description zh-CN 默认值为6379
|
||
ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"`
|
||
// @Title zh-CN 用户名
|
||
// @Description zh-CN 登陆 redis 的用户名,非必填
|
||
Username string `required:"false" yaml:"username" json:"username"`
|
||
// @Title zh-CN 密码
|
||
// @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码
|
||
Password string `required:"false" yaml:"password" json:"password"`
|
||
// @Title zh-CN 请求超时
|
||
// @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒
|
||
Timeout int `required:"false" yaml:"timeout" json:"timeout"`
|
||
}
|
||
|
||
type KVExtractor struct {
|
||
// @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
|
||
RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"`
|
||
// @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
|
||
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
|
||
}
|
||
|
||
type PluginConfig struct {
|
||
// @Title zh-CN Redis 地址信息
|
||
// @Description zh-CN 用于存储缓存结果的 Redis 地址
|
||
RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"`
|
||
// @Title zh-CN 缓存 key 的来源
|
||
// @Description zh-CN 往 redis 里存时,问题的提取方式
|
||
QuestionFrom KVExtractor `required:"true" yaml:"questionFrom" json:"questionFrom"`
|
||
// @Title zh-CN 缓存 value 的来源
|
||
// @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式
|
||
AnswerValueFrom KVExtractor `required:"true" yaml:"answerValueFrom" json:"answerValueFrom"`
|
||
// @Title zh-CN 流式响应下,缓存 value 的来源
|
||
// @Description zh-CN 往 redis 里存时,使用的 answer 的提取方式
|
||
AnswerStreamValueFrom KVExtractor `required:"true" yaml:"answerStreamValueFrom" json:"answerStreamValueFrom"`
|
||
// @Title zh-CN Redis缓存Key的前缀
|
||
// @Description zh-CN 默认值是"higress-ai-cache:"
|
||
CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"`
|
||
// @Title zh-CN 身份解析方式
|
||
// @Description zh-CN 默认值是"Authorization"
|
||
IdentityHeader string `required:"false" yaml:"identityHeader" json:"identityHeader"`
|
||
// @Title zh-CN 默认填充历史对话轮数
|
||
// @Description zh-CN 默认值是 3
|
||
FillHistoryCnt int `required:"false" yaml:"fillHistoryCnt" json:"fillHistoryCnt"`
|
||
// @Title zh-CN 缓存的过期时间
|
||
// @Description zh-CN 单位是秒,默认值为0,即永不过期
|
||
CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"`
|
||
redisClient wrapper.RedisClient `yaml:"-" json:"-"`
|
||
}
|
||
|
||
type ChatHistory struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||
c.RedisInfo.ServiceName = json.Get("redis.serviceName").String()
|
||
if c.RedisInfo.ServiceName == "" {
|
||
return errors.New("redis service name must not be empty")
|
||
}
|
||
c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int())
|
||
if c.RedisInfo.ServicePort == 0 {
|
||
if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") {
|
||
// use default logic port which is 80 for static service
|
||
c.RedisInfo.ServicePort = 80
|
||
} else {
|
||
c.RedisInfo.ServicePort = 6379
|
||
}
|
||
}
|
||
c.RedisInfo.Username = json.Get("redis.username").String()
|
||
c.RedisInfo.Password = json.Get("redis.password").String()
|
||
c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int())
|
||
if c.RedisInfo.Timeout == 0 {
|
||
c.RedisInfo.Timeout = 1000
|
||
}
|
||
c.QuestionFrom.RequestBody = "messages.@reverse.0.content"
|
||
c.AnswerValueFrom.ResponseBody = "choices.0.message.content"
|
||
c.AnswerStreamValueFrom.ResponseBody = "choices.0.delta.content"
|
||
|
||
c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String()
|
||
if c.CacheKeyPrefix == "" {
|
||
c.CacheKeyPrefix = DefaultCacheKeyPrefix
|
||
}
|
||
c.IdentityHeader = json.Get("identityHeader").String()
|
||
if c.IdentityHeader == "" {
|
||
c.IdentityHeader = "Authorization"
|
||
}
|
||
c.FillHistoryCnt = int(json.Get("fillHistoryCnt").Int())
|
||
if c.FillHistoryCnt == 0 {
|
||
c.FillHistoryCnt = 3
|
||
}
|
||
c.CacheTTL = int(json.Get("cacheTTL").Int())
|
||
c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||
FQDN: c.RedisInfo.ServiceName,
|
||
Port: int64(c.RedisInfo.ServicePort),
|
||
})
|
||
return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout))
|
||
}
|
||
|
||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||
if !strings.Contains(contentType, "application/json") {
|
||
log.Warnf("content is not json, can't process:%s", contentType)
|
||
ctx.DontReadRequestBody()
|
||
return types.ActionContinue
|
||
}
|
||
// get identity key
|
||
identityKey, _ := proxywasm.GetHttpRequestHeader(config.IdentityHeader)
|
||
if identityKey == "" {
|
||
log.Warnf("identity key is empty")
|
||
return types.ActionContinue
|
||
}
|
||
identityKey = strings.ReplaceAll(identityKey, " ", "")
|
||
ctx.SetContext(IdentityKey, identityKey)
|
||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||
// The request has a body and requires delaying the header transmission until a cache miss occurs,
|
||
// at which point the header should be sent.
|
||
return types.HeaderStopIteration
|
||
}
|
||
|
||
func TrimQuote(source string) string {
|
||
return strings.Trim(source, `"`)
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
bodyJson := gjson.ParseBytes(body)
|
||
if bodyJson.Get("stream").Bool() {
|
||
ctx.SetContext(StreamContextKey, struct{}{})
|
||
}
|
||
identityKey := ctx.GetStringContext(IdentityKey, "")
|
||
err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) {
|
||
if err := response.Error(); err != nil {
|
||
log.Errorf("redis get failed, err:%v", err)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
if response.IsNull() {
|
||
log.Debugf("cache miss, identityKey:%s", identityKey)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
chatHistories := response.String()
|
||
ctx.SetContext(ChatHistories, chatHistories)
|
||
var chat []ChatHistory
|
||
err := json.Unmarshal([]byte(chatHistories), &chat)
|
||
if err != nil {
|
||
log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
path := ctx.Path()
|
||
if isQueryHistory(path) {
|
||
cnt := getIntQueryParameter("cnt", path, len(chat)/2) * 2
|
||
if cnt > len(chat) {
|
||
cnt = len(chat)
|
||
}
|
||
chat = chat[len(chat)-cnt:]
|
||
res, err := json.Marshal(chat)
|
||
if err != nil {
|
||
log.Errorf("marshal chat:%v failed, err:%v", chat, err)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
_ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1)
|
||
return
|
||
}
|
||
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
|
||
if question == "" {
|
||
log.Debug("parse question from request body failed")
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
ctx.SetContext(QuestionContextKey, question)
|
||
fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2
|
||
currJson := bodyJson.Get("messages").String()
|
||
var currMessage []ChatHistory
|
||
err = json.Unmarshal([]byte(currJson), &currMessage)
|
||
if err != nil {
|
||
log.Errorf("unmarshal currMessage:%s failed, err:%v", currJson, err)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
finalChat := fillHistory(chat, currMessage, fillHistoryCnt)
|
||
var parameter map[string]any
|
||
err = json.Unmarshal(body, ¶meter)
|
||
if err != nil {
|
||
log.Errorf("unmarshal body:%s failed, err:%v", body, err)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
parameter["messages"] = finalChat
|
||
parameterJson, err := json.Marshal(parameter)
|
||
if err != nil {
|
||
log.Errorf("marshal parameter:%v failed, err:%v", parameter, err)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}
|
||
log.Infof("start to replace request body, parameter:%s", string(parameterJson))
|
||
_ = proxywasm.ReplaceHttpRequestBody(parameterJson)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
})
|
||
if err != nil {
|
||
log.Error("redis access failed")
|
||
return types.ActionContinue
|
||
}
|
||
return types.ActionPause
|
||
}
|
||
|
||
func fillHistory(chat []ChatHistory, currMessage []ChatHistory, fillHistoryCnt int) []ChatHistory {
|
||
userInputCnt := 0
|
||
for i := 0; i < len(currMessage); i++ {
|
||
if currMessage[i].Role == "user" {
|
||
userInputCnt++
|
||
}
|
||
}
|
||
if userInputCnt > 1 {
|
||
return currMessage
|
||
}
|
||
if fillHistoryCnt > len(chat) {
|
||
fillHistoryCnt = len(chat)
|
||
}
|
||
finalChat := append(chat[len(chat)-fillHistoryCnt:], currMessage...)
|
||
return finalChat
|
||
}
|
||
|
||
func isQueryHistory(path string) bool {
|
||
return strings.Contains(path, "ai-history/query")
|
||
}
|
||
|
||
func getIntQueryParameter(name string, path string, defaultValue int) int {
|
||
// 解析 URL
|
||
parsedURL, err := url.ParseRequestURI(path)
|
||
if err != nil {
|
||
fmt.Println("Error parsing URL:", err)
|
||
return defaultValue
|
||
}
|
||
|
||
// 获取查询参数
|
||
values := parsedURL.Query()
|
||
|
||
// 获取特定的查询参数 "defaultValue"
|
||
queryStr := values.Get(name)
|
||
if queryStr == "" {
|
||
return defaultValue
|
||
}
|
||
num, err := strconv.Atoi(queryStr)
|
||
if err != nil {
|
||
return defaultValue
|
||
}
|
||
return num
|
||
}
|
||
|
||
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
|
||
subMessages := strings.Split(sseMessage, "\n")
|
||
var message string
|
||
for _, msg := range subMessages {
|
||
if strings.HasPrefix(msg, "data:") {
|
||
message = msg
|
||
break
|
||
}
|
||
}
|
||
if len(message) < 6 {
|
||
log.Errorf("invalid message:%s", message)
|
||
return ""
|
||
}
|
||
// skip the prefix "data:"
|
||
bodyJson := message[5:]
|
||
if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
|
||
tempContentI := ctx.GetContext(AnswerContentContextKey)
|
||
if tempContentI == nil {
|
||
content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
|
||
ctx.SetContext(AnswerContentContextKey, content)
|
||
return content
|
||
}
|
||
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
|
||
content := tempContentI.(string) + append
|
||
ctx.SetContext(AnswerContentContextKey, content)
|
||
return content
|
||
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
|
||
// TODO: compatible with other providers
|
||
ctx.SetContext(ToolCallsContextKey, struct{}{})
|
||
return ""
|
||
}
|
||
log.Debugf("unknown message:%s", bodyJson)
|
||
return ""
|
||
}
|
||
|
||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||
if strings.Contains(contentType, "text/event-stream") {
|
||
ctx.SetContext(StreamContextKey, struct{}{})
|
||
}
|
||
return types.ActionContinue
|
||
}
|
||
func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||
if ctx.GetContext(ToolCallsContextKey) != nil {
|
||
// we should not cache tool call result
|
||
return chunk
|
||
}
|
||
questionI := ctx.GetContext(QuestionContextKey)
|
||
if questionI == nil {
|
||
return chunk
|
||
}
|
||
if isQueryHistory(ctx.Path()) {
|
||
return chunk
|
||
}
|
||
if !isLastChunk {
|
||
stream := ctx.GetContext(StreamContextKey)
|
||
if stream == nil {
|
||
tempContentI := ctx.GetContext(AnswerContentContextKey)
|
||
if tempContentI == nil {
|
||
ctx.SetContext(AnswerContentContextKey, chunk)
|
||
return chunk
|
||
}
|
||
tempContent := tempContentI.([]byte)
|
||
tempContent = append(tempContent, chunk...)
|
||
ctx.SetContext(AnswerContentContextKey, tempContent)
|
||
} else {
|
||
var partialMessage []byte
|
||
partialMessageI := ctx.GetContext(PartialMessageContextKey)
|
||
if partialMessageI != nil {
|
||
partialMessage = append(partialMessageI.([]byte), chunk...)
|
||
} else {
|
||
partialMessage = chunk
|
||
}
|
||
messages := strings.Split(string(partialMessage), "\n\n")
|
||
for i, msg := range messages {
|
||
if i < len(messages)-1 {
|
||
// process complete message
|
||
processSSEMessage(ctx, config, msg, log)
|
||
}
|
||
}
|
||
if !strings.HasSuffix(string(partialMessage), "\n\n") {
|
||
ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1]))
|
||
} else {
|
||
ctx.SetContext(PartialMessageContextKey, nil)
|
||
}
|
||
}
|
||
return chunk
|
||
}
|
||
|
||
stream := ctx.GetContext(StreamContextKey)
|
||
var value string
|
||
if stream == nil {
|
||
var body []byte
|
||
tempContentI := ctx.GetContext(AnswerContentContextKey)
|
||
if tempContentI != nil {
|
||
body = append(tempContentI.([]byte), chunk...)
|
||
} else {
|
||
body = chunk
|
||
}
|
||
bodyJson := gjson.ParseBytes(body)
|
||
|
||
value = TrimQuote(bodyJson.Get(config.AnswerValueFrom.ResponseBody).Raw)
|
||
if value == "" {
|
||
log.Warnf("parse value from response body failded, body:%s", body)
|
||
return chunk
|
||
}
|
||
} else {
|
||
if len(chunk) > 0 {
|
||
var lastMessage []byte
|
||
partialMessageI := ctx.GetContext(PartialMessageContextKey)
|
||
if partialMessageI != nil {
|
||
lastMessage = append(partialMessageI.([]byte), chunk...)
|
||
} else {
|
||
lastMessage = chunk
|
||
}
|
||
if !strings.HasSuffix(string(lastMessage), "\n\n") {
|
||
log.Warnf("invalid lastMessage:%s", lastMessage)
|
||
return chunk
|
||
}
|
||
// remove the last \n\n
|
||
lastMessage = lastMessage[:len(lastMessage)-2]
|
||
value = processSSEMessage(ctx, config, string(lastMessage), log)
|
||
} else {
|
||
tempContentI := ctx.GetContext(AnswerContentContextKey)
|
||
if tempContentI == nil {
|
||
return chunk
|
||
}
|
||
value = tempContentI.(string)
|
||
}
|
||
}
|
||
saveChatHistory(ctx, config, questionI, value, log)
|
||
return chunk
|
||
}
|
||
|
||
func saveChatHistory(ctx wrapper.HttpContext, config PluginConfig, questionI any, value string, log wrapper.Log) {
|
||
question := questionI.(string)
|
||
identityKey := ctx.GetStringContext(IdentityKey, "")
|
||
var chat []ChatHistory
|
||
chatHistories := ctx.GetStringContext(ChatHistories, "")
|
||
if chatHistories != "" {
|
||
err := json.Unmarshal([]byte(chatHistories), &chat)
|
||
if err != nil {
|
||
log.Errorf("unmarshal chatHistories:%s failed, err:%v", chatHistories, err)
|
||
return
|
||
}
|
||
}
|
||
chat = append(chat, ChatHistory{Role: "user", Content: question})
|
||
chat = append(chat, ChatHistory{Role: "assistant", Content: value})
|
||
if len(chat) > config.FillHistoryCnt*2 {
|
||
chat = chat[len(chat)-config.FillHistoryCnt*2:]
|
||
}
|
||
str, err := json.Marshal(chat)
|
||
if err != nil {
|
||
log.Errorf("marshal chat:%v failed, err:%v", chat, err)
|
||
return
|
||
}
|
||
log.Infof("start to Set history, identityKey:%s, chat:%s", identityKey, string(str))
|
||
_ = config.redisClient.Set(config.CacheKeyPrefix+identityKey, string(str), nil)
|
||
if config.CacheTTL != 0 {
|
||
_ = config.redisClient.Expire(config.CacheKeyPrefix+identityKey, config.CacheTTL, nil)
|
||
}
|
||
}
|