mirror of
https://github.com/alibaba/higress.git
synced 2026-03-04 16:40:50 +08:00
177 lines
5.5 KiB
Go
177 lines
5.5 KiB
Go
package main
|
||
|
||
import (
|
||
"errors"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/sjson"
|
||
)
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"ai-transformer",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||
)
|
||
}
|
||
|
||
type AITransformerConfig struct {
|
||
client wrapper.HttpClient
|
||
requestTransformEnable bool
|
||
requestTransformPrompt string
|
||
responseTransformEnable bool
|
||
responseTransformPrompt string
|
||
providerAPIKey string
|
||
}
|
||
|
||
const llmRequestTemplate = `{
|
||
"model": "qwen-max",
|
||
"input":{
|
||
"messages":[
|
||
{
|
||
"role": "system",
|
||
"content": "假设你是一个http 1.1协议专家,你的回答应该只包含http报文,除此之外不要有任何其他内容。"
|
||
},
|
||
{
|
||
"role": "system",
|
||
"content": ""
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": ""
|
||
}
|
||
]
|
||
}
|
||
}`
|
||
|
||
func parseConfig(json gjson.Result, config *AITransformerConfig, log wrapper.Log) error {
|
||
config.requestTransformEnable = json.Get("request.enable").Bool()
|
||
config.requestTransformPrompt = json.Get("request.prompt").String()
|
||
config.responseTransformEnable = json.Get("response.enable").Bool()
|
||
config.responseTransformPrompt = json.Get("response.prompt").String()
|
||
config.providerAPIKey = json.Get("provider.apiKey").String()
|
||
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||
ServiceName: json.Get("provider.serviceName").String(),
|
||
Port: 443,
|
||
Domain: json.Get("provider.domain").String(),
|
||
})
|
||
return nil
|
||
}
|
||
|
||
func getSplitPos(header string) int {
|
||
for i, ch := range header {
|
||
if ch == ':' && i != 0 {
|
||
return i
|
||
}
|
||
}
|
||
return -1
|
||
}
|
||
|
||
func extraceHttpFrame(frame string) ([][2]string, []byte, error) {
|
||
pos := strings.Index(frame, "\n\n")
|
||
headers := [][2]string{}
|
||
for _, header := range strings.Split(frame[:pos], "\n") {
|
||
splitPos := getSplitPos(header)
|
||
if splitPos == -1 {
|
||
return nil, nil, errors.New("invalid http frame.")
|
||
}
|
||
headers = append(headers, [2]string{header[:splitPos], header[splitPos+1:]})
|
||
}
|
||
body := []byte(frame[pos+2:])
|
||
return headers, body, nil
|
||
}
|
||
|
||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
|
||
log.Info("onHttpRequestHeaders")
|
||
if !config.requestTransformEnable || config.requestTransformPrompt == "" {
|
||
ctx.DontReadRequestBody()
|
||
return types.ActionContinue
|
||
} else {
|
||
return types.HeaderStopIteration
|
||
}
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Info("onHttpRequestBody")
|
||
headers, err := proxywasm.GetHttpRequestHeaders()
|
||
if err != nil {
|
||
log.Error("Failed to get http response headers.")
|
||
return types.ActionContinue
|
||
}
|
||
headerStr := ""
|
||
for _, hd := range headers {
|
||
headerStr += hd[0] + ":" + hd[1] + "\n"
|
||
}
|
||
var llmRequestBody string
|
||
llmRequestBody, _ = sjson.Set(llmRequestTemplate, "input.messages.1.content", config.requestTransformPrompt)
|
||
llmRequestBody, _ = sjson.Set(llmRequestBody, "input.messages.2.content", headerStr+"\n"+string(body))
|
||
hds := [][2]string{{"Authorization", "Bearer " + config.providerAPIKey}, {"Content-Type", "application/json"}}
|
||
log.Info(headerStr + "\n" + string(body))
|
||
config.client.Post(
|
||
"/api/v1/services/aigc/text-generation/generation",
|
||
hds,
|
||
[]byte(llmRequestBody),
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
newHeaders, newBody, err := extraceHttpFrame(gjson.GetBytes(responseBody, "output.text").String())
|
||
if err == nil {
|
||
proxywasm.ReplaceHttpRequestHeaders(newHeaders)
|
||
proxywasm.ReplaceHttpRequestBody(newBody)
|
||
}
|
||
proxywasm.ResumeHttpRequest()
|
||
},
|
||
50000,
|
||
)
|
||
|
||
return types.ActionPause
|
||
}
|
||
|
||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
|
||
if !config.responseTransformEnable || config.responseTransformPrompt == "" {
|
||
ctx.DontReadResponseBody()
|
||
return types.ActionContinue
|
||
} else {
|
||
return types.HeaderStopIteration
|
||
}
|
||
}
|
||
|
||
func onHttpResponseBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
|
||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||
if err != nil {
|
||
log.Error("Failed to get http response headers.")
|
||
return types.ActionContinue
|
||
}
|
||
headerStr := ""
|
||
for _, hd := range headers {
|
||
headerStr += hd[0] + ":" + hd[1] + "\n"
|
||
}
|
||
var llmRequestBody string
|
||
llmRequestBody, _ = sjson.Set(llmRequestTemplate, "input.messages.1.content", config.responseTransformPrompt)
|
||
llmRequestBody, _ = sjson.Set(llmRequestBody, "input.messages.2.content", headerStr+"\n"+string(body))
|
||
hds := [][2]string{{"Authorization", "Bearer " + config.providerAPIKey}, {"Content-Type", "application/json"}}
|
||
log.Info(headerStr + "\n" + string(body))
|
||
config.client.Post(
|
||
"/api/v1/services/aigc/text-generation/generation",
|
||
hds,
|
||
[]byte(llmRequestBody),
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
newHeaders, newBody, err := extraceHttpFrame(gjson.GetBytes(responseBody, "output.text").String())
|
||
if err == nil {
|
||
proxywasm.ReplaceHttpResponseHeaders(newHeaders)
|
||
proxywasm.ReplaceHttpResponseBody(newBody)
|
||
}
|
||
proxywasm.ResumeHttpResponse()
|
||
},
|
||
50000,
|
||
)
|
||
|
||
return types.ActionPause
|
||
}
|