fix(ai-proxy): resolve Claude streaming response conversion and SSE event chunking issues (#2882)

This commit is contained in:
澄潭
2025-09-08 09:54:18 +08:00
committed by GitHub
parent 20b68c039c
commit 4a429bf147
9 changed files with 356 additions and 143 deletions

View File

@@ -406,7 +406,8 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
return chunk return chunk
} }
if len(outputEvents) == 0 { if len(outputEvents) == 0 {
responseBuilder.WriteString(event.ToHttpString()) // no need convert, keep original events
responseBuilder.WriteString(event.RawEvent)
} else { } else {
for _, outputEvent := range outputEvents { for _, outputEvent := range outputEvents {
responseBuilder.WriteString(outputEvent.ToHttpString()) responseBuilder.WriteString(outputEvent.ToHttpString())

View File

@@ -138,7 +138,7 @@ type claudeSystemPrompt struct {
// Will be set to the string value if system is a simple string // Will be set to the string value if system is a simple string
StringValue string StringValue string
// Will be set to the array value if system is an array of text blocks // Will be set to the array value if system is an array of text blocks
ArrayValue []claudeTextGenContent ArrayValue []claudeChatMessageContent
// Indicates which type this represents // Indicates which type this represents
IsArray bool IsArray bool
} }
@@ -154,7 +154,7 @@ func (csp *claudeSystemPrompt) UnmarshalJSON(data []byte) error {
} }
// Try to unmarshal as array of text blocks // Try to unmarshal as array of text blocks
var arrayValue []claudeTextGenContent var arrayValue []claudeChatMessageContent
if err := json.Unmarshal(data, &arrayValue); err == nil { if err := json.Unmarshal(data, &arrayValue); err == nil {
csp.ArrayValue = arrayValue csp.ArrayValue = arrayValue
csp.IsArray = true csp.IsArray = true
@@ -196,7 +196,7 @@ type claudeThinkingConfig struct {
type claudeTextGenRequest struct { type claudeTextGenRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"` Messages []claudeChatMessage `json:"messages"`
System claudeSystemPrompt `json:"system,omitempty"` System *claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
@@ -232,9 +232,11 @@ type claudeTextGenContent struct {
} }
type claudeTextGenUsage struct { type claudeTextGenUsage struct {
InputTokens int `json:"input_tokens,omitempty"` InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"` OutputTokens int `json:"output_tokens,omitempty"`
ServiceTier string `json:"service_tier,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
} }
type claudeTextGenError struct { type claudeTextGenError struct {
@@ -254,6 +256,7 @@ type claudeTextGenStreamResponse struct {
type claudeTextGenDelta struct { type claudeTextGenDelta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason,omitempty"` StopReason *string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"` StopSequence *string `json:"stop_sequence,omitempty"`
} }
@@ -401,7 +404,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
for _, message := range origRequest.Messages { for _, message := range origRequest.Messages {
if message.Role == roleSystem { if message.Role == roleSystem {
claudeRequest.System = claudeSystemPrompt{ claudeRequest.System = &claudeSystemPrompt{
StringValue: message.StringContent(), StringValue: message.StringContent(),
IsArray: false, IsArray: false,
} }
@@ -622,12 +625,12 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
systemStr := request.System.String() systemStr := request.System.String()
if systemStr == "" { if systemStr == "" {
request.System = claudeSystemPrompt{ request.System = &claudeSystemPrompt{
StringValue: content, StringValue: content,
IsArray: false, IsArray: false,
} }
} else { } else {
request.System = claudeSystemPrompt{ request.System = &claudeSystemPrompt{
StringValue: content + "\n" + systemStr, StringValue: content + "\n" + systemStr,
IsArray: false, IsArray: false,
} }

View File

@@ -29,7 +29,18 @@ type ClaudeToOpenAIConverter struct {
toolBlockStarted bool toolBlockStarted bool
toolBlockStopped bool toolBlockStopped bool
// Tool call state tracking // Tool call state tracking
toolCallStates map[string]*toolCallState toolCallStates map[int]*toolCallInfo // Map of OpenAI index to tool call state
activeToolIndex *int // Currently active tool call index (for Claude serialization)
}
// toolCallInfo tracks tool call state
type toolCallInfo struct {
id string // Tool call ID
name string // Tool call name
claudeContentIndex int // Claude content block index
contentBlockStarted bool // Whether content_block_start has been sent
contentBlockStopped bool // Whether content_block_stop has been sent
cachedArguments string // Cache arguments for this tool call
} }
// contentConversionResult represents the result of converting Claude content to OpenAI format // contentConversionResult represents the result of converting Claude content to OpenAI format
@@ -41,14 +52,6 @@ type contentConversionResult struct {
hasNonTextContent bool hasNonTextContent bool
} }
// toolCallState tracks the state of a tool call during streaming
type toolCallState struct {
id string
name string
argumentsBuffer string
isComplete bool
}
// ConvertClaudeRequestToOpenAI converts a Claude chat completion request to OpenAI format // ConvertClaudeRequestToOpenAI converts a Claude chat completion request to OpenAI format
func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]byte, error) { func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]byte, error) {
log.Debugf("[Claude->OpenAI] Original Claude request body: %s", string(body)) log.Debugf("[Claude->OpenAI] Original Claude request body: %s", string(body))
@@ -114,18 +117,9 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
// Handle regular content if no tool calls or tool results // Handle regular content if no tool calls or tool results
if len(conversionResult.toolCalls) == 0 && len(conversionResult.toolResults) == 0 { if len(conversionResult.toolCalls) == 0 && len(conversionResult.toolResults) == 0 {
var content interface{}
if !conversionResult.hasNonTextContent && len(conversionResult.textParts) > 0 {
// Simple text content
content = strings.Join(conversionResult.textParts, "\n\n")
} else {
// Multi-modal content or empty content
content = conversionResult.openaiContents
}
openaiMsg := chatMessage{ openaiMsg := chatMessage{
Role: claudeMsg.Role, Role: claudeMsg.Role,
Content: content, Content: conversionResult.openaiContents,
} }
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg) openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
} }
@@ -133,11 +127,13 @@ func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]b
} }
// Handle system message - Claude has separate system field // Handle system message - Claude has separate system field
systemStr := claudeRequest.System.String() if claudeRequest.System != nil {
if systemStr != "" { systemMsg := chatMessage{Role: roleSystem}
systemMsg := chatMessage{ if !claudeRequest.System.IsArray {
Role: roleSystem, systemMsg.Content = claudeRequest.System.StringValue
Content: systemStr, } else {
conversionResult := c.convertContentArray(claudeRequest.System.ArrayValue)
systemMsg.Content = conversionResult.openaiContents
} }
// Insert system message at the beginning // Insert system message at the beginning
openaiRequest.Messages = append([]chatMessage{systemMsg}, openaiRequest.Messages...) openaiRequest.Messages = append([]chatMessage{systemMsg}, openaiRequest.Messages...)
@@ -231,6 +227,9 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.Http
InputTokens: openaiResponse.Usage.PromptTokens, InputTokens: openaiResponse.Usage.PromptTokens,
OutputTokens: openaiResponse.Usage.CompletionTokens, OutputTokens: openaiResponse.Usage.CompletionTokens,
} }
if openaiResponse.Usage.PromptTokensDetails != nil {
claudeResponse.Usage.CacheReadInputTokens = openaiResponse.Usage.PromptTokensDetails.CachedTokens
}
} }
// Convert the first choice content // Convert the first choice content
@@ -312,11 +311,6 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.Http
func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrapper.HttpContext, chunk []byte) ([]byte, error) { func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrapper.HttpContext, chunk []byte) ([]byte, error) {
log.Debugf("[OpenAI->Claude] Original OpenAI streaming chunk: %s", string(chunk)) log.Debugf("[OpenAI->Claude] Original OpenAI streaming chunk: %s", string(chunk))
// Initialize tool call states if needed
if c.toolCallStates == nil {
c.toolCallStates = make(map[string]*toolCallState)
}
// For streaming responses, we need to handle the Server-Sent Events format // For streaming responses, we need to handle the Server-Sent Events format
lines := strings.Split(string(chunk), "\n") lines := strings.Split(string(chunk), "\n")
var result strings.Builder var result strings.Builder
@@ -350,15 +344,18 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
stopData, _ := json.Marshal(stopEvent) stopData, _ := json.Marshal(stopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData)) result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
} }
if c.toolBlockStarted && !c.toolBlockStopped { // Send final content_block_stop events for any remaining unclosed tool calls
c.toolBlockStopped = true for index, toolCall := range c.toolCallStates {
log.Debugf("[OpenAI->Claude] Sending final tool content_block_stop event at index %d", c.toolBlockIndex) if toolCall.contentBlockStarted && !toolCall.contentBlockStopped {
stopEvent := &claudeTextGenStreamResponse{ log.Debugf("[OpenAI->Claude] Sending final tool content_block_stop event for index %d at Claude index %d",
Type: "content_block_stop", index, toolCall.claudeContentIndex)
Index: &c.toolBlockIndex, stopEvent := &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &toolCall.claudeContentIndex,
}
stopData, _ := json.Marshal(stopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
} }
stopData, _ := json.Marshal(stopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
} }
// If we have a pending stop_reason but no usage, send message_delta with just stop_reason // If we have a pending stop_reason but no usage, send message_delta with just stop_reason
@@ -401,7 +398,8 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
c.toolBlockIndex = -1 c.toolBlockIndex = -1
c.toolBlockStarted = false c.toolBlockStarted = false
c.toolBlockStopped = false c.toolBlockStopped = false
c.toolCallStates = make(map[string]*toolCallState) c.toolCallStates = make(map[int]*toolCallInfo)
c.activeToolIndex = nil
log.Debugf("[OpenAI->Claude] Reset converter state for next request") log.Debugf("[OpenAI->Claude] Reset converter state for next request")
continue continue
@@ -424,7 +422,7 @@ func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrappe
continue continue
} }
log.Debugf("[OpenAI->Claude] Stream event [%d/%d]: %s", i+1, len(claudeStreamResponses), string(responseData)) log.Debugf("[OpenAI->Claude] Stream event [%d/%d]: %s", i+1, len(claudeStreamResponses), string(responseData))
result.WriteString(fmt.Sprintf("data: %s\n\n", responseData)) result.WriteString(fmt.Sprintf("event: %s\ndata: %s\n\n", claudeStreamResponse.Type, responseData))
} }
} }
} }
@@ -579,79 +577,60 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
// Handle tool calls in streaming response // Handle tool calls in streaming response
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 { if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
// Initialize toolCallStates if needed
if c.toolCallStates == nil {
c.toolCallStates = make(map[int]*toolCallInfo)
}
for _, toolCall := range choice.Delta.ToolCalls { for _, toolCall := range choice.Delta.ToolCalls {
if !toolCall.Function.IsEmpty() { log.Debugf("[OpenAI->Claude] Processing tool call delta: index=%d, id=%s, name=%s, args=%s",
log.Debugf("[OpenAI->Claude] Processing tool call delta") toolCall.Index, toolCall.Id, toolCall.Function.Name, toolCall.Function.Arguments)
// Get or create tool call state // Handle new tool call (has id and name)
state := c.toolCallStates[toolCall.Id] if toolCall.Id != "" && toolCall.Function.Name != "" {
if state == nil { // Create or update tool call state
state = &toolCallState{ if _, exists := c.toolCallStates[toolCall.Index]; !exists {
id: toolCall.Id, c.toolCallStates[toolCall.Index] = &toolCallInfo{
name: toolCall.Function.Name, id: toolCall.Id,
argumentsBuffer: "", name: toolCall.Function.Name,
isComplete: false, contentBlockStarted: false,
contentBlockStopped: false,
cachedArguments: "",
} }
c.toolCallStates[toolCall.Id] = state
log.Debugf("[OpenAI->Claude] Created new tool call state for id: %s, name: %s", toolCall.Id, toolCall.Function.Name)
} }
// Accumulate arguments toolState := c.toolCallStates[toolCall.Index]
if toolCall.Function.Arguments != "" {
state.argumentsBuffer += toolCall.Function.Arguments // Check if we can start this tool call (Claude requires serialization)
log.Debugf("[OpenAI->Claude] Accumulated tool arguments: %s", state.argumentsBuffer) if c.activeToolIndex == nil {
// No active tool call, start this one
c.activeToolIndex = &toolCall.Index
toolCallResponses := c.startToolCall(toolState)
responses = append(responses, toolCallResponses...)
} }
// If there's already an active tool call, we'll start this one when the current one finishes
}
// Try to parse accumulated arguments as JSON to check if complete // Handle arguments for any tool call - cache all arguments regardless of active state
var input map[string]interface{} if toolCall.Function.Arguments != "" {
if state.argumentsBuffer != "" { if toolState, exists := c.toolCallStates[toolCall.Index]; exists {
if err := json.Unmarshal([]byte(state.argumentsBuffer), &input); err == nil { // Always cache arguments for this tool call
// Successfully parsed - arguments are complete toolState.cachedArguments += toolCall.Function.Arguments
if !state.isComplete { log.Debugf("[OpenAI->Claude] Cached arguments for tool index %d: %s (total: %s)",
state.isComplete = true toolCall.Index, toolCall.Function.Arguments, toolState.cachedArguments)
log.Debugf("[OpenAI->Claude] Tool call arguments complete for %s: %s", state.name, state.argumentsBuffer)
// Close thinking content block if it's still open // Send input_json_delta event only if this tool is currently active and content block started
if c.thinkingBlockStarted && !c.thinkingBlockStopped { if c.activeToolIndex != nil && *c.activeToolIndex == toolCall.Index && toolState.contentBlockStarted {
c.thinkingBlockStopped = true log.Debugf("[OpenAI->Claude] Generated input_json_delta event for active tool index %d: %s",
log.Debugf("[OpenAI->Claude] Closing thinking content block before tool use") toolCall.Index, toolCall.Function.Arguments)
responses = append(responses, &claudeTextGenStreamResponse{ responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop", Type: "content_block_delta",
Index: &c.thinkingBlockIndex, Index: &toolState.claudeContentIndex,
}) Delta: &claudeTextGenDelta{
} Type: "input_json_delta",
PartialJson: toolCall.Function.Arguments,
// Close text content block if it's still open },
if c.textBlockStarted && !c.textBlockStopped { })
c.textBlockStopped = true
log.Debugf("[OpenAI->Claude] Closing text content block before tool use")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.textBlockIndex,
})
}
// Send content_block_start for tool_use only when we have complete arguments with dynamic index
if !c.toolBlockStarted {
c.toolBlockIndex = c.nextContentIndex
c.nextContentIndex++
c.toolBlockStarted = true
log.Debugf("[OpenAI->Claude] Generated content_block_start event for tool_use at index %d", c.toolBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_start",
Index: &c.toolBlockIndex,
ContentBlock: &claudeTextGenContent{
Type: "tool_use",
Id: toolCall.Id,
Name: state.name,
Input: input,
},
})
}
}
} else {
// Still accumulating arguments
log.Debugf("[OpenAI->Claude] Tool arguments not yet complete, continuing to accumulate: %v", err)
} }
} }
} }
@@ -680,15 +659,52 @@ func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpCont
Index: &c.textBlockIndex, Index: &c.textBlockIndex,
}) })
} }
if c.toolBlockStarted && !c.toolBlockStopped {
c.toolBlockStopped = true // First, start any remaining unstarted tool calls (they may have no arguments)
log.Debugf("[OpenAI->Claude] Generated tool content_block_stop event at index %d", c.toolBlockIndex) // Process in order to maintain Claude's sequential requirement
responses = append(responses, &claudeTextGenStreamResponse{ var sortedIndices []int
Type: "content_block_stop", for index := range c.toolCallStates {
Index: &c.toolBlockIndex, sortedIndices = append(sortedIndices, index)
})
} }
// Sort indices to process in order
for i := 0; i < len(sortedIndices)-1; i++ {
for j := i + 1; j < len(sortedIndices); j++ {
if sortedIndices[i] > sortedIndices[j] {
sortedIndices[i], sortedIndices[j] = sortedIndices[j], sortedIndices[i]
}
}
}
for _, index := range sortedIndices {
toolCall := c.toolCallStates[index]
if !toolCall.contentBlockStarted {
log.Debugf("[OpenAI->Claude] Starting remaining tool call at finish: index=%d, id=%s, name=%s",
index, toolCall.id, toolCall.name)
c.activeToolIndex = &index
toolCallResponses := c.startToolCall(toolCall)
responses = append(responses, toolCallResponses...)
c.activeToolIndex = nil // Clear immediately since tool is now fully started
}
}
// Then send content_block_stop for all started tool calls in order
for _, index := range sortedIndices {
toolCall := c.toolCallStates[index]
if toolCall.contentBlockStarted && !toolCall.contentBlockStopped {
log.Debugf("[OpenAI->Claude] Generated content_block_stop for tool at index %d, Claude index %d",
index, toolCall.claudeContentIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &toolCall.claudeContentIndex,
})
toolCall.contentBlockStopped = true
}
}
// Clear active tool index
c.activeToolIndex = nil
// Cache stop_reason until we get usage info (Claude protocol requires them together) // Cache stop_reason until we get usage info (Claude protocol requires them together)
c.pendingStopReason = &claudeFinishReason c.pendingStopReason = &claudeFinishReason
log.Debugf("[OpenAI->Claude] Cached stop_reason: %s, waiting for usage", claudeFinishReason) log.Debugf("[OpenAI->Claude] Cached stop_reason: %s, waiting for usage", claudeFinishReason)
@@ -764,8 +780,9 @@ func (c *ClaudeToOpenAIConverter) convertContentArray(claudeContents []claudeCha
if claudeContent.Text != "" { if claudeContent.Text != "" {
result.textParts = append(result.textParts, claudeContent.Text) result.textParts = append(result.textParts, claudeContent.Text)
result.openaiContents = append(result.openaiContents, chatMessageContent{ result.openaiContents = append(result.openaiContents, chatMessageContent{
Type: contentTypeText, Type: contentTypeText,
Text: claudeContent.Text, Text: claudeContent.Text,
CacheControl: claudeContent.CacheControl,
}) })
} }
case "image": case "image":
@@ -822,3 +839,63 @@ func (c *ClaudeToOpenAIConverter) convertContentArray(claudeContents []claudeCha
return result return result
} }
// startToolCall starts a new tool call content block
func (c *ClaudeToOpenAIConverter) startToolCall(toolState *toolCallInfo) []*claudeTextGenStreamResponse {
var responses []*claudeTextGenStreamResponse
// Close thinking content block if it's still open
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
c.thinkingBlockStopped = true
log.Debugf("[OpenAI->Claude] Closing thinking content block before tool use")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.thinkingBlockIndex,
})
}
// Close text content block if it's still open
if c.textBlockStarted && !c.textBlockStopped {
c.textBlockStopped = true
log.Debugf("[OpenAI->Claude] Closing text content block before tool use")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.textBlockIndex,
})
}
// Assign Claude content index
toolState.claudeContentIndex = c.nextContentIndex
c.nextContentIndex++
toolState.contentBlockStarted = true
log.Debugf("[OpenAI->Claude] Started tool call: Claude index=%d, id=%s, name=%s",
toolState.claudeContentIndex, toolState.id, toolState.name)
// Send content_block_start
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_start",
Index: &toolState.claudeContentIndex,
ContentBlock: &claudeTextGenContent{
Type: "tool_use",
Id: toolState.id,
Name: toolState.name,
Input: map[string]interface{}{}, // Empty input as per Claude spec
},
})
// Send any cached arguments as input_json_delta events
if toolState.cachedArguments != "" {
log.Debugf("[OpenAI->Claude] Outputting cached arguments for tool: %s", toolState.cachedArguments)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_delta",
Index: &toolState.claudeContentIndex,
Delta: &claudeTextGenDelta{
Type: "input_json_delta",
PartialJson: toolState.cachedArguments,
},
})
}
return responses
}

View File

@@ -35,7 +35,8 @@ func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
converter := &ClaudeToOpenAIConverter{} converter := &ClaudeToOpenAIConverter{}
t.Run("convert_multiple_text_content_blocks", func(t *testing.T) { t.Run("convert_multiple_text_content_blocks", func(t *testing.T) {
// Test case for the bug fix: multiple text content blocks should be merged into a single string // Test case: multiple text content blocks should remain as separate array elements with cache control support
// Both system and user messages should handle array content format
claudeRequest := `{ claudeRequest := `{
"max_tokens": 32000, "max_tokens": 32000,
"messages": [{ "messages": [{
@@ -98,15 +99,64 @@ func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
// First message should be system message (converted from Claude's system field) // First message should be system message (converted from Claude's system field)
systemMsg := openaiRequest.Messages[0] systemMsg := openaiRequest.Messages[0]
assert.Equal(t, roleSystem, systemMsg.Role) assert.Equal(t, roleSystem, systemMsg.Role)
assert.Equal(t, "xxx\nyyy", systemMsg.Content) // Claude system uses single \n
// Second message should be user message with merged text content // System content should now also be an array for multiple text blocks
systemContentArray, ok := systemMsg.Content.([]interface{})
require.True(t, ok, "System content should be an array for multiple text blocks")
require.Len(t, systemContentArray, 2)
// First system text block
firstSystemElement, ok := systemContentArray[0].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeText, firstSystemElement["type"])
assert.Equal(t, "xxx", firstSystemElement["text"])
assert.NotNil(t, firstSystemElement["cache_control"]) // Has cache control
systemCacheControl1, ok := firstSystemElement["cache_control"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "ephemeral", systemCacheControl1["type"])
// Second system text block
secondSystemElement, ok := systemContentArray[1].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeText, secondSystemElement["type"])
assert.Equal(t, "yyy", secondSystemElement["text"])
assert.NotNil(t, secondSystemElement["cache_control"]) // Has cache control
systemCacheControl2, ok := secondSystemElement["cache_control"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "ephemeral", systemCacheControl2["type"])
// Second message should be user message with text content as array
userMsg := openaiRequest.Messages[1] userMsg := openaiRequest.Messages[1]
assert.Equal(t, "user", userMsg.Role) assert.Equal(t, "user", userMsg.Role)
// The key fix: multiple text blocks should be merged into a single string // The content should now be an array of separate text blocks, not merged
expectedContent := "<system-reminder>\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.</system-reminder>\n\n<system-reminder>\nyyy</system-reminder>\n\n你是谁" contentArray, ok := userMsg.Content.([]interface{})
assert.Equal(t, expectedContent, userMsg.Content) require.True(t, ok, "Content should be an array for multiple text blocks")
require.Len(t, contentArray, 3)
// First text block
firstElement, ok := contentArray[0].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeText, firstElement["type"])
assert.Equal(t, "<system-reminder>\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.</system-reminder>", firstElement["text"])
assert.Nil(t, firstElement["cache_control"]) // No cache control for first block
// Second text block
secondElement, ok := contentArray[1].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeText, secondElement["type"])
assert.Equal(t, "<system-reminder>\nyyy</system-reminder>", secondElement["text"])
assert.Nil(t, secondElement["cache_control"]) // No cache control for second block
// Third text block with cache control
thirdElement, ok := contentArray[2].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeText, thirdElement["type"])
assert.Equal(t, "你是谁", thirdElement["text"])
assert.NotNil(t, thirdElement["cache_control"]) // Has cache control
cacheControl, ok := thirdElement["cache_control"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "ephemeral", cacheControl["type"])
}) })
t.Run("convert_mixed_content_with_image", func(t *testing.T) { t.Run("convert_mixed_content_with_image", func(t *testing.T) {

View File

@@ -159,10 +159,17 @@ type usage struct {
CompletionTokens int `json:"completion_tokens,omitempty"` CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"` TotalTokens int `json:"total_tokens,omitempty"`
CompletionTokensDetails *completionTokensDetails `json:"completion_tokens_details,omitempty"` CompletionTokensDetails *completionTokensDetails `json:"completion_tokens_details,omitempty"`
PromptTokensDetails *promptTokensDetails `json:"prompt_tokens_details,omitempty"`
}
type promptTokensDetails struct {
AudioTokens int `json:"audio_tokens,omitempty"`
CachedTokens int `json:"cached_tokens,omitempty"`
} }
type completionTokensDetails struct { type completionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens,omitempty"` ReasoningTokens int `json:"reasoning_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
} }
@@ -240,11 +247,12 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r
} }
type chatMessageContent struct { type chatMessageContent struct {
Type string `json:"type,omitempty"` CacheControl map[string]interface{} `json:"cache_control,omitempty"`
Text string `json:"text"` Type string `json:"type,omitempty"`
ImageUrl *chatMessageContentImageUrl `json:"image_url,omitempty"` Text string `json:"text"`
File *chatMessageContentFile `json:"file,omitempty"` ImageUrl *chatMessageContentImageUrl `json:"image_url,omitempty"`
InputAudio *chatMessageContentAudio `json:"input_audio,omitempty"` File *chatMessageContentFile `json:"file,omitempty"`
InputAudio *chatMessageContentAudio `json:"input_audio,omitempty"`
} }
type chatMessageContentAudio struct { type chatMessageContentAudio struct {
@@ -402,6 +410,7 @@ func (m *functionCall) IsEmpty() bool {
} }
type StreamEvent struct { type StreamEvent struct {
RawEvent string `json:"-"`
Id string `json:"id"` Id string `json:"id"`
Event string `json:"event"` Event string `json:"event"`
Data string `json:"data"` Data string `json:"data"`

View File

@@ -813,6 +813,7 @@ func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent
value := string(body[valueStartIndex:i]) value := string(body[valueStartIndex:i])
currentEvent.SetValue(currentKey, value) currentEvent.SetValue(currentKey, value)
} else { } else {
currentEvent.RawEvent = string(body[eventStartIndex : i+1])
// Extra new line. The current event is complete. // Extra new line. The current event is complete.
events = append(events, *currentEvent) events = append(events, *currentEvent)
// Reset event parsing state. // Reset event parsing state.

View File

@@ -6,14 +6,15 @@ import (
"strings" "strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
) )
const ( const (
zhipuAiDomain = "open.bigmodel.cn" zhipuAiDomain = "open.bigmodel.cn"
zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions" zhipuAiChatCompletionPath = "/api/paas/v4/chat/completions"
zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings" zhipuAiEmbeddingsPath = "/api/paas/v4/embeddings"
zhipuAiAnthropicMessagesPath = "/api/anthropic/v1/messages"
) )
type zhipuAiProviderInitializer struct{} type zhipuAiProviderInitializer struct{}
@@ -27,8 +28,9 @@ func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) erro
func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string { func (m *zhipuAiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{ return map[string]string{
string(ApiNameChatCompletion): zhipuAiChatCompletionPath, string(ApiNameChatCompletion): zhipuAiChatCompletionPath,
string(ApiNameEmbeddings): zhipuAiEmbeddingsPath, string(ApiNameEmbeddings): zhipuAiEmbeddingsPath,
string(ApiNameAnthropicMessages): zhipuAiAnthropicMessagesPath,
} }
} }
@@ -75,5 +77,8 @@ func (m *zhipuAiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, zhipuAiEmbeddingsPath) { if strings.Contains(path, zhipuAiEmbeddingsPath) {
return ApiNameEmbeddings return ApiNameEmbeddings
} }
if strings.Contains(path, zhipuAiAnthropicMessagesPath) {
return ApiNameAnthropicMessages
}
return "" return ""
} }

View File

@@ -83,6 +83,23 @@ const (
RuleFirst = "first" RuleFirst = "first"
RuleReplace = "replace" RuleReplace = "replace"
RuleAppend = "append" RuleAppend = "append"
// Built-in attributes
BuiltinQuestionKey = "question"
BuiltinAnswerKey = "answer"
// Built-in attribute paths
// Question paths (from request body)
QuestionPathOpenAI = "messages.@reverse.0.content"
QuestionPathClaude = "messages.@reverse.0.content" // Claude uses same format
// Answer paths (from response body - non-streaming)
AnswerPathOpenAINonStreaming = "choices.0.message.content"
AnswerPathClaudeNonStreaming = "content.0.text"
// Answer paths (from response streaming body)
AnswerPathOpenAIStreaming = "choices.0.delta.content"
AnswerPathClaudeStreaming = "delta.text"
) )
// TracingSpan is the tracing span configuration. // TracingSpan is the tracing span configuration.
@@ -325,12 +342,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
userPromptCount := 0 userPromptCount := 0
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
// OpenAI and Claude/Anthropic format - both use "messages" array with "role" field
for _, msg := range messages.Array() { for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" { if msg.Get("role").String() == "user" {
userPromptCount += 1 userPromptCount += 1
} }
} }
} else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() { // Google Gemini GenerateContent } else if contents := gjson.GetBytes(body, "contents"); contents.Exists() && contents.IsArray() {
// Google Gemini GenerateContent
for _, content := range contents.Array() { for _, content := range contents.Array() {
if !content.Get("role").Exists() || content.Get("role").String() == "user" { if !content.Get("role").Exists() || content.Get("role").String() == "user" {
userPromptCount += 1 userPromptCount += 1
@@ -387,7 +406,7 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
"id", "id",
"response.id", "response.id",
"responseId", // Gemini generateContent "responseId", // Gemini generateContent
"message.id", // anthropic messages "message.id", // anthropic/claude messages
}); chatID != nil { }); chatID != nil {
ctx.SetUserAttribute(ChatID, chatID.String()) ctx.SetUserAttribute(ChatID, chatID.String())
} }
@@ -456,7 +475,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
"id", "id",
"response.id", "response.id",
"responseId", // Gemini generateContent "responseId", // Gemini generateContent
"message.id", // anthropic messages "message.id", // anthropic/claude messages
}); chatID != nil { }); chatID != nil {
ctx.SetUserAttribute(ChatID, chatID.String()) ctx.SetUserAttribute(ChatID, chatID.String())
} }
@@ -507,6 +526,15 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
value = gjson.GetBytes(body, attribute.Value).Value() value = gjson.GetBytes(body, attribute.Value).Value()
default: default:
} }
// Handle built-in attributes with Claude/OpenAI protocol fallback logic
if (value == nil || value == "") && isBuiltinAttribute(key) {
value = getBuiltinAttributeFallback(ctx, config, key, source, body, attribute.Rule)
if value != nil && value != "" {
log.Debugf("[attribute] Used protocol fallback for %s: %+v", key, value)
}
}
if (value == nil || value == "") && attribute.DefaultValue != "" { if (value == nil || value == "") && attribute.DefaultValue != "" {
value = attribute.DefaultValue value = attribute.DefaultValue
} }
@@ -538,6 +566,45 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
} }
} }
// isBuiltinAttribute checks if the given key is a built-in attribute
func isBuiltinAttribute(key string) bool {
return key == BuiltinQuestionKey || key == BuiltinAnswerKey
}
// getBuiltinAttributeFallback provides protocol compatibility fallback for question/answer attributes
func getBuiltinAttributeFallback(ctx wrapper.HttpContext, config AIStatisticsConfig, key, source string, body []byte, rule string) interface{} {
switch key {
case BuiltinQuestionKey:
if source == RequestBody {
// Try OpenAI/Claude format (both use same messages structure)
if value := gjson.GetBytes(body, QuestionPathOpenAI).Value(); value != nil && value != "" {
return value
}
}
case BuiltinAnswerKey:
if source == ResponseStreamingBody {
// Try OpenAI format first
if value := extractStreamingBodyByJsonPath(body, AnswerPathOpenAIStreaming, rule); value != nil && value != "" {
return value
}
// Try Claude format
if value := extractStreamingBodyByJsonPath(body, AnswerPathClaudeStreaming, rule); value != nil && value != "" {
return value
}
} else if source == ResponseBody {
// Try OpenAI format first
if value := gjson.GetBytes(body, AnswerPathOpenAINonStreaming).Value(); value != nil && value != "" {
return value
}
// Try Claude format
if value := gjson.GetBytes(body, AnswerPathClaudeNonStreaming).Value(); value != nil && value != "" {
return value
}
}
}
return nil
}
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string) interface{} { func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string) interface{} {
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
var value interface{} var value interface{}