mirror of
https://github.com/alibaba/higress.git
synced 2026-06-26 02:35:02 +08:00
fix(ai-proxy): add ids for Vertex tool calls (#3990)
Signed-off-by: DENG <33118163+XinhhD@users.noreply.github.com> Co-authored-by: woody <yaodiwu618@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
@@ -47,6 +48,7 @@ const (
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
contextVertexRawMarker = "isVertexRawRequest"
|
||||
contextVertexStreamDoneMarker = "vertexStreamDoneSent"
|
||||
contextVertexStreamToolCallIDs = "vertexStreamToolCallIDs"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||
)
|
||||
@@ -871,6 +873,7 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
|
||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||
choice.Message.ToolCalls = []toolCall{
|
||||
{
|
||||
Id: newOpenAIToolCallID(),
|
||||
Type: "function",
|
||||
ThoughtSignature: part.ThoughtSignature,
|
||||
ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature),
|
||||
@@ -981,6 +984,7 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
||||
choice.Delta = &chatMessage{
|
||||
ToolCalls: []toolCall{
|
||||
{
|
||||
Id: getVertexStreamToolCallID(ctx, 0),
|
||||
Type: "function",
|
||||
ThoughtSignature: part.ThoughtSignature,
|
||||
ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature),
|
||||
@@ -1025,6 +1029,24 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
||||
return &streamResponse
|
||||
}
|
||||
|
||||
func newOpenAIToolCallID() string {
|
||||
return fmt.Sprintf("call_%s", uuid.New().String())
|
||||
}
|
||||
|
||||
func getVertexStreamToolCallID(ctx wrapper.HttpContext, index int) string {
|
||||
toolCallIDs, _ := ctx.GetContext(contextVertexStreamToolCallIDs).(map[int]string)
|
||||
if toolCallIDs == nil {
|
||||
toolCallIDs = make(map[int]string)
|
||||
ctx.SetContext(contextVertexStreamToolCallIDs, toolCallIDs)
|
||||
}
|
||||
if id := toolCallIDs[index]; id != "" {
|
||||
return id
|
||||
}
|
||||
id := newOpenAIToolCallID()
|
||||
toolCallIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
@@ -609,6 +610,8 @@ func TestVertexProviderPreservesFunctionCallThoughtSignature(t *testing.T) {
|
||||
require.Len(t, response.Choices, 1)
|
||||
require.NotNil(t, response.Choices[0].Message)
|
||||
require.Len(t, response.Choices[0].Message.ToolCalls, 1)
|
||||
assert.NotEmpty(t, response.Choices[0].Message.ToolCalls[0].Id)
|
||||
assert.True(t, strings.HasPrefix(response.Choices[0].Message.ToolCalls[0].Id, "call_"))
|
||||
assert.Equal(t, "thought-signature-from-vertex", response.Choices[0].Message.ToolCalls[0].ThoughtSignature)
|
||||
assert.Equal(
|
||||
t,
|
||||
@@ -617,6 +620,48 @@ func TestVertexProviderPreservesFunctionCallThoughtSignature(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
func TestVertexProviderStreamToolCallIncludesStableID(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
ctx := newMockMultipartHttpContext()
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, "gemini-3.1-pro-preview")
|
||||
vertexResp := &vertexChatResponse{
|
||||
ResponseId: "vertex-response-id",
|
||||
Candidates: []vertexChatCandidate{
|
||||
{
|
||||
Index: 0,
|
||||
Content: vertexChatContent{
|
||||
Role: "model",
|
||||
Parts: []vertexPart{
|
||||
{
|
||||
FunctionCall: &vertexFunctionCall{
|
||||
Name: "lookup",
|
||||
Args: map[string]interface{}{"query": "test"},
|
||||
},
|
||||
ThoughtSignature: "thought-signature-from-vertex",
|
||||
},
|
||||
},
|
||||
},
|
||||
FinishReason: "STOP",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
first := v.buildChatCompletionStreamResponse(ctx, vertexResp)
|
||||
second := v.buildChatCompletionStreamResponse(ctx, vertexResp)
|
||||
|
||||
require.Len(t, first.Choices, 1)
|
||||
require.NotNil(t, first.Choices[0].Delta)
|
||||
require.Len(t, first.Choices[0].Delta.ToolCalls, 1)
|
||||
firstID := first.Choices[0].Delta.ToolCalls[0].Id
|
||||
assert.NotEmpty(t, firstID)
|
||||
assert.True(t, strings.HasPrefix(firstID, "call_"))
|
||||
|
||||
require.Len(t, second.Choices, 1)
|
||||
require.NotNil(t, second.Choices[0].Delta)
|
||||
require.Len(t, second.Choices[0].Delta.ToolCalls, 1)
|
||||
assert.Equal(t, firstID, second.Choices[0].Delta.ToolCalls[0].Id)
|
||||
}
|
||||
|
||||
func TestVertexProviderRestoresFunctionCallThoughtSignature(t *testing.T) {
|
||||
v := &vertexProvider{}
|
||||
req := &chatCompletionRequest{
|
||||
|
||||
Reference in New Issue
Block a user