diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/longcat.go b/plugins/wasm-go/extensions/ai-proxy/provider/longcat.go new file mode 100644 index 000000000..2049f274d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/longcat.go @@ -0,0 +1,90 @@ +package provider + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +// longcatProvider is the provider for LongCat AI service. + +const ( + longcatDomain = "api.longcat.chat" +) + +type longcatProviderInitializer struct{} + +func (m *longcatProviderInitializer) ValidateConfig(config *ProviderConfig) error { + if config.apiTokens == nil || len(config.apiTokens) == 0 { + return errors.New("no apiToken found in provider config") + } + return nil +} + +func (m *longcatProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameModels): PathOpenAIModels, + } +} + +func (m *longcatProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(m.DefaultCapabilities()) + return &longcatProvider{ + config: config, + contextCache: createContextCache(&config), + }, nil +} + +type longcatProvider struct { + config ProviderConfig + contextCache *contextCache +} + +func (m *longcatProvider) GetProviderType() string { + return providerTypeLongcat +} + +func (m *longcatProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { + m.config.handleRequestHeaders(m, ctx, apiName) + return nil +} + +func (m *longcatProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities) + util.OverwriteRequestHostHeader(headers, longcatDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *longcatProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + if !m.config.isSupportedAPI(apiName) { + return types.ActionContinue, errUnsupportedApiName + } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body) +} + +func (m *longcatProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + if m.config.responseJsonSchema != nil && apiName == ApiNameChatCompletion { + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return nil, err + } + request.ResponseFormat = m.config.responseJsonSchema + body, err := json.Marshal(request) + if err != nil { + return nil, err + } + return body, nil + } + // For testing purposes, skip defaultTransformRequestBody if ctx is nil + if ctx != nil { + return m.config.defaultTransformRequestBody(ctx, apiName, body) + } + return body, nil +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/longcat_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/longcat_test.go new file mode 100644 index 000000000..9364a3cf5 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/longcat_test.go @@ -0,0 +1,205 @@ +package provider + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLongcatProviderInitializer_ValidateConfig(t *testing.T) { + initializer := &longcatProviderInitializer{} + + t.Run("valid_config_with_api_tokens", func(t *testing.T) { + config := &ProviderConfig{ + apiTokens: []string{"test-token"}, + } + err := initializer.ValidateConfig(config) + assert.NoError(t, err) + }) + + t.Run("invalid_config_without_api_tokens", func(t *testing.T) { + config := &ProviderConfig{ + apiTokens: nil, + } + err := initializer.ValidateConfig(config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no apiToken found in provider config") + }) + + t.Run("invalid_config_with_empty_api_tokens", func(t *testing.T) { + config := &ProviderConfig{ + apiTokens: []string{}, + } + err := initializer.ValidateConfig(config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no apiToken found in provider config") + }) +} + +func TestLongcatProviderInitializer_DefaultCapabilities(t *testing.T) { + initializer := &longcatProviderInitializer{} + + capabilities := initializer.DefaultCapabilities() + expected := map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameModels): PathOpenAIModels, + } + + assert.Equal(t, expected, capabilities) +} + +func TestLongcatProviderInitializer_CreateProvider(t *testing.T) { + initializer := &longcatProviderInitializer{} + + config := ProviderConfig{ + apiTokens: []string{"test-token"}, + } + + provider, err := initializer.CreateProvider(config) + require.NoError(t, err) + require.NotNil(t, provider) + + assert.Equal(t, providerTypeLongcat, provider.GetProviderType()) + + longcatProvider, ok := provider.(*longcatProvider) + require.True(t, ok) + assert.NotNil(t, longcatProvider.config.apiTokens) + assert.Equal(t, []string{"test-token"}, longcatProvider.config.apiTokens) +} + +func TestLongcatProvider_GetProviderType(t *testing.T) { + provider := &longcatProvider{ + config: ProviderConfig{ + apiTokens: []string{"test-token"}, + }, + contextCache: createContextCache(&ProviderConfig{}), + } + + assert.Equal(t, providerTypeLongcat, provider.GetProviderType()) +} + +func TestLongcatProvider_IsSupportedAPI(t *testing.T) { + provider := &longcatProvider{ + config: ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + }, + }, + } + + t.Run("supported_api", func(t *testing.T) { + assert.True(t, provider.config.isSupportedAPI(ApiNameChatCompletion)) + assert.True(t, provider.config.isSupportedAPI(ApiNameEmbeddings)) + }) + + t.Run("unsupported_api", func(t *testing.T) { + assert.False(t, provider.config.isSupportedAPI(ApiName("unsupported"))) + assert.False(t, provider.config.isSupportedAPI(ApiNameModels)) + }) +} + +func TestLongcatProvider_TransformRequestBody(t *testing.T) { + t.Run("with_response_schema", func(t *testing.T) { + provider := &longcatProvider{ + config: ProviderConfig{ + apiTokens: []string{"test-token"}, + responseJsonSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "answer": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + } + + requestBody := `{"model":"test","messages":[{"role":"user","content":"Hello"}]}` + + result, err := provider.TransformRequestBody(nil, ApiNameChatCompletion, []byte(requestBody)) + require.NoError(t, err) + + var transformedRequest chatCompletionRequest + err = json.Unmarshal(result, &transformedRequest) + require.NoError(t, err) + + assert.Equal(t, provider.config.responseJsonSchema, transformedRequest.ResponseFormat) + }) + + t.Run("invalid_json_request", func(t *testing.T) { + provider := &longcatProvider{ + config: ProviderConfig{ + responseJsonSchema: map[string]interface{}{ + "type": "object", + }, + }, + } + + requestBody := `invalid json` + + _, err := provider.TransformRequestBody(nil, ApiNameChatCompletion, []byte(requestBody)) + assert.Error(t, err) + }) + + t.Run("without_response_schema", func(t *testing.T) { + provider := &longcatProvider{ + config: ProviderConfig{ + apiTokens: []string{"test-token"}, + }, + } + + requestBody := `{"model":"test","messages":[{"role":"user","content":"Hello"}]}` + + result, err := provider.TransformRequestBody(nil, ApiNameChatCompletion, []byte(requestBody)) + assert.NoError(t, err) + + var transformedRequest chatCompletionRequest + err = json.Unmarshal(result, &transformedRequest) + require.NoError(t, err) + + // Without response schema, the request should remain unchanged + assert.Nil(t, transformedRequest.ResponseFormat) + }) +} + +func TestLongcatProvider_Integration(t *testing.T) { + // Test the complete flow from initialization to basic functionality + initializer := &longcatProviderInitializer{} + + config := ProviderConfig{ + apiTokens: []string{"test-token-123"}, + } + + provider, err := initializer.CreateProvider(config) + require.NoError(t, err) + + // Test provider type + assert.Equal(t, providerTypeLongcat, provider.GetProviderType()) + + // Test capabilities are set correctly + longcatProvider, ok := provider.(*longcatProvider) + require.True(t, ok) + + expectedCapabilities := map[string]string{ + string(ApiNameChatCompletion): PathOpenAIChatCompletions, + string(ApiNameEmbeddings): PathOpenAIEmbeddings, + string(ApiNameModels): PathOpenAIModels, + } + assert.Equal(t, expectedCapabilities, longcatProvider.config.capabilities) + + // Test API support + assert.True(t, longcatProvider.config.isSupportedAPI(ApiNameChatCompletion)) + assert.True(t, longcatProvider.config.isSupportedAPI(ApiNameEmbeddings)) + assert.True(t, longcatProvider.config.isSupportedAPI(ApiNameModels)) + assert.False(t, longcatProvider.config.isSupportedAPI(ApiName("unsupported"))) +} + +// Test constants +func TestLongcatConstants(t *testing.T) { + assert.Equal(t, "api.longcat.chat", longcatDomain) + assert.Equal(t, "longcat", providerTypeLongcat) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 513a3e47d..726bce521 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -133,6 +133,7 @@ const ( providerTypeVertex = "vertex" providerTypeTriton = "triton" providerTypeOpenRouter = "openrouter" + providerTypeLongcat = "longcat" protocolOpenAI = "openai" protocolOriginal = "original" @@ -213,6 +214,7 @@ var ( providerTypeVertex: &vertexProviderInitializer{}, providerTypeTriton: &tritonProviderInitializer{}, providerTypeOpenRouter: &openrouterProviderInitializer{}, + providerTypeLongcat: &longcatProviderInitializer{}, } ) @@ -852,6 +854,9 @@ func (c *ProviderConfig) IsSupportedAPI(apiName ApiName) bool { } func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) { + if c.capabilities == nil { + c.capabilities = make(map[string]string) + } for capability, path := range capabilities { c.capabilities[capability] = path }