mirror of
https://github.com/alibaba/higress.git
synced 2026-05-08 04:17:27 +08:00
feat(ai-proxy): add context cleanup command support (#3409)
This commit is contained in:
@@ -421,6 +421,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN generic Provider 对应的Host
|
||||
// @Description zh-CN 仅适用于generic provider,用于覆盖请求转发的目标Host
|
||||
genericHost string `required:"false" yaml:"genericHost" json:"genericHost"`
|
||||
// @Title zh-CN 上下文清理命令
|
||||
// @Description zh-CN 配置清理命令文本列表,当请求的 messages 中存在完全匹配任意一个命令的 user 消息时,将该消息及之前所有非 system 消息清理掉,实现主动清理上下文的效果
|
||||
contextCleanupCommands []string `required:"false" yaml:"contextCleanupCommands" json:"contextCleanupCommands"`
|
||||
// @Title zh-CN 首包超时
|
||||
// @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0,表示不开启首包超时
|
||||
firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"`
|
||||
@@ -461,6 +464,10 @@ func (c *ProviderConfig) GetVllmServerHost() string {
|
||||
return c.vllmServerHost
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetContextCleanupCommands() []string {
|
||||
return c.contextCleanupCommands
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsOpenAIProtocol() bool {
|
||||
return c.protocol == protocolOpenAI
|
||||
}
|
||||
@@ -639,6 +646,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.vllmServerHost = json.Get("vllmServerHost").String()
|
||||
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
|
||||
c.doubaoDomain = json.Get("doubaoDomain").String()
|
||||
c.contextCleanupCommands = make([]string, 0)
|
||||
for _, cmd := range json.Get("contextCleanupCommands").Array() {
|
||||
if cmd.String() != "" {
|
||||
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -949,6 +962,16 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
log.Debugf("[Auto Protocol] converted Claude request body to OpenAI format")
|
||||
}
|
||||
|
||||
// handle context cleanup command for chat completion requests
|
||||
if apiName == ApiNameChatCompletion && len(c.contextCleanupCommands) > 0 {
|
||||
body, err = cleanupContextMessages(body, c.contextCleanupCommands)
|
||||
if err != nil {
|
||||
log.Warnf("[contextCleanup] failed to cleanup context messages: %v", err)
|
||||
// Continue processing even if cleanup fails
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
// use openai protocol (either original openai or converted from claude)
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
|
||||
@@ -73,6 +73,73 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupContextMessages 根据配置的清理命令清理上下文消息
|
||||
// 查找最后一个完全匹配任意 cleanupCommands 的 user 消息,将该消息及之前所有非 system 消息清理掉,只保留 system 消息
|
||||
func cleanupContextMessages(body []byte, cleanupCommands []string) ([]byte, error) {
|
||||
if len(cleanupCommands) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for context cleanup: %v", err)
|
||||
}
|
||||
|
||||
if len(request.Messages) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 从后往前查找最后一个匹配任意清理命令的 user 消息
|
||||
cleanupIndex := -1
|
||||
for i := len(request.Messages) - 1; i >= 0; i-- {
|
||||
msg := request.Messages[i]
|
||||
if msg.Role == roleUser {
|
||||
content := msg.StringContent()
|
||||
for _, cmd := range cleanupCommands {
|
||||
if content == cmd {
|
||||
cleanupIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if cleanupIndex != -1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 没有找到匹配的清理命令
|
||||
if cleanupIndex == -1 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
log.Debugf("[contextCleanup] found cleanup command at index %d, cleaning up messages", cleanupIndex)
|
||||
|
||||
// 构建新的消息列表:
|
||||
// 1. 保留 cleanupIndex 之前的 system 消息(只保留 system,其他都清理)
|
||||
// 2. 删除 cleanupIndex 位置的清理命令消息
|
||||
// 3. 保留 cleanupIndex 之后的所有消息
|
||||
var newMessages []chatMessage
|
||||
|
||||
// 处理 cleanupIndex 之前的消息,只保留 system
|
||||
for i := 0; i < cleanupIndex; i++ {
|
||||
msg := request.Messages[i]
|
||||
if msg.Role == roleSystem {
|
||||
newMessages = append(newMessages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// 跳过 cleanupIndex 位置的消息(清理命令本身)
|
||||
// 保留 cleanupIndex 之后的所有消息
|
||||
for i := cleanupIndex + 1; i < len(request.Messages); i++ {
|
||||
newMessages = append(newMessages, request.Messages[i])
|
||||
}
|
||||
|
||||
request.Messages = newMessages
|
||||
log.Debugf("[contextCleanup] messages after cleanup: %d", len(newMessages))
|
||||
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func ReplaceResponseBody(body []byte) error {
|
||||
log.Debugf("response body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||
|
||||
@@ -0,0 +1,253 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCleanupContextMessages(t *testing.T) {
|
||||
t.Run("empty_cleanup_commands", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
result, err := cleanupContextMessages(body, []string{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("no_matching_command", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"system","content":"你是助手"},{"role":"user","content":"hello"}]}`)
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文", "/clear"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("cleanup_with_single_command", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "你是一个助手"},
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
{Role: "user", Content: "新问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "system", output.Messages[0].Role)
|
||||
assert.Equal(t, "你是一个助手", output.Messages[0].Content)
|
||||
assert.Equal(t, "user", output.Messages[1].Role)
|
||||
assert.Equal(t, "新问题", output.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("cleanup_with_multiple_commands_match_first", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "你是一个助手"},
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "/clear"},
|
||||
{Role: "user", Content: "新问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文", "/clear", "重新开始"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "system", output.Messages[0].Role)
|
||||
assert.Equal(t, "user", output.Messages[1].Role)
|
||||
assert.Equal(t, "新问题", output.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("cleanup_removes_tool_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "你是一个助手"},
|
||||
{Role: "user", Content: "查天气"},
|
||||
{Role: "assistant", Content: ""},
|
||||
{Role: "tool", Content: "北京 25°C"},
|
||||
{Role: "assistant", Content: "北京今天25度"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
{Role: "user", Content: "新问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "system", output.Messages[0].Role)
|
||||
assert.Equal(t, "user", output.Messages[1].Role)
|
||||
})
|
||||
|
||||
t.Run("cleanup_keeps_multiple_system_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "系统提示1"},
|
||||
{Role: "system", Content: "系统提示2"},
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
{Role: "user", Content: "新问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, output.Messages, 3)
|
||||
assert.Equal(t, "system", output.Messages[0].Role)
|
||||
assert.Equal(t, "系统提示1", output.Messages[0].Content)
|
||||
assert.Equal(t, "system", output.Messages[1].Role)
|
||||
assert.Equal(t, "系统提示2", output.Messages[1].Content)
|
||||
assert.Equal(t, "user", output.Messages[2].Role)
|
||||
})
|
||||
|
||||
t.Run("cleanup_finds_last_matching_command", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "你是一个助手"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
{Role: "user", Content: "中间问题"},
|
||||
{Role: "assistant", Content: "中间回答"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
{Role: "user", Content: "最后问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 应该匹配最后一个清理命令,保留 system 和 "最后问题"
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "system", output.Messages[0].Role)
|
||||
assert.Equal(t, "user", output.Messages[1].Role)
|
||||
assert.Equal(t, "最后问题", output.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("cleanup_at_end_of_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "你是一个助手"},
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 清理命令在最后,只保留 system
|
||||
assert.Len(t, output.Messages, 1)
|
||||
assert.Equal(t, "system", output.Messages[0].Role)
|
||||
})
|
||||
|
||||
t.Run("cleanup_without_system_message", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "清理上下文"},
|
||||
{Role: "user", Content: "新问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 没有 system 消息,只保留清理命令之后的消息
|
||||
assert.Len(t, output.Messages, 1)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "新问题", output.Messages[0].Content)
|
||||
})
|
||||
|
||||
t.Run("cleanup_with_empty_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
err = json.Unmarshal(result, &output)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, output.Messages, 0)
|
||||
})
|
||||
|
||||
t.Run("cleanup_command_partial_match_not_triggered", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "你是一个助手"},
|
||||
{Role: "user", Content: "请清理上下文吧"},
|
||||
{Role: "assistant", Content: "好的"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 部分匹配不应触发清理
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("invalid_json_body", func(t *testing.T) {
|
||||
body := []byte(`invalid json`)
|
||||
result, err := cleanupContextMessages(body, []string{"清理上下文"})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user