mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
feat(ai-proxy): add OpenRouter provider support (#2823)
This commit is contained in:
@@ -36,8 +36,18 @@ type claudeToolChoice struct {
|
||||
}
|
||||
|
||||
type claudeChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content claudeChatMessageContentWr `json:"content"`
|
||||
}
|
||||
|
||||
// claudeChatMessageContentWr wraps the content to handle both string and array formats
|
||||
type claudeChatMessageContentWr struct {
|
||||
// StringValue holds simple text content
|
||||
StringValue string
|
||||
// ArrayValue holds multi-modal content
|
||||
ArrayValue []claudeChatMessageContent
|
||||
// IsString indicates whether this is a simple string or array
|
||||
IsString bool
|
||||
}
|
||||
|
||||
type claudeChatMessageContentSource struct {
|
||||
@@ -49,23 +59,154 @@ type claudeChatMessageContentSource struct {
|
||||
}
|
||||
|
||||
type claudeChatMessageContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *claudeChatMessageContentSource `json:"source,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *claudeChatMessageContentSource `json:"source,omitempty"`
|
||||
CacheControl map[string]interface{} `json:"cache_control,omitempty"`
|
||||
// Tool use fields
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
// Tool result fields
|
||||
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
|
||||
Content string `json:"content,omitempty"` // For tool_result
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for claudeChatMessageContentWr
|
||||
func (ccw *claudeChatMessageContentWr) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as string first
|
||||
var stringValue string
|
||||
if err := json.Unmarshal(data, &stringValue); err == nil {
|
||||
ccw.StringValue = stringValue
|
||||
ccw.IsString = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as array of content blocks
|
||||
var arrayValue []claudeChatMessageContent
|
||||
if err := json.Unmarshal(data, &arrayValue); err == nil {
|
||||
ccw.ArrayValue = arrayValue
|
||||
ccw.IsString = false
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content field must be either a string or an array of content blocks")
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for claudeChatMessageContentWr
|
||||
func (ccw claudeChatMessageContentWr) MarshalJSON() ([]byte, error) {
|
||||
if ccw.IsString {
|
||||
return json.Marshal(ccw.StringValue)
|
||||
}
|
||||
return json.Marshal(ccw.ArrayValue)
|
||||
}
|
||||
|
||||
// GetStringValue returns the string representation if it's a string, empty string otherwise
|
||||
func (ccw claudeChatMessageContentWr) GetStringValue() string {
|
||||
if ccw.IsString {
|
||||
return ccw.StringValue
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetArrayValue returns the array representation if it's an array, empty slice otherwise
|
||||
func (ccw claudeChatMessageContentWr) GetArrayValue() []claudeChatMessageContent {
|
||||
if !ccw.IsString {
|
||||
return ccw.ArrayValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewStringContent creates a new wrapper for string content
|
||||
func NewStringContent(content string) claudeChatMessageContentWr {
|
||||
return claudeChatMessageContentWr{
|
||||
StringValue: content,
|
||||
IsString: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewArrayContent creates a new wrapper for array content
|
||||
func NewArrayContent(content []claudeChatMessageContent) claudeChatMessageContentWr {
|
||||
return claudeChatMessageContentWr{
|
||||
ArrayValue: content,
|
||||
IsString: false,
|
||||
}
|
||||
}
|
||||
|
||||
// claudeSystemPrompt represents the system field which can be either a string or an array of text blocks
|
||||
type claudeSystemPrompt struct {
|
||||
// Will be set to the string value if system is a simple string
|
||||
StringValue string
|
||||
// Will be set to the array value if system is an array of text blocks
|
||||
ArrayValue []claudeTextGenContent
|
||||
// Indicates which type this represents
|
||||
IsArray bool
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for claudeSystemPrompt
|
||||
func (csp *claudeSystemPrompt) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as string first
|
||||
var stringValue string
|
||||
if err := json.Unmarshal(data, &stringValue); err == nil {
|
||||
csp.StringValue = stringValue
|
||||
csp.IsArray = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as array of text blocks
|
||||
var arrayValue []claudeTextGenContent
|
||||
if err := json.Unmarshal(data, &arrayValue); err == nil {
|
||||
csp.ArrayValue = arrayValue
|
||||
csp.IsArray = true
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("system field must be either a string or an array of text blocks")
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for claudeSystemPrompt
|
||||
func (csp claudeSystemPrompt) MarshalJSON() ([]byte, error) {
|
||||
if csp.IsArray {
|
||||
return json.Marshal(csp.ArrayValue)
|
||||
}
|
||||
return json.Marshal(csp.StringValue)
|
||||
}
|
||||
|
||||
// String returns the string representation of the system prompt
|
||||
func (csp claudeSystemPrompt) String() string {
|
||||
if csp.IsArray {
|
||||
// Concatenate all text blocks
|
||||
var parts []string
|
||||
for _, block := range csp.ArrayValue {
|
||||
if block.Text != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
return csp.StringValue
|
||||
}
|
||||
|
||||
// claudeThinkingConfig represents the thinking configuration for Claude
|
||||
type claudeThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System claudeSystemPrompt `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenResponse struct {
|
||||
@@ -81,8 +222,13 @@ type claudeTextGenResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
Signature string `json:"signature,omitempty"` // For thinking
|
||||
Thinking string `json:"thinking,omitempty"` // For thinking
|
||||
}
|
||||
|
||||
type claudeTextGenUsage struct {
|
||||
@@ -99,7 +245,7 @@ type claudeTextGenError struct {
|
||||
type claudeTextGenStreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Message *claudeTextGenResponse `json:"message,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *claudeTextGenContent `json:"content_block,omitempty"`
|
||||
Delta *claudeTextGenDelta `json:"delta,omitempty"`
|
||||
Usage *claudeTextGenUsage `json:"usage,omitempty"`
|
||||
@@ -107,13 +253,13 @@ type claudeTextGenStreamResponse struct {
|
||||
|
||||
type claudeTextGenDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
Text string `json:"text,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
if len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
@@ -255,7 +401,10 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
|
||||
for _, message := range origRequest.Messages {
|
||||
if message.Role == roleSystem {
|
||||
claudeRequest.System = message.StringContent()
|
||||
claudeRequest.System = claudeSystemPrompt{
|
||||
StringValue: message.StringContent(),
|
||||
IsArray: false,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -263,7 +412,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
claudeMessage.Content = NewStringContent(message.StringContent())
|
||||
} else {
|
||||
chatMessageContents := make([]claudeChatMessageContent, 0)
|
||||
for _, messageContent := range message.ParseContent() {
|
||||
@@ -310,7 +459,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
continue
|
||||
}
|
||||
}
|
||||
claudeMessage.Content = chatMessageContents
|
||||
claudeMessage.Content = NewArrayContent(chatMessageContents)
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
}
|
||||
@@ -342,19 +491,25 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
|
||||
}
|
||||
|
||||
return &chatCompletionResponse{
|
||||
response := &chatCompletionResponse{
|
||||
Id: origResponse.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: &usage{
|
||||
}
|
||||
|
||||
// Include usage information if available
|
||||
if origResponse.Usage.InputTokens > 0 || origResponse.Usage.OutputTokens > 0 {
|
||||
response.Usage = &usage{
|
||||
PromptTokens: origResponse.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Usage.OutputTokens,
|
||||
TotalTokens: origResponse.Usage.InputTokens + origResponse.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func stopReasonClaude2OpenAI(reason *string) string {
|
||||
@@ -376,31 +531,47 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
||||
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
|
||||
switch origResponse.Type {
|
||||
case "message_start":
|
||||
c.messageId = origResponse.Message.Id
|
||||
c.usage = usage{
|
||||
PromptTokens: origResponse.Message.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Message.Usage.OutputTokens,
|
||||
if origResponse.Message != nil {
|
||||
c.messageId = origResponse.Message.Id
|
||||
c.usage = usage{
|
||||
PromptTokens: origResponse.Message.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Message.Usage.OutputTokens,
|
||||
}
|
||||
c.serviceTier = origResponse.Message.Usage.ServiceTier
|
||||
}
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
c.serviceTier = origResponse.Message.Usage.ServiceTier
|
||||
choice := chatCompletionChoice{
|
||||
Index: origResponse.Index,
|
||||
Index: index,
|
||||
Delta: &chatMessage{Role: roleAssistant, Content: ""},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "content_block_delta":
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
choice := chatCompletionChoice{
|
||||
Index: origResponse.Index,
|
||||
Index: index,
|
||||
Delta: &chatMessage{Content: origResponse.Delta.Text},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "message_delta":
|
||||
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
|
||||
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
|
||||
if origResponse.Usage != nil {
|
||||
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
|
||||
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
|
||||
}
|
||||
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
choice := chatCompletionChoice{
|
||||
Index: origResponse.Index,
|
||||
Index: index,
|
||||
Delta: &chatMessage{},
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.Delta.StopReason)),
|
||||
}
|
||||
@@ -449,10 +620,17 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
if request.System == "" {
|
||||
request.System = content
|
||||
systemStr := request.System.String()
|
||||
if systemStr == "" {
|
||||
request.System = claudeSystemPrompt{
|
||||
StringValue: content,
|
||||
IsArray: false,
|
||||
}
|
||||
} else {
|
||||
request.System = content + "\n" + request.System
|
||||
request.System = claudeSystemPrompt{
|
||||
StringValue: content + "\n" + systemStr,
|
||||
IsArray: false,
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(request)
|
||||
|
||||
Reference in New Issue
Block a user