Files
higress/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go

213 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package provider
import (
"encoding/json"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
)
func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Messages == nil || len(request.Messages) == 0 {
return fmt.Errorf("no message found in the request body: %s", body)
}
return nil
}
func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func decodeImageGenerationRequest(body []byte, request *imageGenerationRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func decodeImageEditRequest(body []byte, request *imageEditRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func decodeImageVariationRequest(body []byte, request *imageVariationRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func replaceJsonRequestBody(request interface{}) error {
body, err := json.Marshal(request)
if err != nil {
return fmt.Errorf("unable to marshal request: %v", err)
}
log.Debugf("request body: %s", string(body))
err = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return fmt.Errorf("unable to replace the original request body: %v", err)
}
return err
}
func replaceRequestBody(body []byte) error {
log.Debugf("request body: %s", string(body))
err := proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return fmt.Errorf("unable to replace the original request body: %v", err)
}
return nil
}
func insertContextMessage(request *chatCompletionRequest, content string) {
fileMessage := chatMessage{
Role: roleSystem,
Content: content,
}
var firstNonSystemMessageIndex int
for i, message := range request.Messages {
if message.Role != roleSystem {
firstNonSystemMessageIndex = i
break
}
}
if firstNonSystemMessageIndex == 0 {
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
} else {
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
}
}
// cleanupContextMessages 根据配置的清理命令清理上下文消息
// 查找最后一个完全匹配任意 cleanupCommands 的 user 消息,将该消息及之前所有非 system 消息清理掉,只保留 system 消息
func cleanupContextMessages(body []byte, cleanupCommands []string) ([]byte, error) {
if len(cleanupCommands) == 0 {
return body, nil
}
request := &chatCompletionRequest{}
if err := json.Unmarshal(body, request); err != nil {
return body, fmt.Errorf("unable to unmarshal request for context cleanup: %v", err)
}
if len(request.Messages) == 0 {
return body, nil
}
// 从后往前查找最后一个匹配任意清理命令的 user 消息
cleanupIndex := -1
for i := len(request.Messages) - 1; i >= 0; i-- {
msg := request.Messages[i]
if msg.Role == roleUser {
content := msg.StringContent()
for _, cmd := range cleanupCommands {
if content == cmd {
cleanupIndex = i
break
}
}
if cleanupIndex != -1 {
break
}
}
}
// 没有找到匹配的清理命令
if cleanupIndex == -1 {
return body, nil
}
log.Debugf("[contextCleanup] found cleanup command at index %d, cleaning up messages", cleanupIndex)
// 构建新的消息列表:
// 1. 保留 cleanupIndex 之前的 system 消息(只保留 system其他都清理
// 2. 删除 cleanupIndex 位置的清理命令消息
// 3. 保留 cleanupIndex 之后的所有消息
var newMessages []chatMessage
// 处理 cleanupIndex 之前的消息,只保留 system
for i := 0; i < cleanupIndex; i++ {
msg := request.Messages[i]
if msg.Role == roleSystem {
newMessages = append(newMessages, msg)
}
}
// 跳过 cleanupIndex 位置的消息(清理命令本身)
// 保留 cleanupIndex 之后的所有消息
for i := cleanupIndex + 1; i < len(request.Messages); i++ {
newMessages = append(newMessages, request.Messages[i])
}
request.Messages = newMessages
log.Debugf("[contextCleanup] messages after cleanup: %d", len(newMessages))
return json.Marshal(request)
}
// mergeConsecutiveMessages merges consecutive messages of the same role (user or assistant).
// Many LLM providers require strict user↔assistant alternation and reject requests where
// two messages of the same role appear consecutively. When enabled, consecutive same-role
// messages have their content concatenated into a single message.
func mergeConsecutiveMessages(body []byte) ([]byte, error) {
request := &chatCompletionRequest{}
if err := json.Unmarshal(body, request); err != nil {
return body, fmt.Errorf("unable to unmarshal request for message merging: %v", err)
}
if len(request.Messages) <= 1 {
return body, nil
}
merged := false
result := make([]chatMessage, 0, len(request.Messages))
for _, msg := range request.Messages {
if len(result) > 0 &&
result[len(result)-1].Role == msg.Role &&
(msg.Role == roleUser || msg.Role == roleAssistant) {
last := &result[len(result)-1]
last.Content = mergeMessageContent(last.Content, msg.Content)
merged = true
continue
}
result = append(result, msg)
}
if !merged {
return body, nil
}
request.Messages = result
return json.Marshal(request)
}
// mergeMessageContent concatenates two message content values.
// If both are plain strings they are joined with a blank line.
// Otherwise both are converted to content-block arrays and concatenated.
func mergeMessageContent(prev, curr any) any {
prevStr, prevIsStr := prev.(string)
currStr, currIsStr := curr.(string)
if prevIsStr && currIsStr {
return prevStr + "\n\n" + currStr
}
prevParts := (&chatMessage{Content: prev}).ParseContent()
currParts := (&chatMessage{Content: curr}).ParseContent()
return append(prevParts, currParts...)
}
func ReplaceResponseBody(body []byte) error {
log.Debugf("response body: %s", string(body))
err := proxywasm.ReplaceHttpResponseBody(body)
if err != nil {
return fmt.Errorf("unable to replace the original response body: %v", err)
}
return nil
}