feat(ai-proxy): add OpenRouter provider support (#2823)

This commit is contained in:
澄潭
2025-08-28 19:26:21 +08:00
committed by GitHub
parent b2ffeff7b8
commit 44c33617fa
11 changed files with 5684 additions and 184 deletions

View File

@@ -36,8 +36,18 @@ type claudeToolChoice struct {
}
type claudeChatMessage struct {
Role string `json:"role"`
Content any `json:"content"`
Role string `json:"role"`
Content claudeChatMessageContentWr `json:"content"`
}
// claudeChatMessageContentWr wraps the content to handle both string and array formats
type claudeChatMessageContentWr struct {
// StringValue holds simple text content
StringValue string
// ArrayValue holds multi-modal content
ArrayValue []claudeChatMessageContent
// IsString indicates whether this is a simple string or array
IsString bool
}
type claudeChatMessageContentSource struct {
@@ -49,23 +59,154 @@ type claudeChatMessageContentSource struct {
}
type claudeChatMessageContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *claudeChatMessageContentSource `json:"source,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *claudeChatMessageContentSource `json:"source,omitempty"`
CacheControl map[string]interface{} `json:"cache_control,omitempty"`
// Tool use fields
Id string `json:"id,omitempty"` // For tool_use
Name string `json:"name,omitempty"` // For tool_use
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
// Tool result fields
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
Content string `json:"content,omitempty"` // For tool_result
}
// UnmarshalJSON implements custom JSON unmarshaling for claudeChatMessageContentWr
func (ccw *claudeChatMessageContentWr) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var stringValue string
if err := json.Unmarshal(data, &stringValue); err == nil {
ccw.StringValue = stringValue
ccw.IsString = true
return nil
}
// Try to unmarshal as array of content blocks
var arrayValue []claudeChatMessageContent
if err := json.Unmarshal(data, &arrayValue); err == nil {
ccw.ArrayValue = arrayValue
ccw.IsString = false
return nil
}
return fmt.Errorf("content field must be either a string or an array of content blocks")
}
// MarshalJSON implements custom JSON marshaling for claudeChatMessageContentWr
func (ccw claudeChatMessageContentWr) MarshalJSON() ([]byte, error) {
if ccw.IsString {
return json.Marshal(ccw.StringValue)
}
return json.Marshal(ccw.ArrayValue)
}
// GetStringValue returns the string representation if it's a string, empty string otherwise
func (ccw claudeChatMessageContentWr) GetStringValue() string {
if ccw.IsString {
return ccw.StringValue
}
return ""
}
// GetArrayValue returns the array representation if it's an array, empty slice otherwise
func (ccw claudeChatMessageContentWr) GetArrayValue() []claudeChatMessageContent {
if !ccw.IsString {
return ccw.ArrayValue
}
return nil
}
// NewStringContent creates a new wrapper for string content
func NewStringContent(content string) claudeChatMessageContentWr {
return claudeChatMessageContentWr{
StringValue: content,
IsString: true,
}
}
// NewArrayContent creates a new wrapper for array content
func NewArrayContent(content []claudeChatMessageContent) claudeChatMessageContentWr {
return claudeChatMessageContentWr{
ArrayValue: content,
IsString: false,
}
}
// claudeSystemPrompt represents the system field which can be either a string or an array of text blocks
type claudeSystemPrompt struct {
// Will be set to the string value if system is a simple string
StringValue string
// Will be set to the array value if system is an array of text blocks
ArrayValue []claudeTextGenContent
// Indicates which type this represents
IsArray bool
}
// UnmarshalJSON implements custom JSON unmarshaling for claudeSystemPrompt
func (csp *claudeSystemPrompt) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var stringValue string
if err := json.Unmarshal(data, &stringValue); err == nil {
csp.StringValue = stringValue
csp.IsArray = false
return nil
}
// Try to unmarshal as array of text blocks
var arrayValue []claudeTextGenContent
if err := json.Unmarshal(data, &arrayValue); err == nil {
csp.ArrayValue = arrayValue
csp.IsArray = true
return nil
}
return fmt.Errorf("system field must be either a string or an array of text blocks")
}
// MarshalJSON implements custom JSON marshaling for claudeSystemPrompt
func (csp claudeSystemPrompt) MarshalJSON() ([]byte, error) {
if csp.IsArray {
return json.Marshal(csp.ArrayValue)
}
return json.Marshal(csp.StringValue)
}
// String returns the string representation of the system prompt
func (csp claudeSystemPrompt) String() string {
if csp.IsArray {
// Concatenate all text blocks
var parts []string
for _, block := range csp.ArrayValue {
if block.Text != "" {
parts = append(parts, block.Text)
}
}
return strings.Join(parts, "\n")
}
return csp.StringValue
}
// claudeThinkingConfig represents the thinking configuration for Claude
type claudeThinkingConfig struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens,omitempty"`
}
type claudeTextGenRequest struct {
Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"`
System claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
}
type claudeTextGenResponse struct {
@@ -81,8 +222,13 @@ type claudeTextGenResponse struct {
}
type claudeTextGenContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
Id string `json:"id,omitempty"` // For tool_use
Name string `json:"name,omitempty"` // For tool_use
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
Signature string `json:"signature,omitempty"` // For thinking
Thinking string `json:"thinking,omitempty"` // For thinking
}
type claudeTextGenUsage struct {
@@ -99,7 +245,7 @@ type claudeTextGenError struct {
type claudeTextGenStreamResponse struct {
Type string `json:"type"`
Message *claudeTextGenResponse `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *claudeTextGenContent `json:"content_block,omitempty"`
Delta *claudeTextGenDelta `json:"delta,omitempty"`
Usage *claudeTextGenUsage `json:"usage,omitempty"`
@@ -107,13 +253,13 @@ type claudeTextGenStreamResponse struct {
type claudeTextGenDelta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Text string `json:"text,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
}
func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
if len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
@@ -255,7 +401,10 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
for _, message := range origRequest.Messages {
if message.Role == roleSystem {
claudeRequest.System = message.StringContent()
claudeRequest.System = claudeSystemPrompt{
StringValue: message.StringContent(),
IsArray: false,
}
continue
}
@@ -263,7 +412,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
claudeMessage.Content = NewStringContent(message.StringContent())
} else {
chatMessageContents := make([]claudeChatMessageContent, 0)
for _, messageContent := range message.ParseContent() {
@@ -310,7 +459,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
continue
}
}
claudeMessage.Content = chatMessageContents
claudeMessage.Content = NewArrayContent(chatMessageContents)
}
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
@@ -342,19 +491,25 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
}
return &chatCompletionResponse{
response := &chatCompletionResponse{
Id: origResponse.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: []chatCompletionChoice{choice},
Usage: &usage{
}
// Include usage information if available
if origResponse.Usage.InputTokens > 0 || origResponse.Usage.OutputTokens > 0 {
response.Usage = &usage{
PromptTokens: origResponse.Usage.InputTokens,
CompletionTokens: origResponse.Usage.OutputTokens,
TotalTokens: origResponse.Usage.InputTokens + origResponse.Usage.OutputTokens,
},
}
}
return response
}
func stopReasonClaude2OpenAI(reason *string) string {
@@ -376,31 +531,47 @@ func stopReasonClaude2OpenAI(reason *string) string {
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
switch origResponse.Type {
case "message_start":
c.messageId = origResponse.Message.Id
c.usage = usage{
PromptTokens: origResponse.Message.Usage.InputTokens,
CompletionTokens: origResponse.Message.Usage.OutputTokens,
if origResponse.Message != nil {
c.messageId = origResponse.Message.Id
c.usage = usage{
PromptTokens: origResponse.Message.Usage.InputTokens,
CompletionTokens: origResponse.Message.Usage.OutputTokens,
}
c.serviceTier = origResponse.Message.Usage.ServiceTier
}
var index int
if origResponse.Index != nil {
index = *origResponse.Index
}
c.serviceTier = origResponse.Message.Usage.ServiceTier
choice := chatCompletionChoice{
Index: origResponse.Index,
Index: index,
Delta: &chatMessage{Role: roleAssistant, Content: ""},
}
return c.createChatCompletionResponse(ctx, origResponse, choice)
case "content_block_delta":
var index int
if origResponse.Index != nil {
index = *origResponse.Index
}
choice := chatCompletionChoice{
Index: origResponse.Index,
Index: index,
Delta: &chatMessage{Content: origResponse.Delta.Text},
}
return c.createChatCompletionResponse(ctx, origResponse, choice)
case "message_delta":
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
if origResponse.Usage != nil {
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
}
var index int
if origResponse.Index != nil {
index = *origResponse.Index
}
choice := chatCompletionChoice{
Index: origResponse.Index,
Index: index,
Delta: &chatMessage{},
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.Delta.StopReason)),
}
@@ -449,10 +620,17 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.System == "" {
request.System = content
systemStr := request.System.String()
if systemStr == "" {
request.System = claudeSystemPrompt{
StringValue: content,
IsArray: false,
}
} else {
request.System = content + "\n" + request.System
request.System = claudeSystemPrompt{
StringValue: content + "\n" + systemStr,
IsArray: false,
}
}
return json.Marshal(request)