mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat(ai-proxy): add mergeConsecutiveMessages option to merge consecutive same-role messages (#3598)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
6
.github/workflows/wasm-plugin-unit-test.yml
vendored
6
.github/workflows/wasm-plugin-unit-test.yml
vendored
@@ -199,12 +199,12 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Go 1.24
|
- name: Set up Go 1.25
|
||||||
uses: actions/setup-go@v4
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: 1.24
|
go-version: 1.25
|
||||||
cache: true
|
cache: true
|
||||||
|
|
||||||
- name: Install required tools
|
- name: Install required tools
|
||||||
run: |
|
run: |
|
||||||
go install github.com/wadey/gocovmerge@latest
|
go install github.com/wadey/gocovmerge@latest
|
||||||
|
|||||||
@@ -462,6 +462,9 @@ type ProviderConfig struct {
|
|||||||
// @Title zh-CN 智谱AI Code Plan 模式
|
// @Title zh-CN 智谱AI Code Plan 模式
|
||||||
// @Description zh-CN 仅适用于智谱AI服务。启用后将使用 /api/coding/paas/v4/chat/completions 接口
|
// @Description zh-CN 仅适用于智谱AI服务。启用后将使用 /api/coding/paas/v4/chat/completions 接口
|
||||||
zhipuCodePlanMode bool `required:"false" yaml:"zhipuCodePlanMode" json:"zhipuCodePlanMode"`
|
zhipuCodePlanMode bool `required:"false" yaml:"zhipuCodePlanMode" json:"zhipuCodePlanMode"`
|
||||||
|
// @Title zh-CN 合并连续同角色消息
|
||||||
|
// @Description zh-CN 开启后,若请求的 messages 中存在连续的同角色消息(如连续两条 user 消息),将其内容合并为一条,以满足要求严格轮流交替(user→assistant→user→...)的模型服务商的要求。
|
||||||
|
mergeConsecutiveMessages bool `required:"false" yaml:"mergeConsecutiveMessages" json:"mergeConsecutiveMessages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) GetId() string {
|
func (c *ProviderConfig) GetId() string {
|
||||||
@@ -681,6 +684,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
|||||||
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
|
c.contextCleanupCommands = append(c.contextCleanupCommands, cmd.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
c.mergeConsecutiveMessages = json.Get("mergeConsecutiveMessages").Bool()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ProviderConfig) Validate() error {
|
func (c *ProviderConfig) Validate() error {
|
||||||
@@ -1120,6 +1124,17 @@ func (c *ProviderConfig) handleRequestBody(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// merge consecutive same-role messages for providers that require strict role alternation
|
||||||
|
if apiName == ApiNameChatCompletion && c.mergeConsecutiveMessages {
|
||||||
|
body, err = mergeConsecutiveMessages(body)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("[mergeConsecutiveMessages] failed to merge messages: %v", err)
|
||||||
|
err = nil
|
||||||
|
} else {
|
||||||
|
log.Debugf("[mergeConsecutiveMessages] merged consecutive messages for provider: %s", c.typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convert developer role to system role for providers that don't support it
|
// convert developer role to system role for providers that don't support it
|
||||||
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
||||||
body, err = convertDeveloperRoleToSystem(body)
|
body, err = convertDeveloperRoleToSystem(body)
|
||||||
|
|||||||
@@ -154,6 +154,54 @@ func cleanupContextMessages(body []byte, cleanupCommands []string) ([]byte, erro
|
|||||||
return json.Marshal(request)
|
return json.Marshal(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeConsecutiveMessages merges consecutive messages of the same role (user or assistant).
|
||||||
|
// Many LLM providers require strict user↔assistant alternation and reject requests where
|
||||||
|
// two messages of the same role appear consecutively. When enabled, consecutive same-role
|
||||||
|
// messages have their content concatenated into a single message.
|
||||||
|
func mergeConsecutiveMessages(body []byte) ([]byte, error) {
|
||||||
|
request := &chatCompletionRequest{}
|
||||||
|
if err := json.Unmarshal(body, request); err != nil {
|
||||||
|
return body, fmt.Errorf("unable to unmarshal request for message merging: %v", err)
|
||||||
|
}
|
||||||
|
if len(request.Messages) <= 1 {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := false
|
||||||
|
result := make([]chatMessage, 0, len(request.Messages))
|
||||||
|
for _, msg := range request.Messages {
|
||||||
|
if len(result) > 0 &&
|
||||||
|
result[len(result)-1].Role == msg.Role &&
|
||||||
|
(msg.Role == roleUser || msg.Role == roleAssistant) {
|
||||||
|
last := &result[len(result)-1]
|
||||||
|
last.Content = mergeMessageContent(last.Content, msg.Content)
|
||||||
|
merged = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !merged {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
request.Messages = result
|
||||||
|
return json.Marshal(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeMessageContent concatenates two message content values.
|
||||||
|
// If both are plain strings they are joined with a blank line.
|
||||||
|
// Otherwise both are converted to content-block arrays and concatenated.
|
||||||
|
func mergeMessageContent(prev, curr any) any {
|
||||||
|
prevStr, prevIsStr := prev.(string)
|
||||||
|
currStr, currIsStr := curr.(string)
|
||||||
|
if prevIsStr && currIsStr {
|
||||||
|
return prevStr + "\n\n" + currStr
|
||||||
|
}
|
||||||
|
prevParts := (&chatMessage{Content: prev}).ParseContent()
|
||||||
|
currParts := (&chatMessage{Content: curr}).ParseContent()
|
||||||
|
return append(prevParts, currParts...)
|
||||||
|
}
|
||||||
|
|
||||||
func ReplaceResponseBody(body []byte) error {
|
func ReplaceResponseBody(body []byte) error {
|
||||||
log.Debugf("response body: %s", string(body))
|
log.Debugf("response body: %s", string(body))
|
||||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||||
|
|||||||
@@ -8,6 +8,131 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMergeConsecutiveMessages(t *testing.T) {
|
||||||
|
t.Run("no_consecutive_messages", func(t *testing.T) {
|
||||||
|
input := chatCompletionRequest{
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "user", Content: "你好"},
|
||||||
|
{Role: "assistant", Content: "你好!"},
|
||||||
|
{Role: "user", Content: "再见"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// No merging needed, returned body should be identical
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merges_consecutive_user_messages", func(t *testing.T) {
|
||||||
|
input := chatCompletionRequest{
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "user", Content: "第一条"},
|
||||||
|
{Role: "user", Content: "第二条"},
|
||||||
|
{Role: "assistant", Content: "回复"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var output chatCompletionRequest
|
||||||
|
require.NoError(t, json.Unmarshal(result, &output))
|
||||||
|
|
||||||
|
assert.Len(t, output.Messages, 2)
|
||||||
|
assert.Equal(t, "user", output.Messages[0].Role)
|
||||||
|
assert.Equal(t, "第一条\n\n第二条", output.Messages[0].Content)
|
||||||
|
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merges_consecutive_assistant_messages", func(t *testing.T) {
|
||||||
|
input := chatCompletionRequest{
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "user", Content: "问题"},
|
||||||
|
{Role: "assistant", Content: "第一段"},
|
||||||
|
{Role: "assistant", Content: "第二段"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var output chatCompletionRequest
|
||||||
|
require.NoError(t, json.Unmarshal(result, &output))
|
||||||
|
|
||||||
|
assert.Len(t, output.Messages, 2)
|
||||||
|
assert.Equal(t, "user", output.Messages[0].Role)
|
||||||
|
assert.Equal(t, "assistant", output.Messages[1].Role)
|
||||||
|
assert.Equal(t, "第一段\n\n第二段", output.Messages[1].Content)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merges_multiple_consecutive_same_role", func(t *testing.T) {
|
||||||
|
input := chatCompletionRequest{
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "user", Content: "A"},
|
||||||
|
{Role: "user", Content: "B"},
|
||||||
|
{Role: "user", Content: "C"},
|
||||||
|
{Role: "assistant", Content: "回复"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var output chatCompletionRequest
|
||||||
|
require.NoError(t, json.Unmarshal(result, &output))
|
||||||
|
|
||||||
|
assert.Len(t, output.Messages, 2)
|
||||||
|
assert.Equal(t, "A\n\nB\n\nC", output.Messages[0].Content)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("system_messages_not_merged", func(t *testing.T) {
|
||||||
|
input := chatCompletionRequest{
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "system", Content: "系统提示1"},
|
||||||
|
{Role: "system", Content: "系统提示2"},
|
||||||
|
{Role: "user", Content: "问题"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// system messages are not merged, body unchanged
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single_message_unchanged", func(t *testing.T) {
|
||||||
|
input := chatCompletionRequest{
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "user", Content: "只有一条"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json_body", func(t *testing.T) {
|
||||||
|
body := []byte(`invalid json`)
|
||||||
|
result, err := mergeConsecutiveMessages(body)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestCleanupContextMessages(t *testing.T) {
|
func TestCleanupContextMessages(t *testing.T) {
|
||||||
t.Run("empty_cleanup_commands", func(t *testing.T) {
|
t.Run("empty_cleanup_commands", func(t *testing.T) {
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
|||||||
Reference in New Issue
Block a user