diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 7c7769cad..460ede929 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -17,6 +17,7 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/google/uuid" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" @@ -47,6 +48,7 @@ const ( contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" contextVertexRawMarker = "isVertexRawRequest" contextVertexStreamDoneMarker = "vertexStreamDoneSent" + contextVertexStreamToolCallIDs = "vertexStreamToolCallIDs" vertexAnthropicVersion = "vertex-2023-10-16" vertexImageVariationDefaultPrompt = "Create variations of the provided image." ) @@ -871,6 +873,7 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re args, _ := json.Marshal(part.FunctionCall.Args) choice.Message.ToolCalls = []toolCall{ { + Id: newOpenAIToolCallID(), Type: "function", ThoughtSignature: part.ThoughtSignature, ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature), @@ -981,6 +984,7 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte choice.Delta = &chatMessage{ ToolCalls: []toolCall{ { + Id: getVertexStreamToolCallID(ctx, 0), Type: "function", ThoughtSignature: part.ThoughtSignature, ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature), @@ -1025,6 +1029,24 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte return &streamResponse } +func newOpenAIToolCallID() string { + return fmt.Sprintf("call_%s", uuid.New().String()) +} + +func getVertexStreamToolCallID(ctx wrapper.HttpContext, index int) string { + toolCallIDs, _ := ctx.GetContext(contextVertexStreamToolCallIDs).(map[int]string) + if toolCallIDs == nil { + toolCallIDs = make(map[int]string) + ctx.SetContext(contextVertexStreamToolCallIDs, toolCallIDs) + } + if id := toolCallIDs[index]; id != "" { + return id + } + id := newOpenAIToolCallID() + toolCallIDs[index] = id + return id +} + func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go index d7a88b333..51fad3477 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex_test.go @@ -2,6 +2,7 @@ package provider import ( "net/http" + "strings" "testing" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -609,6 +610,8 @@ func TestVertexProviderPreservesFunctionCallThoughtSignature(t *testing.T) { require.Len(t, response.Choices, 1) require.NotNil(t, response.Choices[0].Message) require.Len(t, response.Choices[0].Message.ToolCalls, 1) + assert.NotEmpty(t, response.Choices[0].Message.ToolCalls[0].Id) + assert.True(t, strings.HasPrefix(response.Choices[0].Message.ToolCalls[0].Id, "call_")) assert.Equal(t, "thought-signature-from-vertex", response.Choices[0].Message.ToolCalls[0].ThoughtSignature) assert.Equal( t, @@ -617,6 +620,48 @@ func TestVertexProviderPreservesFunctionCallThoughtSignature(t *testing.T) { ) } +func TestVertexProviderStreamToolCallIncludesStableID(t *testing.T) { + v := &vertexProvider{} + ctx := newMockMultipartHttpContext() + ctx.SetContext(ctxKeyFinalRequestModel, "gemini-3.1-pro-preview") + vertexResp := &vertexChatResponse{ + ResponseId: "vertex-response-id", + Candidates: []vertexChatCandidate{ + { + Index: 0, + Content: vertexChatContent{ + Role: "model", + Parts: []vertexPart{ + { + FunctionCall: &vertexFunctionCall{ + Name: "lookup", + Args: map[string]interface{}{"query": "test"}, + }, + ThoughtSignature: "thought-signature-from-vertex", + }, + }, + }, + FinishReason: "STOP", + }, + }, + } + + first := v.buildChatCompletionStreamResponse(ctx, vertexResp) + second := v.buildChatCompletionStreamResponse(ctx, vertexResp) + + require.Len(t, first.Choices, 1) + require.NotNil(t, first.Choices[0].Delta) + require.Len(t, first.Choices[0].Delta.ToolCalls, 1) + firstID := first.Choices[0].Delta.ToolCalls[0].Id + assert.NotEmpty(t, firstID) + assert.True(t, strings.HasPrefix(firstID, "call_")) + + require.Len(t, second.Choices, 1) + require.NotNil(t, second.Choices[0].Delta) + require.Len(t, second.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, firstID, second.Choices[0].Delta.ToolCalls[0].Id) +} + func TestVertexProviderRestoresFunctionCallThoughtSignature(t *testing.T) { v := &vertexProvider{} req := &chatCompletionRequest{