mirror of
https://github.com/alibaba/higress.git
synced 2026-02-24 12:40:48 +08:00
feat(ai-proxy): convert developer role to system for unsupported providers (#3479)
This commit is contained in:
@@ -2,6 +2,7 @@ package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -151,6 +152,7 @@ const (
|
||||
protocolOriginal = "original"
|
||||
|
||||
roleSystem = "system"
|
||||
roleDeveloper = "developer"
|
||||
roleAssistant = "assistant"
|
||||
roleUser = "user"
|
||||
roleTool = "tool"
|
||||
@@ -193,6 +195,12 @@ type providerInitializer interface {
|
||||
var (
|
||||
errUnsupportedApiName = errors.New("unsupported API name")
|
||||
|
||||
// Providers that support the "developer" role. Other providers will have "developer" roles converted to "system".
|
||||
developerRoleSupportedProviders = map[string]bool{
|
||||
providerTypeOpenAI: true,
|
||||
providerTypeAzure: true,
|
||||
}
|
||||
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
providerTypeMoonshot: &moonshotProviderInitializer{},
|
||||
providerTypeAzure: &azureProviderInitializer{},
|
||||
@@ -838,6 +846,34 @@ func doGetMappedModel(model string, modelMapping map[string]string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// isDeveloperRoleSupported checks if the provider supports the "developer" role.
|
||||
func isDeveloperRoleSupported(providerType string) bool {
|
||||
return developerRoleSupportedProviders[providerType]
|
||||
}
|
||||
|
||||
// convertDeveloperRoleToSystem converts "developer" roles to "system" role in the request body.
|
||||
// This is used for providers that don't support the "developer" role.
|
||||
func convertDeveloperRoleToSystem(body []byte) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return body, fmt.Errorf("unable to unmarshal request for developer role conversion: %v", err)
|
||||
}
|
||||
|
||||
converted := false
|
||||
for i := range request.Messages {
|
||||
if request.Messages[i].Role == roleDeveloper {
|
||||
request.Messages[i].Role = roleSystem
|
||||
converted = true
|
||||
}
|
||||
}
|
||||
|
||||
if converted {
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
|
||||
body := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
@@ -976,6 +1012,18 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
}
|
||||
}
|
||||
|
||||
// convert developer role to system role for providers that don't support it
|
||||
if apiName == ApiNameChatCompletion && !isDeveloperRoleSupported(c.typ) {
|
||||
body, err = convertDeveloperRoleToSystem(body)
|
||||
if err != nil {
|
||||
log.Warnf("[developerRole] failed to convert developer role to system: %v", err)
|
||||
// Continue processing even if conversion fails
|
||||
err = nil
|
||||
} else {
|
||||
log.Debugf("[developerRole] converted developer role to system for provider: %s", c.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// use openai protocol (either original openai or converted from claude)
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
|
||||
Reference in New Issue
Block a user