From a93847e07f1e73afb7d6afd38079a7b087f8eeca Mon Sep 17 00:00:00 2001 From: woody Date: Thu, 14 May 2026 16:18:00 +0800 Subject: [PATCH] Add Kling provider support (#3742) Signed-off-by: wydream --- plugins/wasm-go/extensions/ai-proxy/main.go | 3 +- .../wasm-go/extensions/ai-proxy/main_test.go | 7 + .../extensions/ai-proxy/provider/kling.go | 492 +++++++++ .../ai-proxy/provider/kling_test.go | 942 ++++++++++++++++++ .../extensions/ai-proxy/provider/provider.go | 25 +- .../ai-proxy/provider/provider_test.go | 15 + .../wasm-go/extensions/ai-proxy/test/kling.go | 418 ++++++++ .../wasm-go/extensions/ai-proxy/test/util.go | 9 + .../wasm-go/extensions/ai-proxy/util/http.go | 4 +- 9 files changed, 1910 insertions(+), 5 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/kling.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/provider/kling_test.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/test/kling.go diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 4e2b14e6..c887caa5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -264,8 +264,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf return types.ActionContinue } + _, hasRequestBodyHandler := activeProvider.(provider.RequestBodyHandler) hasRequestBody := ctx.HasRequestBody() - if hasRequestBody { + if hasRequestBody && hasRequestBodyHandler { _ = proxywasm.RemoveHttpRequestHeader("Content-Length") ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) // Delay the header processing to allow changing in OnRequestBody diff --git a/plugins/wasm-go/extensions/ai-proxy/main_test.go b/plugins/wasm-go/extensions/ai-proxy/main_test.go index 178e1e1b..515c82d9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/main_test.go @@ -314,6 +314,13 @@ func TestGeneric(t *testing.T) { test.RunGenericOnHttpRequestBodyTests(t) } +func TestKling(t *testing.T) { + test.RunKlingParseConfigTests(t) + test.RunKlingOnHttpRequestHeadersTests(t) + test.RunKlingOnHttpRequestBodyTests(t) + test.RunKlingOnHttpResponseBodyTests(t) +} + func TestVertex(t *testing.T) { test.RunVertexParseConfigTests(t) test.RunVertexExpressModeOnHttpRequestHeadersTests(t) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/kling.go b/plugins/wasm-go/extensions/ai-proxy/provider/kling.go new file mode 100644 index 00000000..ad9e8554 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/kling.go @@ -0,0 +1,492 @@ +package provider + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + klingDefaultDomain = "api-singapore.klingai.com" + klingTextToVideoPath = "/v1/videos/text2video" + klingImageToVideoPath = "/v1/videos/image2video" + klingTextToVideoTaskPath = "/v1/videos/text2video/{video_id}" + klingImageToVideoTaskPath = "/v1/videos/image2video/{video_id}" + klingJWTLifetimeSeconds = int64(1800) + klingJWTNotBeforeSkewSecond = int64(5) + klingDefaultRefreshAhead = int64(60) + klingTaskTypeTextToVideo = "text2video" + klingTaskTypeImageToVideo = "image2video" + klingTextTaskIDPrefix = "kling-t2v-" + klingImageTaskIDPrefix = "kling-i2v-" + klingTaskTypeQueryKey = "kling_task_type" + ctxKeyKlingVideoTaskType = "klingVideoTaskType" +) + +type klingProviderInitializer struct{} + +func (k *klingProviderInitializer) ValidateConfig(config *ProviderConfig) error { + hasAccessKey := strings.TrimSpace(config.klingAccessKey) != "" + hasSecretKey := strings.TrimSpace(config.klingSecretKey) != "" + if hasAccessKey || hasSecretKey { + if !hasAccessKey || !hasSecretKey { + return errors.New("missing klingAccessKey or klingSecretKey in provider config") + } + return nil + } + if len(config.apiTokens) > 0 { + return nil + } + return errors.New("missing kling authentication parameters: either apiTokens or (klingAccessKey + klingSecretKey) is required") +} + +func (k *klingProviderInitializer) DefaultCapabilities() map[string]string { + return map[string]string{ + string(ApiNameVideos): klingTextToVideoPath, + string(ApiNameKlingImageToVideo): klingImageToVideoPath, + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + string(ApiNameKlingRetrieveImageVideo): klingImageToVideoTaskPath, + } +} + +func (k *klingProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + config.setDefaultCapabilities(k.DefaultCapabilities()) + if config.klingTokenRefreshAhead == 0 { + config.klingTokenRefreshAhead = klingDefaultRefreshAhead + } + provider := &klingProvider{ + config: config, + contextCache: createContextCache(&config), + } + if config.IsOriginal() { + return provider, nil + } + return &klingOpenAIProvider{klingProvider: provider}, nil +} + +type klingProvider struct { + config ProviderConfig + contextCache *contextCache + jwtToken string + jwtExpireAt int64 +} + +type klingOpenAIProvider struct { + *klingProvider +} + +func (k *klingProvider) GetProviderType() string { + return providerTypeKling +} + +func (k *klingProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error { + k.config.handleRequestHeaders(k, ctx, apiName) + if k.config.IsOriginal() { + ctx.DontReadRequestBody() + } + return nil +} + +func (k *klingOpenAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { + if !k.config.isSupportedAPI(apiName) { + return types.ActionContinue, errUnsupportedApiName + } + return k.config.handleRequestBody(k, k.contextCache, ctx, apiName, body) +} + +func (k *klingProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) { + if !k.config.IsOriginal() { + mappedPath := "" + if apiName == ApiNameRetrieveVideo { + mappedPath = k.mapRetrieveVideoPath(headers.Get(util.HeaderPath)) + } else { + mappedPath = util.MapRequestPathByCapability(string(apiName), headers.Get(util.HeaderPath), k.config.capabilities) + } + if mappedPath != "" { + util.OverwriteRequestPathHeader(headers, mappedPath) + } + } + if k.config.providerDomain == "" { + util.OverwriteRequestHostHeader(headers, klingDefaultDomain) + } + if token := k.getAuthorizationToken(ctx); token != "" { + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+token) + } + if !k.config.IsOriginal() { + headers.Del("Content-Length") + } +} + +func (k *klingProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) { + if apiName != ApiNameVideos { + return k.config.defaultTransformRequestBody(ctx, apiName, body) + } + + taskType := klingTaskTypeTextToVideo + targetPath := k.textCreateVideoPath() + if k.isImageToVideoRequest(body) { + taskType = klingTaskTypeImageToVideo + targetPath = k.imageCreateVideoPath() + } + ctx.SetContext(ctxKeyKlingVideoTaskType, taskType) + util.OverwriteRequestPathHeader(headers, klingPathWithOriginalQuery(ctx, headers.Get(util.HeaderPath), targetPath)) + return k.transformOpenAIVideoRequest(ctx, body) +} + +func (k *klingOpenAIProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) { + if apiName != ApiNameVideos { + return body, nil + } + + taskType, _ := ctx.GetContext(ctxKeyKlingVideoTaskType).(string) + switch taskType { + case klingTaskTypeTextToVideo: + return prefixKlingTaskIDs(body, klingTextTaskIDPrefix) + case klingTaskTypeImageToVideo: + return prefixKlingTaskIDs(body, klingImageTaskIDPrefix) + default: + return body, nil + } +} + +func (k *klingProvider) GetApiName(path string) ApiName { + switch { + case isKlingNativeRetrieveVideoPath(path): + return ApiNameRetrieveVideo + case isKlingNativeCreateVideoPath(path): + return ApiNameVideos + case util.RegRetrieveVideoPath.MatchString(path): + return ApiNameRetrieveVideo + default: + return "" + } +} + +func isKlingNativeCreateVideoPath(path string) bool { + return strings.HasSuffix(path, klingTextToVideoPath) || + strings.HasSuffix(path, klingImageToVideoPath) +} + +func isKlingNativeRetrieveVideoPath(path string) bool { + return hasSinglePathSegmentAfter(path, klingTextToVideoPath) || + hasSinglePathSegmentAfter(path, klingImageToVideoPath) +} + +func hasSinglePathSegmentAfter(path, prefix string) bool { + index := strings.Index(path, prefix+"/") + if index < 0 { + return false + } + remaining := path[index+len(prefix)+1:] + return remaining != "" && !strings.Contains(remaining, "/") +} + +func (k *klingProvider) getAuthorizationToken(ctx wrapper.HttpContext) string { + if k.isOfficialMode() { + return k.getJWTToken() + } + return k.config.GetApiTokenInUse(ctx) +} + +func (k *klingProvider) isOfficialMode() bool { + return strings.TrimSpace(k.config.klingAccessKey) != "" && strings.TrimSpace(k.config.klingSecretKey) != "" +} + +func (k *klingProvider) getJWTToken() string { + now := time.Now().Unix() + if k.jwtToken != "" && k.jwtExpireAt > now+k.config.klingTokenRefreshAhead { + return k.jwtToken + } + + token, expireAt, err := createKlingJWT(k.config.klingAccessKey, k.config.klingSecretKey, now) + if err != nil { + return "" + } + k.jwtToken = token + k.jwtExpireAt = expireAt + return k.jwtToken +} + +func createKlingJWT(accessKey, secretKey string, now int64) (string, int64, error) { + expireAt := now + klingJWTLifetimeSeconds + header := struct { + Alg string `json:"alg"` + Typ string `json:"typ"` + }{ + Alg: "HS256", + Typ: "JWT", + } + payload := struct { + Iss string `json:"iss"` + Exp int64 `json:"exp"` + Nbf int64 `json:"nbf"` + }{ + Iss: strings.TrimSpace(accessKey), + Exp: expireAt, + Nbf: now - klingJWTNotBeforeSkewSecond, + } + + headerJSON, err := json.Marshal(header) + if err != nil { + return "", 0, fmt.Errorf("unable to marshal kling jwt header: %v", err) + } + payloadJSON, err := json.Marshal(payload) + if err != nil { + return "", 0, fmt.Errorf("unable to marshal kling jwt payload: %v", err) + } + + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) + signingInput := headerB64 + "." + payloadB64 + mac := hmac.New(sha256.New, []byte(strings.TrimSpace(secretKey))) + _, _ = mac.Write([]byte(signingInput)) + signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return signingInput + "." + signature, expireAt, nil +} + +func (k *klingProvider) transformOpenAIVideoRequest(ctx wrapper.HttpContext, body []byte) ([]byte, error) { + model := gjson.GetBytes(body, "model") + modelPath := "model" + if !model.Exists() { + model = gjson.GetBytes(body, "model_name") + modelPath = "model_name" + } + if model.Exists() { + rawModel := model.String() + ctx.SetContext(ctxKeyOriginalRequestModel, rawModel) + mappedModel := getMappedModel(rawModel, k.config.modelMapping) + ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) + var err error + body, err = sjson.SetBytes(body, "model_name", mappedModel) + if err != nil { + return nil, err + } + if modelPath == "model" { + body, err = sjson.DeleteBytes(body, "model") + if err != nil { + return nil, err + } + } + } + return body, nil +} + +func (k *klingProvider) mapRetrieveVideoPath(originPath string) string { + pathOnly, query := splitKlingPathAndQuery(originPath) + matches := util.RegRetrieveVideoPath.FindStringSubmatch(pathOnly) + if matches == nil { + return util.MapRequestPathByCapability(string(ApiNameRetrieveVideo), originPath, k.config.capabilities) + } + + index := util.RegRetrieveVideoPath.SubexpIndex("video_id") + if index < 0 || index >= len(matches) { + return util.MapRequestPathByCapability(string(ApiNameRetrieveVideo), originPath, k.config.capabilities) + } + + videoID := matches[index] + taskType, forwardedQuery := extractKlingTaskTypeQuery(query) + switch { + case strings.HasPrefix(videoID, klingImageTaskIDPrefix): + rawID := strings.TrimPrefix(videoID, klingImageTaskIDPrefix) + return appendKlingQuery(replaceKlingVideoID(k.imageRetrieveVideoPath(), rawID), forwardedQuery) + case strings.HasPrefix(videoID, klingTextTaskIDPrefix): + rawID := strings.TrimPrefix(videoID, klingTextTaskIDPrefix) + return appendKlingQuery(replaceKlingVideoID(k.textRetrieveVideoPath(), rawID), forwardedQuery) + default: + if taskType == klingTaskTypeImageToVideo { + return appendKlingQuery(replaceKlingVideoID(k.imageRetrieveVideoPath(), videoID), forwardedQuery) + } + if taskType == klingTaskTypeTextToVideo { + return appendKlingQuery(replaceKlingVideoID(k.textRetrieveVideoPath(), videoID), forwardedQuery) + } + return util.MapRequestPathByCapability(string(ApiNameRetrieveVideo), pathOnly+forwardedQuery, k.config.capabilities) + } +} + +func (k *klingProvider) textCreateVideoPath() string { + return klingCapabilityPath(k.config.capabilities, ApiNameVideos, klingTextToVideoPath) +} + +func (k *klingProvider) imageCreateVideoPath() string { + return klingCapabilityPath(k.config.capabilities, ApiNameKlingImageToVideo, klingImageToVideoPath) +} + +func (k *klingProvider) textRetrieveVideoPath() string { + return klingCapabilityPath(k.config.capabilities, ApiNameRetrieveVideo, klingTextToVideoTaskPath) +} + +func (k *klingProvider) imageRetrieveVideoPath() string { + return klingCapabilityPath(k.config.capabilities, ApiNameKlingRetrieveImageVideo, klingImageToVideoTaskPath) +} + +func klingCapabilityPath(capabilities map[string]string, apiName ApiName, fallback string) string { + if path := capabilities[string(apiName)]; path != "" { + return path + } + return fallback +} + +func replaceKlingVideoID(taskPath, videoID string) string { + return strings.Replace(taskPath, "{video_id}", videoID, 1) +} + +func klingPathWithExistingQuery(currentPath, targetPath string) string { + _, query := splitKlingPathAndQuery(currentPath) + return appendKlingQuery(targetPath, query) +} + +func klingPathWithOriginalQuery(ctx wrapper.HttpContext, currentPath, targetPath string) string { + if originPath, ok := ctx.GetContext(CtxRequestPath).(string); ok && originPath != "" { + _, query := splitKlingPathAndQuery(originPath) + return appendKlingQuery(targetPath, query) + } + return klingPathWithExistingQuery(currentPath, targetPath) +} + +func appendKlingQuery(targetPath, query string) string { + if query == "" { + return targetPath + } + query = strings.TrimPrefix(query, "?") + if query == "" { + return targetPath + } + targetPathOnly, targetQuery := splitKlingPathAndQuery(targetPath) + targetQuery = strings.TrimPrefix(targetQuery, "?") + if targetQuery == "" { + return targetPathOnly + "?" + query + } + return targetPathOnly + "?" + mergeKlingQueryParts(targetQuery, query) +} + +func mergeKlingQueryParts(baseQuery, extraQuery string) string { + parts := make([]string, 0) + seen := make(map[string]struct{}) + for _, part := range strings.Split(baseQuery, "&") { + if part == "" { + continue + } + parts = append(parts, part) + seen[part] = struct{}{} + } + for _, part := range strings.Split(extraQuery, "&") { + if part == "" { + continue + } + if _, exists := seen[part]; exists { + continue + } + parts = append(parts, part) + seen[part] = struct{}{} + } + return strings.Join(parts, "&") +} + +func splitKlingPathAndQuery(rawPath string) (string, string) { + queryIndex := strings.Index(rawPath, "?") + if queryIndex < 0 { + return rawPath, "" + } + return rawPath[:queryIndex], rawPath[queryIndex:] +} + +func extractKlingTaskTypeQuery(query string) (string, string) { + if query == "" { + return "", "" + } + + parts := strings.Split(strings.TrimPrefix(query, "?"), "&") + forwardedParts := make([]string, 0, len(parts)) + taskType := "" + for _, part := range parts { + if part == "" { + continue + } + + key, value, _ := strings.Cut(part, "=") + decodedKey, err := url.QueryUnescape(key) + if err != nil { + decodedKey = key + } + if decodedKey != klingTaskTypeQueryKey { + forwardedParts = append(forwardedParts, part) + continue + } + + decodedValue, err := url.QueryUnescape(value) + if err != nil { + decodedValue = value + } + // If repeated, the last task type wins; all task type hints are stripped before forwarding. + taskType = normalizeKlingTaskType(decodedValue) + } + + if len(forwardedParts) == 0 { + return taskType, "" + } + return taskType, "?" + strings.Join(forwardedParts, "&") +} + +func normalizeKlingTaskType(taskType string) string { + switch strings.ToLower(strings.TrimSpace(taskType)) { + case klingTaskTypeImageToVideo, "image", "i2v": + return klingTaskTypeImageToVideo + case klingTaskTypeTextToVideo, "text", "t2v": + return klingTaskTypeTextToVideo + default: + return "" + } +} + +func prefixKlingTaskIDs(body []byte, prefix string) ([]byte, error) { + var err error + for _, path := range []string{"data.task_id", "task_id"} { + value := gjson.GetBytes(body, path) + if !value.Exists() || value.String() == "" { + continue + } + taskID := value.String() + if strings.HasPrefix(taskID, klingTextTaskIDPrefix) || strings.HasPrefix(taskID, klingImageTaskIDPrefix) { + continue + } + body, err = sjson.SetBytes(body, path, prefix+taskID) + if err != nil { + return nil, err + } + } + return body, nil +} + +func (k *klingProvider) isImageToVideoRequest(body []byte) bool { + // Keep this in sync with Kling video generation image input fields. + imageFields := []string{ + "image", + "image_url", + "image_urls", + "images", + "image_tail", + "image_tail_url", + "input_image", + "first_frame_image", + "last_frame_image", + } + for _, field := range imageFields { + if gjson.GetBytes(body, field).Exists() { + return true + } + } + return false +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/kling_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/kling_test.go new file mode 100644 index 00000000..ea05fa2b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/kling_test.go @@ -0,0 +1,942 @@ +package provider + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "strings" + "testing" + "time" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestKlingProviderValidateConfig(t *testing.T) { + initializer := &klingProviderInitializer{} + + t.Run("official credentials", func(t *testing.T) { + err := initializer.ValidateConfig(&ProviderConfig{ + klingAccessKey: "ak", + klingSecretKey: "sk", + }) + require.NoError(t, err) + }) + + t.Run("gateway token", func(t *testing.T) { + err := initializer.ValidateConfig(&ProviderConfig{ + apiTokens: []string{"gateway-token"}, + }) + require.NoError(t, err) + }) + + t.Run("official credentials preferred when both configured", func(t *testing.T) { + err := initializer.ValidateConfig(&ProviderConfig{ + apiTokens: []string{"gateway-token"}, + klingAccessKey: "ak", + klingSecretKey: "sk", + }) + require.NoError(t, err) + }) + + t.Run("partial official credentials rejected", func(t *testing.T) { + err := initializer.ValidateConfig(&ProviderConfig{ + klingAccessKey: "ak", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "klingAccessKey") + }) + + t.Run("missing auth rejected", func(t *testing.T) { + err := initializer.ValidateConfig(&ProviderConfig{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing kling authentication") + }) +} + +func TestKlingProviderConfigFromJson(t *testing.T) { + config := &ProviderConfig{} + config.FromJson(gjson.Parse(`{ + "type": "kling", + "klingAccessKey": "ak", + "klingSecretKey": "sk", + "klingTokenRefreshAhead": 120, + "capabilities": { + "kling/v1/image2video": "/gateway/image2video", + "kling/v1/retrieveimagevideo": "/gateway/image-tasks/{video_id}" + } + }`)) + + assert.Equal(t, "ak", config.klingAccessKey) + assert.Equal(t, "sk", config.klingSecretKey) + assert.Equal(t, int64(120), config.klingTokenRefreshAhead) + assert.Equal(t, "/gateway/image2video", config.capabilities[string(ApiNameKlingImageToVideo)]) + assert.Equal(t, "/gateway/image-tasks/{video_id}", config.capabilities[string(ApiNameKlingRetrieveImageVideo)]) + + defaultConfig := &ProviderConfig{} + defaultConfig.FromJson(gjson.Parse(`{"type": "kling", "apiTokens": ["token"]}`)) + assert.Equal(t, klingDefaultRefreshAhead, defaultConfig.klingTokenRefreshAhead) +} + +func TestKlingProviderInitializerCreateProvider(t *testing.T) { + initializer := &klingProviderInitializer{} + + capabilities := initializer.DefaultCapabilities() + assert.Equal(t, klingTextToVideoPath, capabilities[string(ApiNameVideos)]) + assert.Equal(t, klingImageToVideoPath, capabilities[string(ApiNameKlingImageToVideo)]) + assert.Equal(t, klingTextToVideoTaskPath, capabilities[string(ApiNameRetrieveVideo)]) + assert.Equal(t, klingImageToVideoTaskPath, capabilities[string(ApiNameKlingRetrieveImageVideo)]) + + created, err := initializer.CreateProvider(ProviderConfig{ + protocol: protocolOpenAI, + klingAccessKey: "ak", + klingSecretKey: "sk", + }) + require.NoError(t, err) + _, transformsOpenAIResponseBody := created.(TransformResponseBodyHandler) + assert.True(t, transformsOpenAIResponseBody) + _, transformsOpenAIRequestBody := created.(RequestBodyHandler) + assert.True(t, transformsOpenAIRequestBody) + + provider := requireKlingBaseProvider(t, created) + assert.Equal(t, providerTypeKling, provider.GetProviderType()) + assert.Equal(t, klingDefaultRefreshAhead, provider.config.klingTokenRefreshAhead) + assert.Equal(t, klingTextToVideoPath, provider.config.capabilities[string(ApiNameVideos)]) + assert.Equal(t, klingImageToVideoPath, provider.config.capabilities[string(ApiNameKlingImageToVideo)]) + assert.Equal(t, klingImageToVideoTaskPath, provider.config.capabilities[string(ApiNameKlingRetrieveImageVideo)]) + + original, err := initializer.CreateProvider(ProviderConfig{ + protocol: protocolOriginal, + apiTokens: []string{"token"}, + }) + require.NoError(t, err) + _, transformsOriginalResponseBody := original.(TransformResponseBodyHandler) + assert.False(t, transformsOriginalResponseBody) + _, transformsOriginalRequestBody := original.(RequestBodyHandler) + assert.False(t, transformsOriginalRequestBody) + _, isBaseProvider := original.(*klingProvider) + assert.True(t, isBaseProvider) +} + +func TestCreateKlingJWT(t *testing.T) { + now := int64(1710000000) + token, expireAt, err := createKlingJWT("access-key", "secret-key", now) + require.NoError(t, err) + assert.Equal(t, now+klingJWTLifetimeSeconds, expireAt) + + parts := strings.Split(token, ".") + require.Len(t, parts, 3) + + headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + require.NoError(t, err) + var header map[string]string + require.NoError(t, json.Unmarshal(headerJSON, &header)) + assert.Equal(t, "HS256", header["alg"]) + assert.Equal(t, "JWT", header["typ"]) + + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + var payload map[string]interface{} + require.NoError(t, json.Unmarshal(payloadJSON, &payload)) + assert.Equal(t, "access-key", payload["iss"]) + assert.Equal(t, float64(now+klingJWTLifetimeSeconds), payload["exp"]) + assert.Equal(t, float64(now-klingJWTNotBeforeSkewSecond), payload["nbf"]) + + mac := hmac.New(sha256.New, []byte("secret-key")) + _, _ = mac.Write([]byte(parts[0] + "." + parts[1])) + expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + assert.Equal(t, expectedSignature, parts[2]) +} + +func TestKlingProviderTransformRequestHeadersAuth(t *testing.T) { + t.Run("official mode uses jwt bearer", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOriginal, + klingAccessKey: "access-key", + klingSecretKey: "secret-key", + }, + } + headers := http.Header{} + + provider.TransformRequestHeaders(newMockMultipartHttpContext(), ApiNameVideos, headers) + + assert.Equal(t, klingDefaultDomain, headers.Get(":authority")) + auth := headers.Get("Authorization") + require.True(t, strings.HasPrefix(auth, "Bearer ")) + payload := decodeKlingJWTPayload(t, strings.TrimPrefix(auth, "Bearer ")) + assert.Equal(t, "access-key", payload["iss"]) + }) + + t.Run("gateway mode uses static bearer token", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOriginal, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + headers := http.Header{} + + provider.TransformRequestHeaders(ctx, ApiNameVideos, headers) + + assert.Equal(t, klingDefaultDomain, headers.Get(":authority")) + assert.Equal(t, "Bearer gateway-token", headers.Get("Authorization")) + }) + + t.Run("provider domain skips default host overwrite", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOriginal, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + providerDomain: "api.302.ai", + }, + } + headers := http.Header{":authority": []string{"example.com"}} + + provider.TransformRequestHeaders(ctx, ApiNameVideos, headers) + + assert.Equal(t, "example.com", headers.Get(":authority")) + assert.Equal(t, "Bearer gateway-token", headers.Get("Authorization")) + }) + + t.Run("original mode preserves content length", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOriginal, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + headers := http.Header{"Content-Length": []string{"128"}} + + provider.TransformRequestHeaders(ctx, ApiNameVideos, headers) + + assert.Equal(t, "128", headers.Get("Content-Length")) + }) + + t.Run("openai mode rewrites capability path and removes content length", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + klingAccessKey: "access-key", + klingSecretKey: "secret-key", + capabilities: map[string]string{ + string(ApiNameVideos): klingTextToVideoPath, + }, + }, + } + headers := http.Header{ + ":path": []string{"/v1/videos?trace=1"}, + "Content-Length": []string{"12"}, + } + + provider.TransformRequestHeaders(newMockMultipartHttpContext(), ApiNameVideos, headers) + + assert.Equal(t, klingTextToVideoPath+"?trace=1", headers.Get(":path")) + assert.Equal(t, klingDefaultDomain, headers.Get(":authority")) + assert.Empty(t, headers.Get("Content-Length")) + }) + + t.Run("prefixed image task id routes retrieve to image endpoint", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/" + klingImageTaskIDPrefix + "task-123?with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, klingImageToVideoPath+"/task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("prefixed image task id strips internal task type query", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/" + klingImageTaskIDPrefix + "task-123?kling_task_type=image2video&with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, klingImageToVideoPath+"/task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("prefixed text task id routes retrieve to text endpoint", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/" + klingTextTaskIDPrefix + "task-123?with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, klingTextToVideoPath+"/task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("prefixed text task id strips internal task type query", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/" + klingTextTaskIDPrefix + "task-123?with_status=true&kling_task_type=t2v"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, klingTextToVideoPath+"/task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("raw image task id uses explicit task type query without forwarding it", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/raw-task-123?kling_task_type=image2video&with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, klingImageToVideoPath+"/raw-task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("raw text task id uses explicit task type query without forwarding it", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/raw-task-123?with_status=true&kling_task_type=t2v"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, klingTextToVideoPath+"/raw-task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("image task id routes retrieve through configured image capability", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + string(ApiNameKlingRetrieveImageVideo): "/gateway/image-tasks/{video_id}", + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/" + klingImageTaskIDPrefix + "task-123?with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, "/gateway/image-tasks/task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("raw image task id routes retrieve through configured image capability", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): klingTextToVideoTaskPath, + string(ApiNameKlingRetrieveImageVideo): "/gateway/image-tasks/{video_id}", + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/raw-task-123?kling_task_type=i2v&with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, "/gateway/image-tasks/raw-task-123?with_status=true", headers.Get(":path")) + }) + + t.Run("retrieve capability query merges with request query", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + protocol: protocolOpenAI, + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): "/gateway/text-tasks/{video_id}?version=1", + string(ApiNameKlingRetrieveImageVideo): "/gateway/image-tasks/{video_id}?version=1", + }, + apiTokens: []string{"gateway-token"}, + failover: &failover{ctxApiTokenInUse: "kling-token-key"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext("kling-token-key", "gateway-token") + headers := http.Header{":path": []string{"/v1/videos/raw-task-123?kling_task_type=i2v&with_status=true"}} + + provider.TransformRequestHeaders(ctx, ApiNameRetrieveVideo, headers) + + assert.Equal(t, "/gateway/image-tasks/raw-task-123?version=1&with_status=true", headers.Get(":path")) + }) + + t.Run("retrieve path outside openai pattern falls back to capability mapping", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): "/gateway/retrieve", + }, + }, + } + + assert.Equal(t, "/gateway/retrieve?trace=1", provider.mapRetrieveVideoPath("/custom/retrieve?trace=1")) + }) + + t.Run("unknown task type hint is stripped before fallback retrieve mapping", func(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameRetrieveVideo): "/gateway/text-tasks/{video_id}?version=1", + }, + }, + } + + assert.Equal(t, "/gateway/text-tasks/task-123?version=1&with_status=true", provider.mapRetrieveVideoPath("/v1/videos/task-123?kling_task_type=bad&with_status=true")) + }) +} + +func TestKlingProviderGetJWTTokenUsesCache(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + klingAccessKey: "access-key", + klingSecretKey: "secret-key", + klingTokenRefreshAhead: 60, + }, + } + + first := provider.getJWTToken() + require.NotEmpty(t, first) + second := provider.getJWTToken() + assert.Equal(t, first, second) + + provider.jwtExpireAt = time.Now().Unix() + refreshed := provider.getJWTToken() + assert.NotEmpty(t, refreshed) +} + +func TestKlingProviderTransformRequestBodyHeaders(t *testing.T) { + provider := &klingProvider{ + config: ProviderConfig{ + modelMapping: map[string]string{"client-video": "kling-v2-1"}, + }, + } + + t.Run("text to video maps model to model_name", func(t *testing.T) { + headers := http.Header{":path": []string{klingTextToVideoPath}} + body := []byte(`{"model":"client-video","prompt":"sunrise","duration":"5","mode":"std"}`) + + result, err := provider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, klingTextToVideoPath, headers.Get(":path")) + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.Equal(t, "kling-v2-1", gjson.GetBytes(result, "model_name").String()) + assert.Equal(t, "sunrise", gjson.GetBytes(result, "prompt").String()) + assert.Equal(t, "5", gjson.GetBytes(result, "duration").String()) + }) + + t.Run("image to video switches path", func(t *testing.T) { + headers := http.Header{":path": []string{klingTextToVideoPath}} + body := []byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`) + + result, err := provider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, klingImageToVideoPath, headers.Get(":path")) + assert.Equal(t, "kling-v2-1", gjson.GetBytes(result, "model_name").String()) + assert.Equal(t, "https://example.com/a.png", gjson.GetBytes(result, "image").String()) + }) + + t.Run("text to video preserves query string", func(t *testing.T) { + headers := http.Header{":path": []string{klingTextToVideoPath + "?trace=1"}} + body := []byte(`{"model":"client-video","prompt":"sunrise"}`) + + _, err := provider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, klingTextToVideoPath+"?trace=1", headers.Get(":path")) + }) + + t.Run("image to video preserves query string", func(t *testing.T) { + headers := http.Header{":path": []string{klingTextToVideoPath + "?trace=1"}} + body := []byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`) + + _, err := provider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, klingImageToVideoPath+"?trace=1", headers.Get(":path")) + }) + + t.Run("text create capability query merges with request query", func(t *testing.T) { + customProvider := &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{string(ApiNameVideos): "/gateway/text2video?version=1"}, + modelMapping: map[string]string{"client-video": "kling-v2-1"}, + }, + } + headers := http.Header{":path": []string{"/v1/videos?trace=1"}} + body := []byte(`{"model":"client-video","prompt":"sunrise"}`) + + _, err := customProvider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, "/gateway/text2video?version=1&trace=1", headers.Get(":path")) + }) + + t.Run("image create uses explicit image capability and merges query", func(t *testing.T) { + customProvider := &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameVideos): "/gateway/text2video", + string(ApiNameKlingImageToVideo): "/gateway/image2video?version=1", + }, + modelMapping: map[string]string{"client-video": "kling-v2-1"}, + }, + } + headers := http.Header{":path": []string{"/v1/videos?trace=1"}} + body := []byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`) + + result, err := customProvider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, "/gateway/image2video?version=1&trace=1", headers.Get(":path")) + assert.Equal(t, "kling-v2-1", gjson.GetBytes(result, "model_name").String()) + }) + + t.Run("image create does not duplicate capability query after header mapping", func(t *testing.T) { + customProvider := &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameKlingImageToVideo): "/gateway/image2video?version=1", + }, + modelMapping: map[string]string{"client-video": "kling-v2-1"}, + }, + } + headers := http.Header{":path": []string{"/gateway/image2video?version=1&trace=1"}} + body := []byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`) + + _, err := customProvider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, "/gateway/image2video?version=1&trace=1", headers.Get(":path")) + }) + + t.Run("image create does not inherit text capability query from header mapping", func(t *testing.T) { + customProvider := &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameVideos): "/gateway/text2video?mode=text", + string(ApiNameKlingImageToVideo): "/gateway/image2video?mode=image", + }, + modelMapping: map[string]string{"client-video": "kling-v2-1"}, + }, + } + ctx := newMockMultipartHttpContext() + ctx.SetContext(CtxRequestPath, "/v1/videos?trace=1") + headers := http.Header{":path": []string{"/gateway/text2video?mode=text&trace=1"}} + body := []byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`) + + _, err := customProvider.TransformRequestBodyHeaders(ctx, ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, "/gateway/image2video?mode=image&trace=1", headers.Get(":path")) + }) + + t.Run("model_name is accepted and mapped in place", func(t *testing.T) { + headers := http.Header{":path": []string{klingTextToVideoPath}} + body := []byte(`{"model_name":"client-video","prompt":"sunrise"}`) + + result, err := provider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, "kling-v2-1", gjson.GetBytes(result, "model_name").String()) + }) + + t.Run("missing model passes body through", func(t *testing.T) { + headers := http.Header{":path": []string{klingTextToVideoPath}} + body := []byte(`{"prompt":"sunrise"}`) + + result, err := provider.TransformRequestBodyHeaders(newMockMultipartHttpContext(), ApiNameVideos, body, headers) + require.NoError(t, err) + + assert.Equal(t, string(body), string(result)) + assert.Equal(t, klingTextToVideoPath, headers.Get(":path")) + }) + +} + +func TestKlingProviderTransformResponseBody(t *testing.T) { + provider := &klingOpenAIProvider{klingProvider: &klingProvider{}} + + t.Run("image creation prefixes returned task ids", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + ctx.SetContext(ctxKeyKlingVideoTaskType, klingTaskTypeImageToVideo) + + result, err := provider.TransformResponseBody(ctx, ApiNameVideos, []byte(`{"id":"root-task","task_id":"top-task","data":{"task_id":"data-task"}}`)) + require.NoError(t, err) + + assert.Equal(t, "root-task", gjson.GetBytes(result, "id").String()) + assert.Equal(t, klingImageTaskIDPrefix+"top-task", gjson.GetBytes(result, "task_id").String()) + assert.Equal(t, klingImageTaskIDPrefix+"data-task", gjson.GetBytes(result, "data.task_id").String()) + }) + + t.Run("text creation prefixes returned task ids", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + ctx.SetContext(ctxKeyKlingVideoTaskType, klingTaskTypeTextToVideo) + + result, err := provider.TransformResponseBody(ctx, ApiNameVideos, []byte(`{"data":{"task_id":"data-task"}}`)) + require.NoError(t, err) + + assert.Equal(t, klingTextTaskIDPrefix+"data-task", gjson.GetBytes(result, "data.task_id").String()) + }) + + t.Run("retrieve video response body passes through", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + ctx.SetContext(ctxKeyKlingVideoTaskType, klingTaskTypeImageToVideo) + body := []byte(`{"id":"root-task","data":{"task_id":"data-task"}}`) + + result, err := provider.TransformResponseBody(ctx, ApiNameRetrieveVideo, body) + require.NoError(t, err) + + assert.Equal(t, string(body), string(result)) + }) + + t.Run("video response without task type passes through", func(t *testing.T) { + ctx := newMockMultipartHttpContext() + body := []byte(`{"data":{"task_id":"data-task"}}`) + + result, err := provider.TransformResponseBody(ctx, ApiNameVideos, body) + require.NoError(t, err) + + assert.Equal(t, string(body), string(result)) + }) +} + +func TestPrefixKlingTaskIDs(t *testing.T) { + t.Run("already prefixed task ids are unchanged", func(t *testing.T) { + body := []byte(`{"task_id":"kling-i2v-top-task","data":{"task_id":"kling-t2v-data-task"}}`) + + result, err := prefixKlingTaskIDs(body, klingImageTaskIDPrefix) + require.NoError(t, err) + + assert.Equal(t, "kling-i2v-top-task", gjson.GetBytes(result, "task_id").String()) + assert.Equal(t, "kling-t2v-data-task", gjson.GetBytes(result, "data.task_id").String()) + }) + + t.Run("missing task ids are ignored", func(t *testing.T) { + body := []byte(`{"data":{"status":"submitted"}}`) + + result, err := prefixKlingTaskIDs(body, klingImageTaskIDPrefix) + require.NoError(t, err) + + assert.Equal(t, string(body), string(result)) + }) +} + +func TestKlingProviderGetApiName(t *testing.T) { + provider := &klingProvider{} + + tests := []struct { + name string + path string + want ApiName + }{ + { + name: "text to video create", + path: "/proxy/v1/videos/text2video", + want: ApiNameVideos, + }, + { + name: "image to video create", + path: "/proxy/v1/videos/image2video", + want: ApiNameVideos, + }, + { + name: "openai retrieve", + path: "/proxy/v1/videos/task-123", + want: ApiNameRetrieveVideo, + }, + { + name: "native text retrieve", + path: "/proxy/v1/videos/text2video/task-123", + want: ApiNameRetrieveVideo, + }, + { + name: "native image retrieve", + path: "/proxy/v1/videos/image2video/task-123", + want: ApiNameRetrieveVideo, + }, + { + name: "native text create is not retrieve", + path: "/proxy/v1/videos/text2video", + want: ApiNameVideos, + }, + { + name: "unsupported path", + path: "/proxy/v1/images/generations", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, provider.GetApiName(tt.path)) + }) + } +} + +func TestKlingQueryMerge(t *testing.T) { + tests := []struct { + name string + targetPath string + query string + want string + }{ + { + name: "empty query leaves target unchanged", + targetPath: "/gateway/image2video", + query: "", + want: "/gateway/image2video", + }, + { + name: "request query is appended", + targetPath: "/gateway/image2video", + query: "?trace=1", + want: "/gateway/image2video?trace=1", + }, + { + name: "capability query and request query are merged", + targetPath: "/gateway/image2video?version=1", + query: "?trace=1", + want: "/gateway/image2video?version=1&trace=1", + }, + { + name: "duplicate capability query from mapped path is not repeated", + targetPath: "/gateway/image2video?version=1", + query: "?version=1&trace=1", + want: "/gateway/image2video?version=1&trace=1", + }, + { + name: "query without question mark is accepted", + targetPath: "/gateway/image2video?version=1", + query: "trace=1", + want: "/gateway/image2video?version=1&trace=1", + }, + { + name: "empty question mark query leaves target unchanged", + targetPath: "/gateway/image2video", + query: "?", + want: "/gateway/image2video", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, appendKlingQuery(tt.targetPath, tt.query)) + }) + } + + t.Run("merge skips empty and duplicate parts", func(t *testing.T) { + assert.Equal(t, "version=1&trace=1", mergeKlingQueryParts("version=1&&", "&trace=1&&version=1")) + }) +} + +func TestKlingTaskTypeQuery(t *testing.T) { + t.Run("extract task type query", func(t *testing.T) { + tests := []struct { + name string + query string + wantTaskType string + wantForwarding string + }{ + { + name: "empty query", + query: "", + wantTaskType: "", + wantForwarding: "", + }, + { + name: "only task type", + query: "?kling_task_type=image2video", + wantTaskType: klingTaskTypeImageToVideo, + wantForwarding: "", + }, + { + name: "task type is stripped and other query params are forwarded", + query: "?trace=1&kling_task_type=t2v&with_status=true", + wantTaskType: klingTaskTypeTextToVideo, + wantForwarding: "?trace=1&with_status=true", + }, + { + name: "url encoded value", + query: "?kling_task_type=image%32video&trace=1", + wantTaskType: klingTaskTypeImageToVideo, + wantForwarding: "?trace=1", + }, + { + name: "url encoded unknown value", + query: "?kling_task_type=image%202video&trace=1", + wantTaskType: "", + wantForwarding: "?trace=1", + }, + { + name: "repeated task type uses the last value", + query: "?kling_task_type=t2v&trace=1&kling_task_type=i2v", + wantTaskType: klingTaskTypeImageToVideo, + wantForwarding: "?trace=1", + }, + { + name: "invalid encoded key is forwarded", + query: "?%zz=image2video&trace=1", + wantTaskType: "", + wantForwarding: "?%zz=image2video&trace=1", + }, + { + name: "invalid encoded value falls back before normalization", + query: "?kling_task_type=%zz&trace=1", + wantTaskType: "", + wantForwarding: "?trace=1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + taskType, forwarding := extractKlingTaskTypeQuery(tt.query) + assert.Equal(t, tt.wantTaskType, taskType) + assert.Equal(t, tt.wantForwarding, forwarding) + }) + } + }) + + t.Run("normalize task type", func(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + { + name: "image alias", + raw: "image", + want: klingTaskTypeImageToVideo, + }, + { + name: "text alias", + raw: "t2v", + want: klingTaskTypeTextToVideo, + }, + { + name: "unknown value", + raw: "image 2video", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, normalizeKlingTaskType(tt.raw)) + }) + } + }) +} + +func TestKlingProviderOnRequestBodyUnsupportedAPI(t *testing.T) { + provider := &klingOpenAIProvider{ + klingProvider: &klingProvider{ + config: ProviderConfig{ + capabilities: map[string]string{}, + }, + }, + } + + action, err := provider.OnRequestBody(newMockMultipartHttpContext(), ApiNameVideos, []byte(`{}`)) + assert.Equal(t, types.ActionContinue, action) + assert.ErrorIs(t, err, errUnsupportedApiName) +} + +func decodeKlingJWTPayload(t *testing.T, token string) map[string]interface{} { + t.Helper() + + parts := strings.Split(token, ".") + require.Len(t, parts, 3) + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + var payload map[string]interface{} + require.NoError(t, json.Unmarshal(payloadJSON, &payload)) + return payload +} + +func requireKlingBaseProvider(t *testing.T, created Provider) *klingProvider { + t.Helper() + + switch provider := created.(type) { + case *klingProvider: + return provider + case *klingOpenAIProvider: + return provider.klingProvider + default: + t.Fatalf("expected kling provider, got %T", created) + return nil + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index adae9f5d..6088f645 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -65,6 +65,8 @@ const ( ApiNameRetrieveVideo ApiName = "openai/v1/retrievevideo" ApiNameVideoRemix ApiName = "openai/v1/videoremix" ApiNameRetrieveVideoContent ApiName = "openai/v1/retrievevideocontent" + ApiNameKlingImageToVideo ApiName = "kling/v1/image2video" + ApiNameKlingRetrieveImageVideo ApiName = "kling/v1/retrieveimagevideo" // TODO: 以下是一些非标准的API名称,需要进一步确认是否支持 ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank" @@ -159,6 +161,7 @@ const ( providerTypeFireworks = "fireworks" providerTypeVllm = "vllm" providerTypeGeneric = "generic" + providerTypeKling = "kling" protocolOpenAI = "openai" protocolOriginal = "original" @@ -253,6 +256,7 @@ var ( providerTypeFireworks: &fireworksProviderInitializer{}, providerTypeVllm: &vllmProviderInitializer{}, providerTypeGeneric: &genericProviderInitializer{}, + providerTypeKling: &klingProviderInitializer{}, } ) @@ -417,6 +421,15 @@ type ProviderConfig struct { // @Title zh-CN Vertex token刷新提前时间 // @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒 vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"` + // @Title zh-CN Kling Access Key + // @Description zh-CN 仅适用于KlingAI官方服务鉴权,用于生成JWT Token + klingAccessKey string `required:"false" yaml:"klingAccessKey" json:"klingAccessKey"` + // @Title zh-CN Kling Secret Key + // @Description zh-CN 仅适用于KlingAI官方服务鉴权,用于签名JWT Token + klingSecretKey string `required:"false" yaml:"klingSecretKey" json:"klingSecretKey"` + // @Title zh-CN Kling token刷新提前时间 + // @Description zh-CN Kling JWT过期前提前刷新的时间,单位为秒,默认值为60秒 + klingTokenRefreshAhead int64 `required:"false" yaml:"klingTokenRefreshAhead" json:"klingTokenRefreshAhead"` // @Title zh-CN Vertex AI OpenAI兼容模式 // @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API,请求和响应均使用OpenAI格式,无需协议转换。与Express Mode(apiTokens)互斥。 vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"` @@ -614,6 +627,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.vertexTokenRefreshAhead == 0 { c.vertexTokenRefreshAhead = 60 } + c.klingAccessKey = json.Get("klingAccessKey").String() + c.klingSecretKey = json.Get("klingSecretKey").String() + c.klingTokenRefreshAhead = json.Get("klingTokenRefreshAhead").Int() + if c.klingTokenRefreshAhead == 0 { + c.klingTokenRefreshAhead = 60 + } c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool() c.targetLang = json.Get("targetLang").String() @@ -696,6 +715,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { string(ApiNameCohereV1Rerank), string(ApiNameVideos), string(ApiNameRetrieveVideo), + string(ApiNameKlingImageToVideo), + string(ApiNameKlingRetrieveImageVideo), string(ApiNameRetrieveVideoContent), string(ApiNameVideoRemix): c.capabilities[capability] = pathJson.String() @@ -1137,7 +1158,9 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) c.capabilities = make(map[string]string) } for capability, path := range capabilities { - c.capabilities[capability] = path + if _, exists := c.capabilities[capability]; !exists { + c.capabilities[capability] = path + } } } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go index 3f2f7b12..5b0f1037 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider_test.go @@ -690,6 +690,21 @@ func TestProviderConfig_SetDefaultCapabilities(t *testing.T) { assert.Equal(t, "/v1/embeddings", config.capabilities[string(ApiNameEmbeddings)]) assert.Equal(t, "/v1/chat/completions", config.capabilities[string(ApiNameChatCompletion)]) }) + + t.Run("preserve_existing_capability", func(t *testing.T) { + config := &ProviderConfig{ + capabilities: map[string]string{ + string(ApiNameChatCompletion): "/custom/chat/completions", + }, + } + + defaultCaps := map[string]string{ + string(ApiNameChatCompletion): "/v1/chat/completions", + } + config.setDefaultCapabilities(defaultCaps) + + assert.Equal(t, "/custom/chat/completions", config.capabilities[string(ApiNameChatCompletion)]) + }) } func TestCreateProvider(t *testing.T) { diff --git a/plugins/wasm-go/extensions/ai-proxy/test/kling.go b/plugins/wasm-go/extensions/ai-proxy/test/kling.go new file mode 100644 index 00000000..bd0f17d1 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/test/kling.go @@ -0,0 +1,418 @@ +package test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/test" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +var klingOfficialConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "kling", + "klingAccessKey": "kling-ak-test", + "klingSecretKey": "kling-sk-test", + "klingTokenRefreshAhead": 60, + "modelMapping": map[string]string{ + "client-video": "kling-v2-1", + }, + }, + }) + return data +}() + +var klingGatewayConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "kling", + "apiTokens": []string{"gateway-token"}, + "providerDomain": "api.302.ai", + "providerBasePath": "/klingai", + "modelMapping": map[string]string{ + "client-video": "kling-v2-1", + }, + }, + }) + return data +}() + +var klingGatewayCustomImageRetrieveConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "kling", + "apiTokens": []string{"gateway-token"}, + "providerDomain": "api.302.ai", + "providerBasePath": "/klingai", + "modelMapping": map[string]string{ + "client-video": "kling-v2-1", + }, + "capabilities": map[string]string{ + "openai/v1/videos": "/gateway/text2video?mode=text", + "kling/v1/image2video": "/gateway/image2video?mode=image", + "kling/v1/retrieveimagevideo": "/gateway/image-tasks/{video_id}?version=1", + }, + }, + }) + return data +}() + +var klingOriginalConfig = func() json.RawMessage { + data, _ := json.Marshal(map[string]interface{}{ + "provider": map[string]interface{}{ + "type": "kling", + "apiTokens": []string{"gateway-token"}, + "providerDomain": "api.302.ai", + "providerBasePath": "/klingai", + "protocol": "original", + }, + }) + return data +}() + +func RunKlingParseConfigTests(t *testing.T) { + test.RunGoTest(t, func(t *testing.T) { + t.Run("kling official config", func(t *testing.T) { + host, status := test.NewTestHost(klingOfficialConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + + t.Run("kling gateway config", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + config, err := host.GetMatchConfig() + require.NoError(t, err) + require.NotNil(t, config) + }) + }) +} + +func RunKlingOnHttpRequestHeadersTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("official mode sets jwt bearer and default host", func(t *testing.T) { + host, status := test.NewTestHost(klingOfficialConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, hasHost) + require.Equal(t, "api-singapore.klingai.com", hostValue) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth) + require.True(t, strings.HasPrefix(authValue, "Bearer ")) + require.Len(t, strings.Split(strings.TrimPrefix(authValue, "Bearer "), "."), 3) + + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/v1/videos/text2video", pathValue) + }) + + t.Run("providerDomain and providerBasePath apply to gateway mode", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.HeaderStopIteration, action) + + requestHeaders := host.GetRequestHeaders() + hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority") + require.True(t, hasHost) + require.Equal(t, "api.302.ai", hostValue) + + authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization") + require.True(t, hasAuth) + require.Equal(t, "Bearer gateway-token", authValue) + + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/text2video", pathValue) + }) + + t.Run("retrieve video query path is mapped under providerBasePath", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/task-123?with_status=true"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/text2video/task-123?with_status=true", pathValue) + }) + + t.Run("prefixed image task query path is mapped to image endpoint", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/kling-i2v-task-123?with_status=true"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/image2video/task-123?with_status=true", pathValue) + }) + + t.Run("prefixed image task query strips task type hint", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/kling-i2v-task-123?kling_task_type=image2video&with_status=true"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/image2video/task-123?with_status=true", pathValue) + }) + + t.Run("raw image task query path uses explicit task type hint", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/raw-task-123?kling_task_type=image2video&with_status=true"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/image2video/raw-task-123?with_status=true", pathValue) + }) + + t.Run("raw retrieve strips unknown task type hint before fallback mapping", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/raw-task-123?kling_task_type=bad&with_status=true"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/text2video/raw-task-123?with_status=true", pathValue) + }) + + t.Run("image retrieve path uses configured image capability", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayCustomImageRetrieveConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/kling-i2v-task-123?with_status=true"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/gateway/image-tasks/task-123?version=1&with_status=true", pathValue) + }) + }) +} + +func RunKlingOnHttpRequestBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("text to video keeps text endpoint and maps model_name", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos?gateway_param=1"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + action := host.CallOnHttpRequestBody([]byte(`{"model":"client-video","prompt":"sunrise","duration":"5"}`)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/text2video?gateway_param=1", pathValue) + + processedBody := host.GetRequestBody() + require.Equal(t, "kling-v2-1", gjson.GetBytes(processedBody, "model_name").String()) + require.False(t, gjson.GetBytes(processedBody, "model").Exists()) + require.Equal(t, "sunrise", gjson.GetBytes(processedBody, "prompt").String()) + }) + + t.Run("image to video switches endpoint after body inspection", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos?gateway_param=1"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + action := host.CallOnHttpRequestBody([]byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/image2video?gateway_param=1", pathValue) + + processedBody := host.GetRequestBody() + require.Equal(t, "kling-v2-1", gjson.GetBytes(processedBody, "model_name").String()) + require.Equal(t, "https://example.com/a.png", gjson.GetBytes(processedBody, "image").String()) + }) + + t.Run("image to video uses configured image capability and merges query", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayCustomImageRetrieveConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos?gateway_param=1"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + + action := host.CallOnHttpRequestBody([]byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`)) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/gateway/image2video?mode=image&gateway_param=1", pathValue) + + processedBody := host.GetRequestBody() + require.Equal(t, "kling-v2-1", gjson.GetBytes(processedBody, "model_name").String()) + }) + + t.Run("original protocol does not expose request body handler", func(t *testing.T) { + host, status := test.NewTestHost(klingOriginalConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/image2video"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + {"Content-Length", "64"}, + }) + require.Equal(t, types.ActionContinue, action) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/image2video", pathValue) + contentLengthValue, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length") + require.True(t, hasContentLength) + require.Equal(t, "64", contentLengthValue) + }) + + t.Run("original protocol recognizes native retrieve video path", func(t *testing.T) { + host, status := test.NewTestHost(klingOriginalConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + action := host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos/text2video/task-123"}, + {":method", "GET"}, + }) + require.True(t, action == types.ActionContinue || action == types.HeaderStopIteration) + + requestHeaders := host.GetRequestHeaders() + pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path") + require.True(t, hasPath) + require.Equal(t, "/klingai/v1/videos/text2video/task-123", pathValue) + }) + }) +} + +func RunKlingOnHttpResponseBodyTests(t *testing.T) { + test.RunTest(t, func(t *testing.T) { + t.Run("image creation response prefixes task id", func(t *testing.T) { + host, status := test.NewTestHost(klingGatewayConfig) + defer host.Reset() + require.Equal(t, types.OnPluginStartStatusOK, status) + + host.CallOnHttpRequestHeaders([][2]string{ + {":authority", "example.com"}, + {":path", "/v1/videos"}, + {":method", "POST"}, + {"Content-Type", "application/json"}, + }) + action := host.CallOnHttpRequestBody([]byte(`{"model":"client-video","prompt":"animate","image":"https://example.com/a.png"}`)) + require.Equal(t, types.ActionContinue, action) + + require.NoError(t, host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))) + action = host.CallOnHttpResponseHeaders([][2]string{ + {":status", "200"}, + {"Content-Type", "application/json"}, + }) + require.Equal(t, types.ActionContinue, action) + action = host.CallOnHttpResponseBody([]byte(`{"id":"root-task","data":{"task_id":"task-123"}}`)) + require.Equal(t, types.ActionContinue, action) + + processedBody := host.GetResponseBody() + require.Equal(t, "root-task", gjson.GetBytes(processedBody, "id").String()) + require.Equal(t, "kling-i2v-task-123", gjson.GetBytes(processedBody, "data.task_id").String()) + }) + }) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/test/util.go b/plugins/wasm-go/extensions/ai-proxy/test/util.go index bfa0e720..609d1a4f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/test/util.go +++ b/plugins/wasm-go/extensions/ai-proxy/test/util.go @@ -83,6 +83,15 @@ func RunMapRequestPathByCapabilityTests(t *testing.T) { }, expected: "/v1/videos/video-xyz", }, + { + name: "video placeholder is replaced in nested provider path", + apiName: "openai/v1/retrievevideo", + origin: "/openai/v1/videos/video-xyz", + mapping: map[string]string{ + "openai/v1/retrievevideo": "/v1/videos/text2video/{video_id}", + }, + expected: "/v1/videos/text2video/video-xyz", + }, { name: "video content placeholder with query", apiName: "openai/v1/retrievevideocontent", diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index b041104d..38825301 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -131,9 +131,7 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s continue } id := subMatch[index] - mappedPathOnly = r.regx.ReplaceAllStringFunc(mappedPathOnly, func(s string) string { - return strings.Replace(s, "{"+r.key+"}", id, 1) - }) + mappedPathOnly = strings.Replace(mappedPathOnly, "{"+r.key+"}", id, 1) } } }