update template decorator (#1142)

This commit is contained in:
rinfx
2024-07-22 17:04:55 +08:00
committed by GitHub
parent ef31e09310
commit 8c48fcb423
3 changed files with 60 additions and 90 deletions

View File

@@ -1,8 +1,7 @@
package main
import (
"errors"
"strings"
"encoding/json"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
@@ -20,66 +19,53 @@ func main() {
)
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type AIPromptDecoratorConfig struct {
decorators map[string]string
Prepend []Message `json:"prepend"`
Append []Message `json:"append"`
}
func removeBrackets(raw string) (string, error) {
startIndex := strings.Index(raw, "{")
endIndex := strings.LastIndex(raw, "}")
if startIndex == -1 || endIndex == -1 {
return raw, errors.New("message format is wrong!")
} else {
return raw[startIndex : endIndex+1], nil
}
}
func parseConfig(json gjson.Result, config *AIPromptDecoratorConfig, log wrapper.Log) error {
config.decorators = make(map[string]string)
for _, v := range json.Get("decorators").Array() {
config.decorators[v.Get("name").String()] = v.Get("decorator").Raw
// log.Info(v.Get("decorator").Raw)
}
return nil
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig, log wrapper.Log) error {
return json.Unmarshal([]byte(jsonConfig.Raw), config)
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, log wrapper.Log) types.Action {
decorator, _ := proxywasm.GetHttpRequestHeader("decorator")
if decorator == "" {
ctx.DontReadRequestBody()
return types.ActionContinue
}
ctx.SetContext("decorator", decorator)
proxywasm.RemoveHttpRequestHeader("decorator")
proxywasm.RemoveHttpRequestHeader("content-length")
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log wrapper.Log) types.Action {
decoratorName := ctx.GetContext("decorator").(string)
decorator := config.decorators[decoratorName]
messageJson := `{"messages":[]}`
prependMessage := gjson.Get(decorator, "prepend")
if prependMessage.Exists() {
for _, entry := range prependMessage.Array() {
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
for _, entry := range config.Prepend {
msg, err := json.Marshal(entry)
if err != nil {
log.Errorf("Failed to add prepend message, error: %v", err)
return types.ActionContinue
}
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", string(msg))
}
rawMessage := gjson.GetBytes(body, "messages")
if rawMessage.Exists() {
for _, entry := range rawMessage.Array() {
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
}
if !rawMessage.Exists() {
log.Errorf("Cannot find messages field in request body")
return types.ActionContinue
}
for _, entry := range rawMessage.Array() {
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
}
appendMessage := gjson.Get(decorator, "append")
if appendMessage.Exists() {
for _, entry := range appendMessage.Array() {
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", entry.Raw)
for _, entry := range config.Append {
msg, err := json.Marshal(entry)
if err != nil {
log.Errorf("Failed to add prepend message, error: %v", err)
return types.ActionContinue
}
messageJson, _ = sjson.SetRaw(messageJson, "messages.-1", string(msg))
}
newbody, err := sjson.SetRaw(string(body), "messages", gjson.Get(messageJson, "messages").Raw)