Files
higress/plugins/wasm-go/extensions/ai-intent/main.go
2024-08-27 19:25:18 +08:00

358 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// File generated by hgctl. Modify as required.
// See:
package main
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
const (
DefaultPrompt = "你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:'%s',预设的类别为'%s',直接返回一种具体类别,如果没有找到就返回'NotFound'。"
defaultTimeout = 10 * 1000 // ms
)
func main() {
wrapper.SetCtx(
"ai-intent",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
// @Name ai-intent
// @Category protocol
// @Phase AUTHN
// @Priority 1000
// @Title zh-CN AI intent
// @Description zh-CN 大模型意图识别
// @IconUrl
// @Version 0.1.0
//
// @Contact.name jose
// @Contact.url
// @Contact.email
//@Example
// scene:
// category: "金融|电商|法律|Higress"
// prompt:"你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:%s,预设的类别为%s直接返回一种具体类别如果没有找到就返回'NotFound'。"
// 例:"你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:今天天气怎么样?,预设的类别为 ["金融","电商","法律"],直接返回一种具体类别,如果没有找到就返回"NotFound"。"
type SceneInfo struct {
Category string `require:"true" yaml:"category" json:"category"`
Prompt string `require:"false" yaml:"prompt" json:"prompt"`
//解析category后的数组
CategoryArr []string `yaml:"-" json:"-"`
}
type LLMInfo struct {
ProxyServiceName string `require:"true" yaml:"proxyServiceName" json:"proxyServiceName"`
ProxyUrl string `require:"false" yaml:"proxyUrl" json:"proxyUrl"`
ProxyModel string `require:"false" yaml:"proxyModel" json:"proxyModel"`
// @Title zh-CN 大模型服务端口
// @Description zh-CN 服务端口
ProxyPort int64 `required:"false" yaml:"proxyPort" json:"proxyPort"`
// @Title zh-CN 大模型服务域名
// @Description zh-CN 大模型服务域名
ProxyDomain string `required:"false" yaml:"proxyDomain" json:"proxyDomain"`
ProxyTimeout uint32 `require:"false" yaml:"proxyTimeout" json:"proxyTimeout"`
// @Title zh-CN 大模型服务的API_KEY
// @Description zh-CN 大模型服务的API_KEY
ProxyApiKey string `require:"false" yaml:"proxyApiKey" json:"proxyApiKey"`
ProxyClient wrapper.HttpClient `yaml:"-" json:"-"`
// @Title zh-CN 大模型接口路径
// @Description zh-CN 大模型接口路径
ProxyPath string `yaml:"-" json:"-"`
}
type PluginConfig struct {
// @Title zh-CN 意图相关配置
// @Description zh-CN SceneInfo
SceneInfo SceneInfo `required:"true" yaml:"scene" json:"scene"`
// @Title zh-CN 大模型相关配置
// @Description zh-CN LLMInfo
LLMInfo LLMInfo `required:"true" yaml:"llm" json:"llm"`
// @Title zh-CN key 的来源
// @Description zh-CN 使用的 key 的提取方式
KeyFrom KVExtractor `required:"true" yaml:"keyFrom" json:"keyFrom"`
}
type KVExtractor struct {
// @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"`
// @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
}
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
log.Infof("config:%s", json.Raw)
// init scene
c.SceneInfo.Category = json.Get("scene.category").String()
log.Infof("SceneInfo.Category:%s", c.SceneInfo.Category)
if c.SceneInfo.Category == "" {
return errors.New("scene.category must not by empty")
}
c.SceneInfo.CategoryArr = strings.Split(c.SceneInfo.Category, "|")
if len(c.SceneInfo.CategoryArr) <= 0 {
return errors.New("scene.category resolve exception, should use '|' split")
}
c.SceneInfo.Prompt = json.Get("scene.prompt").String()
if c.SceneInfo.Prompt == "" {
c.SceneInfo.Prompt = DefaultPrompt
}
log.Infof("SceneInfo.Prompt:%s", c.SceneInfo.Prompt)
// init llmProxy
log.Debug("Start to init proxyService's http client.")
c.LLMInfo.ProxyServiceName = json.Get("llm.proxyServiceName").String()
log.Infof("ProxyServiceName: %s", c.LLMInfo.ProxyServiceName)
if c.LLMInfo.ProxyServiceName == "" {
return errors.New("llm.proxyServiceName must not by empty")
}
c.LLMInfo.ProxyUrl = json.Get("llm.proxyUrl").String()
log.Infof("c.LLMInfo.ProxyUrl:%s", c.LLMInfo.ProxyUrl)
if c.LLMInfo.ProxyUrl == "" {
return errors.New("llm.proxyUrl must not by empty")
}
//解析域名和path
parsedURL, err := url.Parse(c.LLMInfo.ProxyUrl)
if err != nil {
return errors.New("llm.proxyUrl parsing error")
}
c.LLMInfo.ProxyPath = parsedURL.Path
log.Infof("c.LLMInfo.ProxyPath:%s", c.LLMInfo.ProxyPath)
c.LLMInfo.ProxyDomain = json.Get("llm.proxyDomain").String()
//没有配置llm.proxyDomain时则从proxyUrl中解析获取
if c.LLMInfo.ProxyDomain == "" {
hostName := parsedURL.Hostname()
log.Infof("llm.proxyUrl.hostName:%s", hostName)
if hostName != "" {
c.LLMInfo.ProxyDomain = hostName
}
}
log.Infof("c.LLMInfo.ProxyDomain:%s", c.LLMInfo.ProxyDomain)
c.LLMInfo.ProxyPort = json.Get("llm.proxyPort").Int()
// 没有配置llm.proxyPort时则从proxyUrl中解析获取,如果解析的port为空则http协议端口默认80https端口默认443
if c.LLMInfo.ProxyPort <= 0 {
port := parsedURL.Port()
log.Infof("llm.proxyUrl.port:%s", port)
if port == "" {
c.LLMInfo.ProxyPort = 80
if parsedURL.Scheme == "https" {
c.LLMInfo.ProxyPort = 443
}
} else {
portNum, err := strconv.ParseInt(port, 10, 64)
if err != nil {
return errors.New("llm.proxyUrl.port parsing error")
}
c.LLMInfo.ProxyPort = portNum
}
}
log.Infof("c.LLMInfo.ProxyPort:%s", c.LLMInfo.ProxyPort)
c.LLMInfo.ProxyClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.LLMInfo.ProxyServiceName,
Port: c.LLMInfo.ProxyPort,
Host: c.LLMInfo.ProxyDomain,
})
c.LLMInfo.ProxyModel = json.Get("llm.proxyModel").String()
log.Infof("c.LLMInfo.ProxyModel:%s", c.LLMInfo.ProxyModel)
if c.LLMInfo.ProxyModel == "" {
c.LLMInfo.ProxyModel = "qwen-long"
}
c.LLMInfo.ProxyTimeout = uint32(json.Get("llm.proxyTimeout").Uint())
log.Infof("c.LLMInfo.ProxyTimeout:%s", c.LLMInfo.ProxyTimeout)
if c.LLMInfo.ProxyTimeout <= 0 {
c.LLMInfo.ProxyTimeout = defaultTimeout
}
c.LLMInfo.ProxyApiKey = json.Get("llm.proxyApiKey").String()
log.Infof("c.LLMInfo.ProxyApiKey:%s", c.LLMInfo.ProxyApiKey)
c.KeyFrom.RequestBody = json.Get("keyFrom.requestBody").String()
if c.KeyFrom.RequestBody == "" {
c.KeyFrom.RequestBody = "messages.@reverse.0.content"
}
c.KeyFrom.ResponseBody = json.Get("keyFrom.responseBody").String()
if c.KeyFrom.ResponseBody == "" {
c.KeyFrom.ResponseBody = "choices.0.message.content"
}
log.Debug("Init ai intent's components successfully.")
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
log.Debug("start onHttpRequestHeaders function.")
log.Debug("end onHttpRequestHeaders function.")
return types.HeaderStopIteration
}
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
log.Debug("start onHttpRequestBody function.")
bodyJson := gjson.ParseBytes(body)
TempKey := strings.Trim(bodyJson.Get(config.KeyFrom.RequestBody).Raw, `"`)
//原始问题
originalQuestion, _ := zhToUnicode([]byte(TempKey))
log.Infof("[onHttpRequestBody] originalQuestion is: %s", string(originalQuestion))
//prompt拼接,替换问题和预设的场景类别,参数占位替换
promptStr := fmt.Sprintf(config.SceneInfo.Prompt, string(originalQuestion), config.SceneInfo.Category)
log.Infof("[onHttpRequestBody] after prompt is: %s", promptStr)
proxyUrl, proxyRequestBody, proxyRequestHeader := generateProxyRequest(&config, []string{string(promptStr)}, log)
log.Infof("[onHttpRequestBody] proxyUrl is: %s", proxyUrl)
log.Infof("[onHttpRequestBody] proxyRequestBody is: %s", string(proxyRequestBody))
//调用大模型 获取意向类型
llmProxyErr := config.LLMInfo.ProxyClient.Post(
proxyUrl,
proxyRequestHeader,
proxyRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debug("Start llm.llmProxyClient func")
log.Infof("llm.llmProxyClient statusCode is:%s", statusCode)
log.Infof("llm.llmProxyClient intent responseBody is: %s", string(responseBody))
if statusCode == 200 {
proxyResponseBody, _ := proxyResponseHandler(responseBody, log)
//大模型返回的识别到的意图类型
if nil != proxyResponseBody && nil != proxyResponseBody.Choices && len(proxyResponseBody.Choices) > 0 {
category := proxyResponseBody.Choices[0].Message.Content
log.Infof("llmProxyClient intent response category is: %s", category)
//验证返回结果是否为 定义的枚举值结果集合,判断返回结果是否在预设的类型中。
for i := range config.SceneInfo.CategoryArr {
//防止空格、空字符串
if strings.TrimSpace(config.SceneInfo.CategoryArr[i]) == "" {
continue
}
//2种判定条件1.返回的category与该预设的场景完全一致 2.返回的category包含该预设的场景
if config.SceneInfo.CategoryArr[i] == category || strings.Contains(category, config.SceneInfo.CategoryArr[i]) {
// 把意图类型加入到Property中
log.Debug("llmProxyClient intent category set to Property")
proErr := proxywasm.SetProperty([]string{"intent_category"}, []byte(config.SceneInfo.CategoryArr[i]))
if proErr != nil {
log.Errorf("llmProxyClient proxywasm SetProperty error: %s", proErr.Error())
}
break
}
}
}
}
_ = proxywasm.ResumeHttpRequest()
return
}, config.LLMInfo.ProxyTimeout)
if llmProxyErr != nil {
log.Errorf("llmProxy intent error: %s", llmProxyErr.Error())
_ = proxywasm.ResumeHttpRequest()
}
log.Debug("end onHttpRequestHeaders function.")
return types.ActionPause
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
log.Debug("start onHttpResponseHeaders function.")
log.Debug("end onHttpResponseHeaders function.")
return types.ActionContinue
}
func onStreamingResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
log.Debug("start onStreamingResponseBody function.")
log.Debug("end onStreamingResponseBody function.")
return chunk
}
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
log.Debug("start onHttpResponseBody function.")
log.Debug("end onHttpResponseBody function.")
return types.ActionContinue
}
type ProxyRequest struct {
Model string `json:"model"`
Messages []ProxyRequestMessage `json:"messages"`
}
type ProxyRequestMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func generateProxyRequest(c *PluginConfig, texts []string, log wrapper.Log) (string, []byte, [][2]string) {
url := c.LLMInfo.ProxyPath
var userMessage ProxyRequestMessage
userMessage.Role = "user"
userMessage.Content = texts[0]
var messages []ProxyRequestMessage
messages = append(messages, userMessage)
data := ProxyRequest{
Model: c.LLMInfo.ProxyModel,
Messages: messages,
}
requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("[generateProxyRequest] Marshal json error:%s, data:%s.", err, data)
return "", nil, nil
}
headers := [][2]string{
{"Content-Type", "application/json"},
{"Authorization", "Bearer " + c.LLMInfo.ProxyApiKey},
}
return url, requestBody, headers
}
func zhToUnicode(raw []byte) ([]byte, error) {
str, err := strconv.Unquote(strings.Replace(strconv.Quote(string(raw)), `\\u`, `\u`, -1))
if err != nil {
return nil, err
}
return []byte(str), nil
}
type ProxyResponse struct {
Status int `json:"code"`
Id string `json:"id"`
Choices []ProxyResponseOutputChoices `json:"choices"`
}
type ProxyResponseOutputChoices struct {
FinishReason string `json:"finish_reason"`
Message ProxyResponseOutputChoicesMessage `json:"message"`
}
type ProxyResponseOutputChoicesMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func proxyResponseHandler(responseBody []byte, log wrapper.Log) (*ProxyResponse, error) {
var response ProxyResponse
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("[proxyResponseHandler]Unmarshal json error:%s", err)
return nil, err
}
return &response, nil
}
func getProxyResponseByExtractor(c *PluginConfig, responseBody []byte, log wrapper.Log) string {
bodyJson := gjson.ParseBytes(responseBody)
responseContent := strings.Trim(bodyJson.Get(c.KeyFrom.ResponseBody).Raw, `"`)
// llm返回的结果
originalAnswer, _ := zhToUnicode([]byte(responseContent))
return string(originalAnswer)
}