mirror of
https://github.com/alibaba/higress.git
synced 2026-06-26 02:35:02 +08:00
fix(ai-proxy): add ids for Vertex tool calls (#3990)
Signed-off-by: DENG <33118163+XinhhD@users.noreply.github.com> Co-authored-by: woody <yaodiwu618@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
"github.com/higress-group/wasm-go/pkg/log"
|
"github.com/higress-group/wasm-go/pkg/log"
|
||||||
@@ -47,6 +48,7 @@ const (
|
|||||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||||
contextVertexRawMarker = "isVertexRawRequest"
|
contextVertexRawMarker = "isVertexRawRequest"
|
||||||
contextVertexStreamDoneMarker = "vertexStreamDoneSent"
|
contextVertexStreamDoneMarker = "vertexStreamDoneSent"
|
||||||
|
contextVertexStreamToolCallIDs = "vertexStreamToolCallIDs"
|
||||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||||
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
vertexImageVariationDefaultPrompt = "Create variations of the provided image."
|
||||||
)
|
)
|
||||||
@@ -871,6 +873,7 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
|
|||||||
args, _ := json.Marshal(part.FunctionCall.Args)
|
args, _ := json.Marshal(part.FunctionCall.Args)
|
||||||
choice.Message.ToolCalls = []toolCall{
|
choice.Message.ToolCalls = []toolCall{
|
||||||
{
|
{
|
||||||
|
Id: newOpenAIToolCallID(),
|
||||||
Type: "function",
|
Type: "function",
|
||||||
ThoughtSignature: part.ThoughtSignature,
|
ThoughtSignature: part.ThoughtSignature,
|
||||||
ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature),
|
ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature),
|
||||||
@@ -981,6 +984,7 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
|||||||
choice.Delta = &chatMessage{
|
choice.Delta = &chatMessage{
|
||||||
ToolCalls: []toolCall{
|
ToolCalls: []toolCall{
|
||||||
{
|
{
|
||||||
|
Id: getVertexStreamToolCallID(ctx, 0),
|
||||||
Type: "function",
|
Type: "function",
|
||||||
ThoughtSignature: part.ThoughtSignature,
|
ThoughtSignature: part.ThoughtSignature,
|
||||||
ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature),
|
ExtraContent: buildGoogleThoughtSignatureExtraContent(part.ThoughtSignature),
|
||||||
@@ -1025,6 +1029,24 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
|||||||
return &streamResponse
|
return &streamResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newOpenAIToolCallID() string {
|
||||||
|
return fmt.Sprintf("call_%s", uuid.New().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func getVertexStreamToolCallID(ctx wrapper.HttpContext, index int) string {
|
||||||
|
toolCallIDs, _ := ctx.GetContext(contextVertexStreamToolCallIDs).(map[int]string)
|
||||||
|
if toolCallIDs == nil {
|
||||||
|
toolCallIDs = make(map[int]string)
|
||||||
|
ctx.SetContext(contextVertexStreamToolCallIDs, toolCallIDs)
|
||||||
|
}
|
||||||
|
if id := toolCallIDs[index]; id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
id := newOpenAIToolCallID()
|
||||||
|
toolCallIDs[index] = id
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
|
||||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package provider
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||||
@@ -609,6 +610,8 @@ func TestVertexProviderPreservesFunctionCallThoughtSignature(t *testing.T) {
|
|||||||
require.Len(t, response.Choices, 1)
|
require.Len(t, response.Choices, 1)
|
||||||
require.NotNil(t, response.Choices[0].Message)
|
require.NotNil(t, response.Choices[0].Message)
|
||||||
require.Len(t, response.Choices[0].Message.ToolCalls, 1)
|
require.Len(t, response.Choices[0].Message.ToolCalls, 1)
|
||||||
|
assert.NotEmpty(t, response.Choices[0].Message.ToolCalls[0].Id)
|
||||||
|
assert.True(t, strings.HasPrefix(response.Choices[0].Message.ToolCalls[0].Id, "call_"))
|
||||||
assert.Equal(t, "thought-signature-from-vertex", response.Choices[0].Message.ToolCalls[0].ThoughtSignature)
|
assert.Equal(t, "thought-signature-from-vertex", response.Choices[0].Message.ToolCalls[0].ThoughtSignature)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
@@ -617,6 +620,48 @@ func TestVertexProviderPreservesFunctionCallThoughtSignature(t *testing.T) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVertexProviderStreamToolCallIncludesStableID(t *testing.T) {
|
||||||
|
v := &vertexProvider{}
|
||||||
|
ctx := newMockMultipartHttpContext()
|
||||||
|
ctx.SetContext(ctxKeyFinalRequestModel, "gemini-3.1-pro-preview")
|
||||||
|
vertexResp := &vertexChatResponse{
|
||||||
|
ResponseId: "vertex-response-id",
|
||||||
|
Candidates: []vertexChatCandidate{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Content: vertexChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []vertexPart{
|
||||||
|
{
|
||||||
|
FunctionCall: &vertexFunctionCall{
|
||||||
|
Name: "lookup",
|
||||||
|
Args: map[string]interface{}{"query": "test"},
|
||||||
|
},
|
||||||
|
ThoughtSignature: "thought-signature-from-vertex",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FinishReason: "STOP",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
first := v.buildChatCompletionStreamResponse(ctx, vertexResp)
|
||||||
|
second := v.buildChatCompletionStreamResponse(ctx, vertexResp)
|
||||||
|
|
||||||
|
require.Len(t, first.Choices, 1)
|
||||||
|
require.NotNil(t, first.Choices[0].Delta)
|
||||||
|
require.Len(t, first.Choices[0].Delta.ToolCalls, 1)
|
||||||
|
firstID := first.Choices[0].Delta.ToolCalls[0].Id
|
||||||
|
assert.NotEmpty(t, firstID)
|
||||||
|
assert.True(t, strings.HasPrefix(firstID, "call_"))
|
||||||
|
|
||||||
|
require.Len(t, second.Choices, 1)
|
||||||
|
require.NotNil(t, second.Choices[0].Delta)
|
||||||
|
require.Len(t, second.Choices[0].Delta.ToolCalls, 1)
|
||||||
|
assert.Equal(t, firstID, second.Choices[0].Delta.ToolCalls[0].Id)
|
||||||
|
}
|
||||||
|
|
||||||
func TestVertexProviderRestoresFunctionCallThoughtSignature(t *testing.T) {
|
func TestVertexProviderRestoresFunctionCallThoughtSignature(t *testing.T) {
|
||||||
v := &vertexProvider{}
|
v := &vertexProvider{}
|
||||||
req := &chatCompletionRequest{
|
req := &chatCompletionRequest{
|
||||||
|
|||||||
Reference in New Issue
Block a user