mirror of
https://github.com/alibaba/higress.git
synced 2026-05-08 04:17:27 +08:00
fix(ai-proxy): harden Claude stream conversion compatibility (#3733)
This commit is contained in:
@@ -424,6 +424,13 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
continue
|
||||
}
|
||||
|
||||
// Some providers keep sending duplicate usage chunks after the stream
|
||||
// has already been finalized. Ignore those trailing chunks.
|
||||
if c.messageStopSent {
|
||||
log.Debugf("[OpenAI->Claude] Ignoring chunk after message_stop: %s", data)
|
||||
continue
|
||||
}
|
||||
|
||||
var openaiStreamResponse chatCompletionResponse
|
||||
if err := json.Unmarshal([]byte(data), &openaiStreamResponse); err != nil {
|
||||
log.Debugf("unable to unmarshal openai stream response: %v, data: %s", err, data)
|
||||
@@ -451,6 +458,19 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
|
||||
return claudeChunk, nil
|
||||
}
|
||||
|
||||
func normalizeFinishReason(finishReason *string) (string, bool) {
|
||||
if finishReason == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(*finishReason)
|
||||
if normalized == "" || strings.EqualFold(normalized, "null") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return normalized, true
|
||||
}
|
||||
|
||||
// buildClaudeStreamResponse builds Claude streaming responses from OpenAI streaming response
|
||||
func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpContext, openaiResponse *chatCompletionResponse) []*claudeTextGenStreamResponse {
|
||||
var choice chatCompletionChoice
|
||||
@@ -469,7 +489,7 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
// Log what we're processing
|
||||
hasRole := choice.Delta != nil && choice.Delta.Role != ""
|
||||
hasContent := choice.Delta != nil && choice.Delta.Content != ""
|
||||
hasFinishReason := choice.FinishReason != nil
|
||||
finishReason, hasFinishReason := normalizeFinishReason(choice.FinishReason)
|
||||
hasUsage := openaiResponse.Usage != nil
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Processing OpenAI chunk - Role: %v, Content: %v, FinishReason: %v, Usage: %v",
|
||||
@@ -688,9 +708,9 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
}
|
||||
|
||||
// Handle finish reason
|
||||
if choice.FinishReason != nil {
|
||||
claudeFinishReason := openAIFinishReasonToClaude(*choice.FinishReason)
|
||||
log.Debugf("[OpenAI->Claude] Processing finish_reason: %s -> %s", *choice.FinishReason, claudeFinishReason)
|
||||
if hasFinishReason {
|
||||
claudeFinishReason := openAIFinishReasonToClaude(finishReason)
|
||||
log.Debugf("[OpenAI->Claude] Processing finish_reason: %s -> %s", finishReason, claudeFinishReason)
|
||||
|
||||
// Send content_block_stop for any active content blocks
|
||||
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
|
||||
@@ -764,6 +784,7 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
// Note: Some providers may send usage in the same chunk as finish_reason,
|
||||
// so we check for usage regardless of whether finish_reason is present
|
||||
if openaiResponse.Usage != nil {
|
||||
stopReasonIncluded := false
|
||||
log.Debugf("[OpenAI->Claude] Processing usage info - input: %d, output: %d",
|
||||
openaiResponse.Usage.PromptTokens, openaiResponse.Usage.CompletionTokens)
|
||||
|
||||
@@ -784,13 +805,15 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
|
||||
log.Debugf("[OpenAI->Claude] Combining cached stop_reason %s with usage", *c.pendingStopReason)
|
||||
messageDelta.Delta.StopReason = c.pendingStopReason
|
||||
c.pendingStopReason = nil // Clear cache
|
||||
stopReasonIncluded = true
|
||||
}
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Generated message_delta event with usage and stop_reason")
|
||||
responses = append(responses, messageDelta)
|
||||
|
||||
// Send message_stop after combined message_delta
|
||||
if !c.messageStopSent {
|
||||
// Send message_stop only when this usage chunk carries a real stop_reason.
|
||||
// Some providers report incremental usage in every chunk before finishing.
|
||||
if stopReasonIncluded && !c.messageStopSent {
|
||||
c.messageStopSent = true
|
||||
log.Debugf("[OpenAI->Claude] Generated message_stop event")
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
|
||||
@@ -2,6 +2,7 @@ package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
@@ -978,3 +979,188 @@ func TestStripCchFromBillingHeader(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeFinishReason(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input *string
|
||||
wantReason string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "nil finish reason",
|
||||
input: nil,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "empty finish reason",
|
||||
input: stringPtr(""),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace finish reason",
|
||||
input: stringPtr(" "),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "string null finish reason",
|
||||
input: stringPtr("null"),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "uppercase string null finish reason",
|
||||
input: stringPtr("NULL"),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "valid finish reason",
|
||||
input: stringPtr("length"),
|
||||
wantReason: "length",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotReason, gotValid := normalizeFinishReason(tt.input)
|
||||
assert.Equal(t, tt.wantReason, gotReason)
|
||||
assert.Equal(t, tt.wantValid, gotValid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_ConvertOpenAIStreamResponseToClaude_Compatibility(t *testing.T) {
|
||||
t.Run("finish_reason empty string should not stop stream", func(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
chunk1 := `data: {"id":"stream-1","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":""}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
out1, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunk1))
|
||||
require.NoError(t, err)
|
||||
events1 := parseClaudeSSEEvents(t, out1)
|
||||
require.Len(t, events1, 1)
|
||||
assert.Equal(t, "message_start", events1[0].Name)
|
||||
|
||||
chunk2 := `data: {"id":"stream-1","choices":[{"index":0,"delta":{"reasoning_content":"Let"},"finish_reason":""}],"created":1,"model":"m","object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":1,"total_tokens":11}}` + "\n\n"
|
||||
out2, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunk2))
|
||||
require.NoError(t, err)
|
||||
events2 := parseClaudeSSEEvents(t, out2)
|
||||
require.Len(t, events2, 3)
|
||||
assert.Equal(t, "content_block_start", events2[0].Name)
|
||||
assert.Equal(t, "content_block_delta", events2[1].Name)
|
||||
assert.Equal(t, "message_delta", events2[2].Name)
|
||||
assert.Nil(t, events2[2].Payload.Delta.StopReason, "usage chunk without real finish_reason must not carry stop_reason")
|
||||
|
||||
eventNames := []string{events2[0].Name, events2[1].Name, events2[2].Name}
|
||||
assert.NotContains(t, eventNames, "content_block_stop")
|
||||
assert.NotContains(t, eventNames, "message_stop")
|
||||
})
|
||||
|
||||
t.Run("usage in every chunk should not trigger early message_stop", func(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
chunkStart := `data: {"id":"stream-2","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
_, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunkStart))
|
||||
require.NoError(t, err)
|
||||
|
||||
chunkThinking1 := `data: {"id":"stream-2","choices":[{"index":0,"delta":{"reasoning_content":"Let"},"finish_reason":null}],"created":1,"model":"m","object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":1,"total_tokens":11}}` + "\n\n"
|
||||
outThinking1, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunkThinking1))
|
||||
require.NoError(t, err)
|
||||
eventsThinking1 := parseClaudeSSEEvents(t, outThinking1)
|
||||
require.Len(t, eventsThinking1, 3)
|
||||
assert.Equal(t, "message_delta", eventsThinking1[2].Name)
|
||||
assert.Nil(t, eventsThinking1[2].Payload.Delta.StopReason)
|
||||
|
||||
chunkThinking2 := `data: {"id":"stream-2","choices":[{"index":0,"delta":{"reasoning_content":" me"},"finish_reason":null}],"created":1,"model":"m","object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":2,"total_tokens":12}}` + "\n\n"
|
||||
outThinking2, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunkThinking2))
|
||||
require.NoError(t, err)
|
||||
eventsThinking2 := parseClaudeSSEEvents(t, outThinking2)
|
||||
require.Len(t, eventsThinking2, 2)
|
||||
assert.Equal(t, "content_block_delta", eventsThinking2[0].Name)
|
||||
assert.Equal(t, "message_delta", eventsThinking2[1].Name)
|
||||
assert.Nil(t, eventsThinking2[1].Payload.Delta.StopReason)
|
||||
|
||||
chunkFinishNoUsage := `data: {"id":"stream-2","choices":[{"index":0,"delta":{"content":"","reasoning_content":""},"finish_reason":"length"}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
outFinishNoUsage, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunkFinishNoUsage))
|
||||
require.NoError(t, err)
|
||||
eventsFinishNoUsage := parseClaudeSSEEvents(t, outFinishNoUsage)
|
||||
require.Len(t, eventsFinishNoUsage, 1)
|
||||
assert.Equal(t, "content_block_stop", eventsFinishNoUsage[0].Name)
|
||||
|
||||
chunkFinalUsage := `data: {"id":"stream-2","choices":[],"created":1,"model":"m","object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":100,"total_tokens":110}}` + "\n\n"
|
||||
outFinalUsage, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunkFinalUsage))
|
||||
require.NoError(t, err)
|
||||
eventsFinalUsage := parseClaudeSSEEvents(t, outFinalUsage)
|
||||
require.Len(t, eventsFinalUsage, 2)
|
||||
assert.Equal(t, "message_delta", eventsFinalUsage[0].Name)
|
||||
require.NotNil(t, eventsFinalUsage[0].Payload.Delta.StopReason)
|
||||
assert.Equal(t, "max_tokens", *eventsFinalUsage[0].Payload.Delta.StopReason)
|
||||
assert.Equal(t, "message_stop", eventsFinalUsage[1].Name)
|
||||
|
||||
chunkDuplicateUsage := `data: {"id":"stream-2","choices":[],"created":1,"model":"m","object":"chat.completion.chunk","usage":{"prompt_tokens":10,"completion_tokens":100,"total_tokens":110}}` + "\n\n"
|
||||
outDuplicateUsage, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(chunkDuplicateUsage))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, strings.TrimSpace(string(outDuplicateUsage)), "duplicate trailing chunks after message_stop should be ignored")
|
||||
|
||||
doneChunk := "data: [DONE]\n\n"
|
||||
outDone, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(doneChunk))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, strings.TrimSpace(string(outDone)))
|
||||
|
||||
nextRequestChunk := `data: {"id":"stream-3","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}],"created":1,"model":"m","object":"chat.completion.chunk"}` + "\n\n"
|
||||
outNextRequest, err := converter.ConvertOpenAIStreamResponseToClaude(nil, []byte(nextRequestChunk))
|
||||
require.NoError(t, err)
|
||||
eventsNextRequest := parseClaudeSSEEvents(t, outNextRequest)
|
||||
require.Len(t, eventsNextRequest, 1)
|
||||
assert.Equal(t, "message_start", eventsNextRequest[0].Name)
|
||||
})
|
||||
}
|
||||
|
||||
type parsedClaudeSSEEvent struct {
|
||||
Name string
|
||||
Payload claudeTextGenStreamResponse
|
||||
}
|
||||
|
||||
func parseClaudeSSEEvents(t *testing.T, raw []byte) []parsedClaudeSSEEvent {
|
||||
t.Helper()
|
||||
|
||||
text := strings.TrimSpace(string(raw))
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blocks := strings.Split(text, "\n\n")
|
||||
events := make([]parsedClaudeSSEEvent, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
block = strings.TrimSpace(block)
|
||||
if block == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var eventName string
|
||||
var dataPayload string
|
||||
for _, line := range strings.Split(block, "\n") {
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
eventName = strings.TrimPrefix(line, "event: ")
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload = strings.TrimPrefix(line, "data: ")
|
||||
}
|
||||
}
|
||||
|
||||
require.NotEmpty(t, eventName)
|
||||
require.NotEmpty(t, dataPayload)
|
||||
|
||||
var payload claudeTextGenStreamResponse
|
||||
require.NoError(t, json.Unmarshal([]byte(dataPayload), &payload))
|
||||
events = append(events, parsedClaudeSSEEvent{
|
||||
Name: eventName,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user