From 69d877c11627aecffe739a3eca9ad7df24a9f0d8 Mon Sep 17 00:00:00 2001 From: Xijun Dai Date: Tue, 10 Jun 2025 15:11:18 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai-proxy):=20=E6=B7=BB=E5=8A=A0=20Claude?= =?UTF-8?q?=20=E5=9B=BE=E7=89=87=E7=90=86=E8=A7=A3=E4=B8=8E=20Tools=20?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E8=83=BD=E5=8A=9B=20||=20feat(ai-proxy):=20A?= =?UTF-8?q?dd=20Claude=20image=20understanding=20and=20Tools=20calling=20c?= =?UTF-8?q?apabilities=20(#2385)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Xijun Dai --- .../extensions/ai-proxy/provider/claude.go | 213 +++++++++++++++--- .../extensions/ai-proxy/provider/model.go | 97 ++++++-- .../wasm-go/extensions/ai-proxy/util/http.go | 2 + 3 files changed, 262 insertions(+), 50 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index cae6d85e7..82d1b2e86 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -25,16 +25,49 @@ const ( type claudeProviderInitializer struct{} +type claudeTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"input_schema,omitempty"` +} + +type claudeToolChoice struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"` +} + +type claudeChatMessage struct { + Role string `json:"role"` + Content any `json:"content"` +} + +type claudeChatMessageContentSource struct { + Type string `json:"type"` + MediaType string `json:"media_type,omitempty"` + Data string `json:"data,omitempty"` + Url string `json:"url,omitempty"` + FileId string `json:"file_id,omitempty"` +} + +type claudeChatMessageContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *claudeChatMessageContentSource `json:"source,omitempty"` +} type claudeTextGenRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - System string `json:"system,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` + Model string `json:"model"` + Messages []claudeChatMessage `json:"messages"` + System string `json:"system,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"` + Tools []claudeTool `json:"tools,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` } type claudeTextGenResponse struct { @@ -50,13 +83,14 @@ type claudeTextGenResponse struct { } type claudeTextGenContent struct { - Type string `json:"type"` + Type string `json:"type,omitempty"` Text string `json:"text,omitempty"` } type claudeTextGenUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` } type claudeTextGenError struct { @@ -65,12 +99,12 @@ type claudeTextGenError struct { } type claudeTextGenStreamResponse struct { - Type string `json:"type"` - Message claudeTextGenResponse `json:"message"` - Index int `json:"index"` - ContentBlock *claudeTextGenContent `json:"content_block"` - Delta *claudeTextGenDelta `json:"delta"` - Usage claudeTextGenUsage `json:"usage"` + Type string `json:"type"` + Message *claudeTextGenResponse `json:"message,omitempty"` + Index int `json:"index,omitempty"` + ContentBlock *claudeTextGenContent `json:"content_block,omitempty"` + Delta *claudeTextGenDelta `json:"delta,omitempty"` + Usage *claudeTextGenUsage `json:"usage,omitempty"` } type claudeTextGenDelta struct { @@ -93,6 +127,7 @@ func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string { string(ApiNameCompletion): claudeCompletionPath, // docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameModels): PathOpenAIModels, } } @@ -107,6 +142,10 @@ func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provi type claudeProvider struct { config ProviderConfig contextCache *contextCache + + messageId string + usage usage + serviceTier string } func (c *claudeProvider) GetProviderType() string { @@ -133,7 +172,7 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { if !c.config.isSupportedAPI(apiName) { - return types.ActionContinue, errUnsupportedApiName + return types.ActionContinue, nil } return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body) } @@ -205,11 +244,12 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRequest) *claudeTextGenRequest { claudeRequest := claudeTextGenRequest{ Model: origRequest.Model, - MaxTokens: origRequest.MaxTokens, + MaxTokens: origRequest.getMaxTokens(), StopSequences: origRequest.Stop, Stream: origRequest.Stream, Temperature: origRequest.Temperature, TopP: origRequest.TopP, + // ServiceTier: origRequest.ServiceTier, } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = claudeDefaultMaxTokens @@ -220,12 +260,80 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe claudeRequest.System = message.StringContent() continue } - claudeMessage := chatMessage{ - Role: message.Role, - Content: message.Content, + + claudeMessage := claudeChatMessage{ + Role: message.Role, + } + if message.IsStringContent() { + claudeMessage.Content = message.StringContent() + } else { + chatMessageContents := make([]claudeChatMessageContent, 0) + for _, messageContent := range message.ParseContent() { + switch messageContent.Type { + case contentTypeText: + chatMessageContents = append(chatMessageContents, claudeChatMessageContent{ + Type: contentTypeText, + Text: messageContent.Text, + }) + case contentTypeImageUrl: + if strings.HasPrefix(messageContent.ImageUrl.Url, "data:") { + parts := strings.SplitN(messageContent.ImageUrl.Url, ";", 2) + if len(parts) != 2 { + log.Errorf("invalid image url format: %s", messageContent.ImageUrl.Url) + continue + } + chatMessageContents = append(chatMessageContents, claudeChatMessageContent{ + Type: "image", + Source: &claudeChatMessageContentSource{ + Type: "base64", + MediaType: strings.TrimPrefix(parts[0], "data:"), + Data: strings.TrimPrefix(parts[1], "base64,"), + }, + }) + } else { + chatMessageContents = append(chatMessageContents, claudeChatMessageContent{ + Type: "image", + Source: &claudeChatMessageContentSource{ + Type: "url", + Url: messageContent.ImageUrl.Url, + }, + }) + } + case contentTypeFile: + chatMessageContents = append(chatMessageContents, claudeChatMessageContent{ + Type: "file", + Source: &claudeChatMessageContentSource{ + Type: "url", + FileId: messageContent.File.FileId, + }, + }) + default: + log.Errorf("Unsupported content type: %s", messageContent.Type) + continue + } + } + claudeMessage.Content = chatMessageContents } claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) } + + for _, tool := range origRequest.Tools { + claudeTool := claudeTool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: tool.Function.Parameters, + } + claudeRequest.Tools = append(claudeRequest.Tools, claudeTool) + } + + if tc := origRequest.getToolChoiceObject(); tc != nil { + claudeRequest.ToolChoice = &claudeToolChoice{ + Name: tc.Function.Name, + Type: tc.Type, + DisableParallelToolUse: !origRequest.ParallelToolCalls, + } + } + return &claudeRequest } @@ -270,27 +378,50 @@ func stopReasonClaude2OpenAI(reason *string) string { func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse { switch origResponse.Type { case "message_start": + c.messageId = origResponse.Message.Id + c.usage = usage{ + PromptTokens: origResponse.Message.Usage.InputTokens, + CompletionTokens: origResponse.Message.Usage.OutputTokens, + } + c.serviceTier = origResponse.Message.Usage.ServiceTier choice := chatCompletionChoice{ - Index: 0, + Index: origResponse.Index, Delta: &chatMessage{Role: roleAssistant, Content: ""}, } - return createChatCompletionResponse(ctx, origResponse, choice) + return c.createChatCompletionResponse(ctx, origResponse, choice) case "content_block_delta": choice := chatCompletionChoice{ - Index: 0, + Index: origResponse.Index, Delta: &chatMessage{Content: origResponse.Delta.Text}, } - return createChatCompletionResponse(ctx, origResponse, choice) + return c.createChatCompletionResponse(ctx, origResponse, choice) case "message_delta": + c.usage.CompletionTokens += origResponse.Usage.OutputTokens + c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens + choice := chatCompletionChoice{ - Index: 0, + Index: origResponse.Index, Delta: &chatMessage{}, FinishReason: stopReasonClaude2OpenAI(origResponse.Delta.StopReason), } - return createChatCompletionResponse(ctx, origResponse, choice) - case "content_block_stop", "message_stop": + return c.createChatCompletionResponse(ctx, origResponse, choice) + case "message_stop": + return &chatCompletionResponse{ + Id: c.messageId, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Object: objectChatCompletionChunk, + Choices: []chatCompletionChoice{}, + ServiceTier: c.serviceTier, + Usage: usage{ + PromptTokens: c.usage.PromptTokens, + CompletionTokens: c.usage.CompletionTokens, + TotalTokens: c.usage.TotalTokens, + }, + } + case "content_block_stop", "ping", "content_block_start": log.Debugf("skip processing response type: %s", origResponse.Type) return nil default: @@ -299,13 +430,14 @@ func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, or } } -func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse { +func (c *claudeProvider) createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse { return &chatCompletionResponse{ - Id: response.Message.Id, - Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), - Object: objectChatCompletionChunk, - Choices: []chatCompletionChoice{choice}, + Id: c.messageId, + Created: time.Now().UnixMilli() / 1000, + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), + Object: objectChatCompletionChunk, + Choices: []chatCompletionChoice{choice}, + ServiceTier: c.serviceTier, } } @@ -332,5 +464,14 @@ func (c *claudeProvider) GetApiName(path string) ApiName { if strings.Contains(path, claudeChatCompletionPath) { return ApiNameChatCompletion } + if strings.Contains(path, claudeCompletionPath) { + return ApiNameCompletion + } + if strings.Contains(path, PathOpenAIModels) { + return ApiNameModels + } + if strings.Contains(path, PathOpenAIEmbeddings) { + return ApiNameEmbeddings + } return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 33de57293..510832bd3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -20,8 +20,10 @@ const ( httpStatus200 = "200" - contentTypeText = "text" - contentTypeImageUrl = "image_url" + contentTypeText = "text" + contentTypeImageUrl = "image_url" + contentTypeInputAudio = "input_audio" + contentTypeFile = "file" reasoningStartTag = "" reasoningEndTag = "" @@ -53,11 +55,40 @@ type chatCompletionRequest struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` Tools []tool `json:"tools,omitempty"` - ToolChoice *toolChoice `json:"tool_choice,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` User string `json:"user,omitempty"` } +func (c *chatCompletionRequest) getMaxTokens() int { + if c.MaxCompletionTokens > 0 { + return c.MaxCompletionTokens + } + return c.MaxTokens +} + +func (c *chatCompletionRequest) getToolChoiceString() string { + if c.ToolChoice == nil { + return "" + } + + if tc, ok := c.ToolChoice.(string); ok { + return tc + } + return "" +} + +func (c *chatCompletionRequest) getToolChoiceObject() *toolChoice { + if c.ToolChoice == nil { + return nil + } + + if tc, ok := c.ToolChoice.(*toolChoice); ok { + return tc + } + return nil +} + type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` @@ -200,13 +231,26 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r } } -type messageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageUrl *imageUrl `json:"image_url,omitempty"` +type chatMessageContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageUrl *chatMessageContentImageUrl `json:"image_url,omitempty"` + File *chatMessageContentFile `json:"file,omitempty"` + InputAudio *chatMessageContentAudio `json:"input_audio,omitempty"` } -type imageUrl struct { +type chatMessageContentAudio struct { + Data string `json:"data"` + Format string `json:"format"` +} + +type chatMessageContentFile struct { + FileData string `json:"file_data,omitempty"` + FileId string `json:"file_id,omitempty"` + FileName string `json:"file_name,omitempty"` +} + +type chatMessageContentImageUrl struct { Url string `json:"url,omitempty"` Detail string `json:"detail,omitempty"` } @@ -266,11 +310,11 @@ func (m *chatMessage) StringContent() string { return "" } -func (m *chatMessage) ParseContent() []messageContent { - var contentList []messageContent +func (m *chatMessage) ParseContent() []chatMessageContent { + var contentList []chatMessageContent content, ok := m.Content.(string) if ok { - contentList = append(contentList, messageContent{ + contentList = append(contentList, chatMessageContent{ Type: contentTypeText, Text: content, }) @@ -286,18 +330,43 @@ func (m *chatMessage) ParseContent() []messageContent { switch contentMap["type"] { case contentTypeText: if subStr, ok := contentMap[contentTypeText].(string); ok { - contentList = append(contentList, messageContent{ + contentList = append(contentList, chatMessageContent{ Type: contentTypeText, Text: subStr, }) } case contentTypeImageUrl: if subObj, ok := contentMap[contentTypeImageUrl].(map[string]any); ok { - contentList = append(contentList, messageContent{ + msg := chatMessageContent{ Type: contentTypeImageUrl, - ImageUrl: &imageUrl{ + ImageUrl: &chatMessageContentImageUrl{ Url: subObj["url"].(string), }, + } + if detail, ok := subObj["detail"].(string); ok { + msg.ImageUrl.Detail = detail + } + contentList = append(contentList, msg) + } + case contentTypeInputAudio: + if subObj, ok := contentMap[contentTypeInputAudio].(map[string]any); ok { + contentList = append(contentList, chatMessageContent{ + Type: contentTypeInputAudio, + InputAudio: &chatMessageContentAudio{ + Data: subObj["data"].(string), + Format: subObj["format"].(string), + }, + }) + } + case contentTypeFile: + if subObj, ok := contentMap[contentTypeFile].(map[string]any); ok { + contentList = append(contentList, chatMessageContent{ + Type: contentTypeFile, + File: &chatMessageContentFile{ + FileId: subObj["file_id"].(string), + // FileName: subObj["file_name"].(string), + // FileData: subObj["file_data"].(string), + }, }) } } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 24f799ecf..336a1ce7e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -5,6 +5,7 @@ import ( "regexp" "strings" + "github.com/alibaba/higress/plugins/wasm-go/pkg/log" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ) @@ -109,6 +110,7 @@ func OverwriteRequestPathHeaderByCapability(headers http.Header, apiName string, } } headers.Set(":path", mappedPath) + log.Debugf("[OverwriteRequestPath] originPath=%s, mappedPath=%s", originPath, mappedPath) } func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {