Files
higress/plugins/wasm-go/extensions/ai-transformer/main.go
2024-06-18 17:51:38 +08:00

177 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func main() {
wrapper.SetCtx(
"ai-transformer",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
type AITransformerConfig struct {
client wrapper.HttpClient
requestTransformEnable bool
requestTransformPrompt string
responseTransformEnable bool
responseTransformPrompt string
providerAPIKey string
}
const llmRequestTemplate = `{
"model": "qwen-max",
"input":{
"messages":[
{
"role": "system",
"content": "假设你是一个http 1.1协议专家你的回答应该只包含http报文除此之外不要有任何其他内容。"
},
{
"role": "system",
"content": ""
},
{
"role": "user",
"content": ""
}
]
}
}`
func parseConfig(json gjson.Result, config *AITransformerConfig, log wrapper.Log) error {
config.requestTransformEnable = json.Get("request.enable").Bool()
config.requestTransformPrompt = json.Get("request.prompt").String()
config.responseTransformEnable = json.Get("response.enable").Bool()
config.responseTransformPrompt = json.Get("response.prompt").String()
config.providerAPIKey = json.Get("provider.apiKey").String()
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("provider.serviceName").String(),
Port: 443,
Domain: json.Get("provider.domain").String(),
})
return nil
}
func getSplitPos(header string) int {
for i, ch := range header {
if ch == ':' && i != 0 {
return i
}
}
return -1
}
func extraceHttpFrame(frame string) ([][2]string, []byte, error) {
pos := strings.Index(frame, "\n\n")
headers := [][2]string{}
for _, header := range strings.Split(frame[:pos], "\n") {
splitPos := getSplitPos(header)
if splitPos == -1 {
return nil, nil, errors.New("invalid http frame.")
}
headers = append(headers, [2]string{header[:splitPos], header[splitPos+1:]})
}
body := []byte(frame[pos+2:])
return headers, body, nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
log.Info("onHttpRequestHeaders")
if !config.requestTransformEnable || config.requestTransformPrompt == "" {
ctx.DontReadRequestBody()
return types.ActionContinue
} else {
return types.HeaderStopIteration
}
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
log.Info("onHttpRequestBody")
headers, err := proxywasm.GetHttpRequestHeaders()
if err != nil {
log.Error("Failed to get http response headers.")
return types.ActionContinue
}
headerStr := ""
for _, hd := range headers {
headerStr += hd[0] + ":" + hd[1] + "\n"
}
var llmRequestBody string
llmRequestBody, _ = sjson.Set(llmRequestTemplate, "input.messages.1.content", config.requestTransformPrompt)
llmRequestBody, _ = sjson.Set(llmRequestBody, "input.messages.2.content", headerStr+"\n"+string(body))
hds := [][2]string{{"Authorization", "Bearer " + config.providerAPIKey}, {"Content-Type", "application/json"}}
log.Info(headerStr + "\n" + string(body))
config.client.Post(
"/api/v1/services/aigc/text-generation/generation",
hds,
[]byte(llmRequestBody),
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
newHeaders, newBody, err := extraceHttpFrame(gjson.GetBytes(responseBody, "output.text").String())
if err == nil {
proxywasm.ReplaceHttpRequestHeaders(newHeaders)
proxywasm.ReplaceHttpRequestBody(newBody)
}
proxywasm.ResumeHttpRequest()
},
50000,
)
return types.ActionPause
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
if !config.responseTransformEnable || config.responseTransformPrompt == "" {
ctx.DontReadResponseBody()
return types.ActionContinue
} else {
return types.HeaderStopIteration
}
}
func onHttpResponseBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
headers, err := proxywasm.GetHttpResponseHeaders()
if err != nil {
log.Error("Failed to get http response headers.")
return types.ActionContinue
}
headerStr := ""
for _, hd := range headers {
headerStr += hd[0] + ":" + hd[1] + "\n"
}
var llmRequestBody string
llmRequestBody, _ = sjson.Set(llmRequestTemplate, "input.messages.1.content", config.responseTransformPrompt)
llmRequestBody, _ = sjson.Set(llmRequestBody, "input.messages.2.content", headerStr+"\n"+string(body))
hds := [][2]string{{"Authorization", "Bearer " + config.providerAPIKey}, {"Content-Type", "application/json"}}
log.Info(headerStr + "\n" + string(body))
config.client.Post(
"/api/v1/services/aigc/text-generation/generation",
hds,
[]byte(llmRequestBody),
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
newHeaders, newBody, err := extraceHttpFrame(gjson.GetBytes(responseBody, "output.text").String())
if err == nil {
proxywasm.ReplaceHttpResponseHeaders(newHeaders)
proxywasm.ReplaceHttpResponseBody(newBody)
}
proxywasm.ResumeHttpResponse()
},
50000,
)
return types.ActionPause
}