mirror of
https://github.com/alibaba/higress.git
synced 2026-03-17 00:40:48 +08:00
feat(ai-proxy): add mergeConsecutiveMessages option to merge consecutive same-role messages (#3598)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
6
.github/workflows/wasm-plugin-unit-test.yml
vendored
6
.github/workflows/wasm-plugin-unit-test.yml
vendored
@@ -199,12 +199,12 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go 1.24
|
||||
- name: Set up Go 1.25
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.24
|
||||
go-version: 1.25
|
||||
cache: true
|
||||
|
||||
|
||||
- name: Install required tools
|
||||
run: |
|
||||
go install github.com/wadey/gocovmerge@latest
|
||||
|
||||
@@ -462,6 +462,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 智谱AI Code Plan 模式
|
||||
// @Description zh-CN 仅适用于智谱AI服务。启用后将使用 /api/coding/paas/v4/chat/completions 接口
|
||||
zhipuCodePlanMode bool `required:"false" yaml:"zhipuCodePlanMode" json:"zhipuCodePlanMode"`
|
||||
// @Title zh-CN 合并连续同角色消息
|
||||
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
||||
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -681,6 +684,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
|
||||
}
|
||||
}
|
||||
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -1120,6 +1124,17 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// merge consecutive same-role messages for providers that require strict role alternation
|
||||
if apiName == ApiNameChatCompletion && c.mergeConsecutiveMessages {
|
||||
body, err = mergeConsecutiveMessages(body)
|
||||
if err != nil {
|
||||
log.Warnf("[mergeConsecutiveMessages] failed to merge messages: %v", err)
|
||||
err = nil
|
||||
} else {
|
||||
log.Debugf("[mergeConsecutiveMessages] merged consecutive messages for provider: %s", c.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// convert developer role to system role for providers that don't support it
|
||||
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
||||
body, err = convertDeveloperRoleToSystem(body)
|
||||
|
||||
@@ -154,6 +154,54 @@ func cleanupContextMessages(body []byte, cleanupCommands []string) ([]byte, erro
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
// mergeConsecutiveMessages merges consecutive messages of the same role (user or assistant).
|
||||
// Many LLM providers require strict user↔assistant alternation and reject requests where
|
||||
// two messages of the same role appear consecutively. When enabled, consecutive same-role
|
||||
// messages have their content concatenated into a single message.
|
||||
func mergeConsecutiveMessages(body []byte) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for message merging: %v", err)
|
||||
}
|
||||
if len(request.Messages) <= 1 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
merged := false
|
||||
result := make([]chatMessage, 0, len(request.Messages))
|
||||
for _, msg := range request.Messages {
|
||||
if len(result) > 0 &&
|
||||
result[len(result)-1].Role == msg.Role &&
|
||||
(msg.Role == roleUser || msg.Role == roleAssistant) {
|
||||
last := &result[len(result)-1]
|
||||
last.Content = mergeMessageContent(last.Content, msg.Content)
|
||||
merged = true
|
||||
continue
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
|
||||
if !merged {
|
||||
return body, nil
|
||||
}
|
||||
request.Messages = result
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
// mergeMessageContent concatenates two message content values.
|
||||
// If both are plain strings they are joined with a blank line.
|
||||
// Otherwise both are converted to content-block arrays and concatenated.
|
||||
func mergeMessageContent(prev, curr any) any {
|
||||
prevStr, prevIsStr := prev.(string)
|
||||
currStr, currIsStr := curr.(string)
|
||||
if prevIsStr && currIsStr {
|
||||
return prevStr + "\n\n" + currStr
|
||||
}
|
||||
prevParts := (&chatMessage{Content: prev}).ParseContent()
|
||||
currParts := (&chatMessage{Content: curr}).ParseContent()
|
||||
return append(prevParts, currParts...)
|
||||
}
|
||||
|
||||
func ReplaceResponseBody(body []byte) error {
|
||||
log.Debugf("response body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||
|
||||
@@ -8,6 +8,131 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeConsecutiveMessages(t *testing.T) {
|
||||
t.Run("no_consecutive_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "你好"},
|
||||
{Role: "assistant", Content: "你好!"},
|
||||
{Role: "user", Content: "再见"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
// No merging needed, returned body should be identical
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("merges_consecutive_user_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "第一条"},
|
||||
{Role: "user", Content: "第二条"},
|
||||
{Role: "assistant", Content: "回复"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "第一条\n\n第二条", output.Messages[0].Content)
|
||||
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||
})
|
||||
|
||||
t.Run("merges_consecutive_assistant_messages", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "问题"},
|
||||
{Role: "assistant", Content: "第一段"},
|
||||
{Role: "assistant", Content: "第二段"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "user", output.Messages[0].Role)
|
||||
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||
assert.Equal(t, "第一段\n\n第二段", output.Messages[1].Content)
|
||||
})
|
||||
|
||||
t.Run("merges_multiple_consecutive_same_role", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "A"},
|
||||
{Role: "user", Content: "B"},
|
||||
{Role: "user", Content: "C"},
|
||||
{Role: "assistant", Content: "回复"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var output chatCompletionRequest
|
||||
require.NoError(t, json.Unmarshal(result, &output))
|
||||
|
||||
assert.Len(t, output.Messages, 2)
|
||||
assert.Equal(t, "A\n\nB\n\nC", output.Messages[0].Content)
|
||||
})
|
||||
|
||||
t.Run("system_messages_not_merged", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: "系统提示1"},
|
||||
{Role: "system", Content: "系统提示2"},
|
||||
{Role: "user", Content: "问题"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
// system messages are not merged, body unchanged
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("single_message_unchanged", func(t *testing.T) {
|
||||
input := chatCompletionRequest{
|
||||
Messages: []chatMessage{
|
||||
{Role: "user", Content: "只有一条"},
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
|
||||
t.Run("invalid_json_body", func(t *testing.T) {
|
||||
body := []byte(`invalid json`)
|
||||
result, err := mergeConsecutiveMessages(body)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCleanupContextMessages(t *testing.T) {
|
||||
t.Run("empty_cleanup_commands", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
Reference in New Issue
Block a user