mirror of
https://github.com/alibaba/higress.git
synced 2026-02-28 06:30:49 +08:00
358 lines
14 KiB
Go
358 lines
14 KiB
Go
// File generated by hgctl. Modify as required.
|
||
// See:
|
||
|
||
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
const (
|
||
DefaultPrompt = "你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:'%s',预设的类别为'%s',直接返回一种具体类别,如果没有找到就返回'NotFound'。"
|
||
defaultTimeout = 10 * 1000 // ms
|
||
)
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"ai-intent",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
|
||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||
)
|
||
}
|
||
|
||
// @Name ai-intent
|
||
// @Category protocol
|
||
// @Phase AUTHN
|
||
// @Priority 1000
|
||
// @Title zh-CN AI intent
|
||
// @Description zh-CN 大模型意图识别
|
||
// @IconUrl
|
||
// @Version 0.1.0
|
||
//
|
||
// @Contact.name jose
|
||
// @Contact.url
|
||
// @Contact.email
|
||
//@Example
|
||
|
||
// scene:
|
||
// category: "金融|电商|法律|Higress"
|
||
// prompt:"你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:%s,预设的类别为%s,直接返回一种具体类别,如果没有找到就返回'NotFound'。"
|
||
// 例:"你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:今天天气怎么样?,预设的类别为 ["金融","电商","法律"],直接返回一种具体类别,如果没有找到就返回"NotFound"。"
|
||
|
||
type SceneInfo struct {
|
||
Category string `require:"true" yaml:"category" json:"category"`
|
||
Prompt string `require:"false" yaml:"prompt" json:"prompt"`
|
||
//解析category后的数组
|
||
CategoryArr []string `yaml:"-" json:"-"`
|
||
}
|
||
|
||
type LLMInfo struct {
|
||
ProxyServiceName string `require:"true" yaml:"proxyServiceName" json:"proxyServiceName"`
|
||
ProxyUrl string `require:"false" yaml:"proxyUrl" json:"proxyUrl"`
|
||
ProxyModel string `require:"false" yaml:"proxyModel" json:"proxyModel"`
|
||
// @Title zh-CN 大模型服务端口
|
||
// @Description zh-CN 服务端口
|
||
ProxyPort int64 `required:"false" yaml:"proxyPort" json:"proxyPort"`
|
||
// @Title zh-CN 大模型服务域名
|
||
// @Description zh-CN 大模型服务域名
|
||
ProxyDomain string `required:"false" yaml:"proxyDomain" json:"proxyDomain"`
|
||
ProxyTimeout uint32 `require:"false" yaml:"proxyTimeout" json:"proxyTimeout"`
|
||
// @Title zh-CN 大模型服务的API_KEY
|
||
// @Description zh-CN 大模型服务的API_KEY
|
||
ProxyApiKey string `require:"false" yaml:"proxyApiKey" json:"proxyApiKey"`
|
||
ProxyClient wrapper.HttpClient `yaml:"-" json:"-"`
|
||
// @Title zh-CN 大模型接口路径
|
||
// @Description zh-CN 大模型接口路径
|
||
ProxyPath string `yaml:"-" json:"-"`
|
||
}
|
||
|
||
type PluginConfig struct {
|
||
// @Title zh-CN 意图相关配置
|
||
// @Description zh-CN SceneInfo
|
||
SceneInfo SceneInfo `required:"true" yaml:"scene" json:"scene"`
|
||
// @Title zh-CN 大模型相关配置
|
||
// @Description zh-CN LLMInfo
|
||
LLMInfo LLMInfo `required:"true" yaml:"llm" json:"llm"`
|
||
// @Title zh-CN key 的来源
|
||
// @Description zh-CN 使用的 key 的提取方式
|
||
KeyFrom KVExtractor `required:"true" yaml:"keyFrom" json:"keyFrom"`
|
||
}
|
||
|
||
type KVExtractor struct {
|
||
// @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
|
||
RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"`
|
||
// @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串
|
||
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
|
||
}
|
||
|
||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||
log.Infof("config:%s", json.Raw)
|
||
// init scene
|
||
c.SceneInfo.Category = json.Get("scene.category").String()
|
||
log.Infof("SceneInfo.Category:%s", c.SceneInfo.Category)
|
||
if c.SceneInfo.Category == "" {
|
||
return errors.New("scene.category must not by empty")
|
||
}
|
||
c.SceneInfo.CategoryArr = strings.Split(c.SceneInfo.Category, "|")
|
||
if len(c.SceneInfo.CategoryArr) <= 0 {
|
||
return errors.New("scene.category resolve exception, should use '|' split")
|
||
}
|
||
c.SceneInfo.Prompt = json.Get("scene.prompt").String()
|
||
if c.SceneInfo.Prompt == "" {
|
||
c.SceneInfo.Prompt = DefaultPrompt
|
||
}
|
||
log.Infof("SceneInfo.Prompt:%s", c.SceneInfo.Prompt)
|
||
// init llmProxy
|
||
log.Debug("Start to init proxyService's http client.")
|
||
c.LLMInfo.ProxyServiceName = json.Get("llm.proxyServiceName").String()
|
||
log.Infof("ProxyServiceName: %s", c.LLMInfo.ProxyServiceName)
|
||
if c.LLMInfo.ProxyServiceName == "" {
|
||
return errors.New("llm.proxyServiceName must not by empty")
|
||
}
|
||
c.LLMInfo.ProxyUrl = json.Get("llm.proxyUrl").String()
|
||
log.Infof("c.LLMInfo.ProxyUrl:%s", c.LLMInfo.ProxyUrl)
|
||
if c.LLMInfo.ProxyUrl == "" {
|
||
return errors.New("llm.proxyUrl must not by empty")
|
||
}
|
||
//解析域名和path
|
||
parsedURL, err := url.Parse(c.LLMInfo.ProxyUrl)
|
||
if err != nil {
|
||
return errors.New("llm.proxyUrl parsing error")
|
||
}
|
||
c.LLMInfo.ProxyPath = parsedURL.Path
|
||
log.Infof("c.LLMInfo.ProxyPath:%s", c.LLMInfo.ProxyPath)
|
||
c.LLMInfo.ProxyDomain = json.Get("llm.proxyDomain").String()
|
||
//没有配置llm.proxyDomain时,则从proxyUrl中解析获取
|
||
if c.LLMInfo.ProxyDomain == "" {
|
||
hostName := parsedURL.Hostname()
|
||
log.Infof("llm.proxyUrl.hostName:%s", hostName)
|
||
if hostName != "" {
|
||
c.LLMInfo.ProxyDomain = hostName
|
||
}
|
||
}
|
||
log.Infof("c.LLMInfo.ProxyDomain:%s", c.LLMInfo.ProxyDomain)
|
||
c.LLMInfo.ProxyPort = json.Get("llm.proxyPort").Int()
|
||
// 没有配置llm.proxyPort时,则从proxyUrl中解析获取,如果解析的port为空,则http协议端口默认80,https端口默认443
|
||
if c.LLMInfo.ProxyPort <= 0 {
|
||
port := parsedURL.Port()
|
||
log.Infof("llm.proxyUrl.port:%s", port)
|
||
if port == "" {
|
||
c.LLMInfo.ProxyPort = 80
|
||
if parsedURL.Scheme == "https" {
|
||
c.LLMInfo.ProxyPort = 443
|
||
}
|
||
} else {
|
||
portNum, err := strconv.ParseInt(port, 10, 64)
|
||
if err != nil {
|
||
return errors.New("llm.proxyUrl.port parsing error")
|
||
}
|
||
c.LLMInfo.ProxyPort = portNum
|
||
}
|
||
}
|
||
log.Infof("c.LLMInfo.ProxyPort:%s", c.LLMInfo.ProxyPort)
|
||
c.LLMInfo.ProxyClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||
FQDN: c.LLMInfo.ProxyServiceName,
|
||
Port: c.LLMInfo.ProxyPort,
|
||
Host: c.LLMInfo.ProxyDomain,
|
||
})
|
||
c.LLMInfo.ProxyModel = json.Get("llm.proxyModel").String()
|
||
log.Infof("c.LLMInfo.ProxyModel:%s", c.LLMInfo.ProxyModel)
|
||
if c.LLMInfo.ProxyModel == "" {
|
||
c.LLMInfo.ProxyModel = "qwen-long"
|
||
}
|
||
c.LLMInfo.ProxyTimeout = uint32(json.Get("llm.proxyTimeout").Uint())
|
||
log.Infof("c.LLMInfo.ProxyTimeout:%s", c.LLMInfo.ProxyTimeout)
|
||
if c.LLMInfo.ProxyTimeout <= 0 {
|
||
c.LLMInfo.ProxyTimeout = defaultTimeout
|
||
}
|
||
c.LLMInfo.ProxyApiKey = json.Get("llm.proxyApiKey").String()
|
||
log.Infof("c.LLMInfo.ProxyApiKey:%s", c.LLMInfo.ProxyApiKey)
|
||
c.KeyFrom.RequestBody = json.Get("keyFrom.requestBody").String()
|
||
if c.KeyFrom.RequestBody == "" {
|
||
c.KeyFrom.RequestBody = "messages.@reverse.0.content"
|
||
}
|
||
c.KeyFrom.ResponseBody = json.Get("keyFrom.responseBody").String()
|
||
if c.KeyFrom.ResponseBody == "" {
|
||
c.KeyFrom.ResponseBody = "choices.0.message.content"
|
||
}
|
||
log.Debug("Init ai intent's components successfully.")
|
||
return nil
|
||
}
|
||
|
||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
log.Debug("start onHttpRequestHeaders function.")
|
||
|
||
log.Debug("end onHttpRequestHeaders function.")
|
||
return types.HeaderStopIteration
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Debug("start onHttpRequestBody function.")
|
||
bodyJson := gjson.ParseBytes(body)
|
||
TempKey := strings.Trim(bodyJson.Get(config.KeyFrom.RequestBody).Raw, `"`)
|
||
//原始问题
|
||
originalQuestion, _ := zhToUnicode([]byte(TempKey))
|
||
log.Infof("[onHttpRequestBody] originalQuestion is: %s", string(originalQuestion))
|
||
//prompt拼接,替换问题和预设的场景类别,参数占位替换
|
||
promptStr := fmt.Sprintf(config.SceneInfo.Prompt, string(originalQuestion), config.SceneInfo.Category)
|
||
log.Infof("[onHttpRequestBody] after prompt is: %s", promptStr)
|
||
proxyUrl, proxyRequestBody, proxyRequestHeader := generateProxyRequest(&config, []string{string(promptStr)}, log)
|
||
log.Infof("[onHttpRequestBody] proxyUrl is: %s", proxyUrl)
|
||
log.Infof("[onHttpRequestBody] proxyRequestBody is: %s", string(proxyRequestBody))
|
||
//调用大模型 获取意向类型
|
||
llmProxyErr := config.LLMInfo.ProxyClient.Post(
|
||
proxyUrl,
|
||
proxyRequestHeader,
|
||
proxyRequestBody,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
log.Debug("Start llm.llmProxyClient func")
|
||
log.Infof("llm.llmProxyClient statusCode is:%s", statusCode)
|
||
log.Infof("llm.llmProxyClient intent responseBody is: %s", string(responseBody))
|
||
if statusCode == 200 {
|
||
proxyResponseBody, _ := proxyResponseHandler(responseBody, log)
|
||
//大模型返回的识别到的意图类型
|
||
if nil != proxyResponseBody && nil != proxyResponseBody.Choices && len(proxyResponseBody.Choices) > 0 {
|
||
category := proxyResponseBody.Choices[0].Message.Content
|
||
log.Infof("llmProxyClient intent response category is: %s", category)
|
||
//验证返回结果是否为 定义的枚举值结果集合,判断返回结果是否在预设的类型中。
|
||
for i := range config.SceneInfo.CategoryArr {
|
||
//防止空格、空字符串
|
||
if strings.TrimSpace(config.SceneInfo.CategoryArr[i]) == "" {
|
||
continue
|
||
}
|
||
//2种判定条件,1.返回的category与该预设的场景完全一致 2.返回的category包含该预设的场景
|
||
if config.SceneInfo.CategoryArr[i] == category || strings.Contains(category, config.SceneInfo.CategoryArr[i]) {
|
||
// 把意图类型加入到Property中
|
||
log.Debug("llmProxyClient intent category set to Property")
|
||
proErr := proxywasm.SetProperty([]string{"intent_category"}, []byte(config.SceneInfo.CategoryArr[i]))
|
||
if proErr != nil {
|
||
log.Errorf("llmProxyClient proxywasm SetProperty error: %s", proErr.Error())
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return
|
||
}, config.LLMInfo.ProxyTimeout)
|
||
if llmProxyErr != nil {
|
||
log.Errorf("llmProxy intent error: %s", llmProxyErr.Error())
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
}
|
||
log.Debug("end onHttpRequestHeaders function.")
|
||
return types.ActionPause
|
||
}
|
||
|
||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
log.Debug("start onHttpResponseHeaders function.")
|
||
|
||
log.Debug("end onHttpResponseHeaders function.")
|
||
return types.ActionContinue
|
||
}
|
||
|
||
func onStreamingResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||
log.Debug("start onStreamingResponseBody function.")
|
||
|
||
log.Debug("end onStreamingResponseBody function.")
|
||
return chunk
|
||
}
|
||
|
||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Debug("start onHttpResponseBody function.")
|
||
|
||
log.Debug("end onHttpResponseBody function.")
|
||
return types.ActionContinue
|
||
}
|
||
|
||
type ProxyRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ProxyRequestMessage `json:"messages"`
|
||
}
|
||
|
||
type ProxyRequestMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
func generateProxyRequest(c *PluginConfig, texts []string, log wrapper.Log) (string, []byte, [][2]string) {
|
||
url := c.LLMInfo.ProxyPath
|
||
var userMessage ProxyRequestMessage
|
||
userMessage.Role = "user"
|
||
userMessage.Content = texts[0]
|
||
var messages []ProxyRequestMessage
|
||
messages = append(messages, userMessage)
|
||
data := ProxyRequest{
|
||
Model: c.LLMInfo.ProxyModel,
|
||
Messages: messages,
|
||
}
|
||
requestBody, err := json.Marshal(data)
|
||
if err != nil {
|
||
log.Errorf("[generateProxyRequest] Marshal json error:%s, data:%s.", err, data)
|
||
return "", nil, nil
|
||
}
|
||
|
||
headers := [][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{"Authorization", "Bearer " + c.LLMInfo.ProxyApiKey},
|
||
}
|
||
return url, requestBody, headers
|
||
}
|
||
|
||
func zhToUnicode(raw []byte) ([]byte, error) {
|
||
str, err := strconv.Unquote(strings.Replace(strconv.Quote(string(raw)), `\\u`, `\u`, -1))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return []byte(str), nil
|
||
}
|
||
|
||
type ProxyResponse struct {
|
||
Status int `json:"code"`
|
||
Id string `json:"id"`
|
||
Choices []ProxyResponseOutputChoices `json:"choices"`
|
||
}
|
||
|
||
type ProxyResponseOutputChoices struct {
|
||
FinishReason string `json:"finish_reason"`
|
||
Message ProxyResponseOutputChoicesMessage `json:"message"`
|
||
}
|
||
|
||
type ProxyResponseOutputChoicesMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
func proxyResponseHandler(responseBody []byte, log wrapper.Log) (*ProxyResponse, error) {
|
||
var response ProxyResponse
|
||
err := json.Unmarshal(responseBody, &response)
|
||
if err != nil {
|
||
log.Errorf("[proxyResponseHandler]Unmarshal json error:%s", err)
|
||
return nil, err
|
||
}
|
||
return &response, nil
|
||
}
|
||
|
||
func getProxyResponseByExtractor(c *PluginConfig, responseBody []byte, log wrapper.Log) string {
|
||
bodyJson := gjson.ParseBytes(responseBody)
|
||
responseContent := strings.Trim(bodyJson.Get(c.KeyFrom.ResponseBody).Raw, `"`)
|
||
// llm返回的结果
|
||
originalAnswer, _ := zhToUnicode([]byte(responseContent))
|
||
return string(originalAnswer)
|
||
}
|