diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go new file mode 100644 index 000000000..a4b70302c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -0,0 +1,85 @@ +package provider + +import ( + "fmt" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" +) + +// groqProvider is the provider for Groq service. +const ( + groqDomain = "api.groq.com" + groqChatCompletionPath = "/openai/v1/chat/completions" +) + +type groqProviderInitializer struct{} + +func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error { + return nil +} + +func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &groqProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type groqProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (m *groqProvider) GetProviderType() string { + return providerTypeGroq +} + +func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + _ = util.OverwriteRequestPath(groqChatCompletionPath) + _ = util.OverwriteRequestHost(groqDomain) + _ = proxywasm.ReplaceHttpRequestHeader("Authorization", "Bearer "+m.config.GetRandomToken()) + + if m.contextCache == nil { + ctx.DontReadRequestBody() + } else { + _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + } + + return types.ActionContinue, nil +} + +func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + if m.contextCache == nil { + return types.ActionContinue, nil + } + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return types.ActionContinue, err + } + err := m.contextCache.GetContent(func(content string, err error) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + insertContextMessage(request, content) + if err := replaceJsonRequestBody(request, log); err != nil { + _ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } + }, log) + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 1594ce2aa..b64095448 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -9,7 +9,7 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) -// azureProvider is the provider for Azure OpenAI service. +// openaiProvider is the provider for OpenAI service. const ( openaiDomain = "api.openai.com" diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index b917e78ec..ee440e64d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -19,6 +19,7 @@ const ( providerTypeAzure = "azure" providerTypeQwen = "qwen" providerTypeOpenAI = "openai" + providerTypeGroq = "groq" protocolOpenAI = "openai" protocolOriginal = "original" @@ -51,6 +52,7 @@ var ( providerTypeAzure: &azureProviderInitializer{}, providerTypeQwen: &qwenProviderInitializer{}, providerTypeOpenAI: &openaiProviderInitializer{}, + providerTypeGroq: &groqProviderInitializer{}, } )