mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
optimize plugin sdk (#1930)
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -55,7 +56,7 @@ func writeTraceAttribute(ctx wrapper.HttpContext) {
|
||||
ctx.WriteUserAttributeToTrace()
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CustomLogConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CustomLogConfig, log log.Log) types.Action {
|
||||
if rand.Intn(10)%3 == 1 {
|
||||
writeLog(ctx)
|
||||
} else if rand.Intn(10)%3 == 2 {
|
||||
|
||||
@@ -2,9 +2,11 @@ module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent
|
||||
|
||||
go 1.19
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2 h1:gH7OIGXm4wtW5Vo7L2deMPqF7OVWNESDHv1CaaTGu6s=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2/go.mod h1:359don/ahMxpfeLMzr29Cjwcu8IywTTDUzWlBPRNLHw=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/dashscope"
|
||||
prompttpl "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/promptTpl"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -37,7 +38,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(gjson gjson.Result, c *PluginConfig, log log.Log) error {
|
||||
initResponsePromptTpl(gjson, c)
|
||||
|
||||
err := initAPIs(gjson, c)
|
||||
@@ -54,11 +55,11 @@ func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRequest Request, log wrapper.Log) types.Action {
|
||||
func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRequest Request, log log.Log) types.Action {
|
||||
log.Debugf("[onHttpRequestBody] firstreq:%s", prompt)
|
||||
|
||||
var userMessage Message
|
||||
@@ -88,7 +89,7 @@ func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRe
|
||||
}
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debug("onHttpRequestBody start")
|
||||
defer log.Debug("onHttpRequestBody end")
|
||||
|
||||
@@ -172,7 +173,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
return ret
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
log.Debug("onHttpResponseHeaders start")
|
||||
defer log.Debug("onHttpResponseHeaders end")
|
||||
|
||||
@@ -200,7 +201,7 @@ func extractJson(bodyStr string) (string, error) {
|
||||
return jsonStr, nil
|
||||
}
|
||||
|
||||
func jsonFormat(llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonSchema map[string]interface{}, assistantMessage Message, actionInput string, headers [][2]string, streamMode bool, rawResponse Response, log wrapper.Log) string {
|
||||
func jsonFormat(llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonSchema map[string]interface{}, assistantMessage Message, actionInput string, headers [][2]string, streamMode bool, rawResponse Response, log log.Log) string {
|
||||
prompt := fmt.Sprintf(prompttpl.Json_Resp_Template, jsonSchema, actionInput)
|
||||
|
||||
messages := []dashscope.Message{{Role: "user", Content: prompt}}
|
||||
@@ -241,7 +242,7 @@ func jsonFormat(llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonSchema map[st
|
||||
return content
|
||||
}
|
||||
|
||||
func noneStream(assistantMessage Message, actionInput string, rawResponse Response, log wrapper.Log) {
|
||||
func noneStream(assistantMessage Message, actionInput string, rawResponse Response, log log.Log) {
|
||||
assistantMessage.Role = "assistant"
|
||||
assistantMessage.Content = actionInput
|
||||
rawResponse.Choices[0].Message = assistantMessage
|
||||
@@ -257,7 +258,7 @@ func noneStream(assistantMessage Message, actionInput string, rawResponse Respon
|
||||
}
|
||||
}
|
||||
|
||||
func stream(actionInput string, rawResponse Response, log wrapper.Log) {
|
||||
func stream(actionInput string, rawResponse Response, log log.Log) {
|
||||
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
|
||||
proxywasm.ReplaceHttpResponseHeaders(headers)
|
||||
// Remove quotes from actionInput
|
||||
@@ -271,7 +272,7 @@ func stream(actionInput string, rawResponse Response, log wrapper.Log) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
|
||||
func toolsCallResult(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
|
||||
func toolsCallResult(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log log.Log, statusCode int, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
log.Debugf("statusCode: %d", statusCode)
|
||||
}
|
||||
@@ -332,7 +333,7 @@ func toolsCallResult(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmI
|
||||
}
|
||||
}
|
||||
|
||||
func outputParser(response string, log wrapper.Log) (string, string) {
|
||||
func outputParser(response string, log log.Log) (string, string) {
|
||||
log.Debugf("Raw response:%s", response)
|
||||
|
||||
start := strings.Index(response, "```")
|
||||
@@ -379,7 +380,7 @@ func outputParser(response string, log wrapper.Log) (string, string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func toolsCall(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
|
||||
func toolsCall(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log log.Log) (types.Action, string) {
|
||||
dashscope.MessageStore.AddForAssistant(content)
|
||||
|
||||
action, actionInput := outputParser(content, log)
|
||||
@@ -514,7 +515,7 @@ func toolsCall(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LL
|
||||
}
|
||||
|
||||
// 从response接收到firstreq的大模型返回
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debugf("onHttpResponseBody start")
|
||||
defer log.Debugf("onHttpResponseBody end")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -15,7 +16,7 @@ const (
|
||||
|
||||
type providerInitializer interface {
|
||||
ValidateConfig(ProviderConfig) error
|
||||
CreateProvider(ProviderConfig, wrapper.Log) (Provider, error)
|
||||
CreateProvider(ProviderConfig, log.Log) (Provider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -128,7 +129,7 @@ func (c *ProviderConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateProvider(pc ProviderConfig, log wrapper.Log) (Provider, error) {
|
||||
func CreateProvider(pc ProviderConfig, log log.Log) (Provider, error) {
|
||||
initializer, has := providerInitializers[pc.typ]
|
||||
if !has {
|
||||
return nil, errors.New("unknown provider type: " + pc.typ)
|
||||
|
||||
@@ -3,6 +3,7 @@ package cache
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -16,7 +17,7 @@ func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig, log wrapper.Log) (Provider, error) {
|
||||
func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig, log log.Log) (Provider, error) {
|
||||
rp := redisProvider{
|
||||
config: cf,
|
||||
client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
@@ -32,7 +33,7 @@ func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig, log wrapper
|
||||
type redisProvider struct {
|
||||
config ProviderConfig
|
||||
client wrapper.RedisClient
|
||||
log wrapper.Log
|
||||
log log.Log
|
||||
}
|
||||
|
||||
func (rp *redisProvider) GetProviderType() string {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ type PluginConfig struct {
|
||||
CacheKeyStrategy string
|
||||
}
|
||||
|
||||
func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
|
||||
func (c *PluginConfig) FromJson(json gjson.Result, log log.Log) {
|
||||
c.embeddingProviderConfig = &embedding.ProviderConfig{}
|
||||
c.vectorProviderConfig = &vector.ProviderConfig{}
|
||||
c.cacheProviderConfig = &cache.ProviderConfig{}
|
||||
@@ -140,7 +140,7 @@ func (c *PluginConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PluginConfig) Complete(log wrapper.Log) error {
|
||||
func (c *PluginConfig) Complete(log log.Log) error {
|
||||
var err error
|
||||
if c.embeddingProviderConfig.GetProviderType() != "" {
|
||||
log.Debugf("embedding provider is set to %s", c.embeddingProviderConfig.GetProviderType())
|
||||
@@ -191,7 +191,7 @@ func (c *PluginConfig) GetCacheProvider() cache.Provider {
|
||||
return c.cacheProvider
|
||||
}
|
||||
|
||||
func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) {
|
||||
func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log log.Log) {
|
||||
keyMap := map[string]string{
|
||||
"cacheKeyFrom.requestBody": "cacheKeyFrom",
|
||||
"cacheValueFrom.requestBody": "cacheValueFrom",
|
||||
@@ -210,7 +210,7 @@ func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log)
|
||||
}
|
||||
}
|
||||
|
||||
func setField(c *PluginConfig, fieldName string, value string, log wrapper.Log) {
|
||||
func setField(c *PluginConfig, fieldName string, value string, log log.Log) {
|
||||
switch fieldName {
|
||||
case "cacheKeyFrom":
|
||||
c.CacheKeyFrom = value
|
||||
|
||||
@@ -8,13 +8,14 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
// CheckCacheForKey checks if the key is in the cache, or triggers similarity search if not found.
|
||||
func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error {
|
||||
func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log log.Log, stream bool, useSimilaritySearch bool) error {
|
||||
activeCacheProvider := c.GetCacheProvider()
|
||||
if activeCacheProvider == nil {
|
||||
log.Debugf("[%s] [CheckCacheForKey] no cache provider configured, performing similarity search", PLUGIN_NAME)
|
||||
@@ -37,7 +38,7 @@ func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig
|
||||
}
|
||||
|
||||
// handleCacheResponse processes cache response and handles cache hits and misses.
|
||||
func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) {
|
||||
func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log log.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) {
|
||||
if err := response.Error(); err == nil && !response.IsNull() {
|
||||
log.Infof("[%s] cache hit for key: %s", PLUGIN_NAME, key)
|
||||
processCacheHit(key, response.String(), stream, ctx, c, log)
|
||||
@@ -60,7 +61,7 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex
|
||||
}
|
||||
|
||||
// processCacheHit handles a successful cache hit.
|
||||
func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) {
|
||||
func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log log.Log) {
|
||||
if strings.TrimSpace(response) == "" {
|
||||
log.Warnf("[%s] [processCacheHit] cached response for key %s is empty", PLUGIN_NAME, key)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
@@ -85,7 +86,7 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC
|
||||
}
|
||||
|
||||
// performSimilaritySearch determines the appropriate similarity search method to use.
|
||||
func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, queryString string, stream bool) error {
|
||||
func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log log.Log, queryString string, stream bool) error {
|
||||
activeVectorProvider := c.GetVectorProvider()
|
||||
if activeVectorProvider == nil {
|
||||
return logAndReturnError(log, "[performSimilaritySearch] no vector provider configured for similarity search")
|
||||
@@ -107,19 +108,19 @@ func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.Plugi
|
||||
}
|
||||
|
||||
// performStringQuery executes the string-based similarity search.
|
||||
func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error {
|
||||
func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log log.Log, stream bool) error {
|
||||
stringQuerier, ok := c.GetVectorProvider().(vector.StringQuerier)
|
||||
if !ok {
|
||||
return logAndReturnError(log, "[performStringQuery] active vector provider does not implement StringQuerier interface")
|
||||
}
|
||||
|
||||
return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) {
|
||||
return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log log.Log, err error) {
|
||||
handleQueryResults(key, results, ctx, log, stream, c, err)
|
||||
})
|
||||
}
|
||||
|
||||
// performEmbeddingQuery executes the embedding-based similarity search.
|
||||
func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error {
|
||||
func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log log.Log, stream bool) error {
|
||||
embeddingQuerier, ok := c.GetVectorProvider().(vector.EmbeddingQuerier)
|
||||
if !ok {
|
||||
return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] active vector provider does not implement EmbeddingQuerier interface"))
|
||||
@@ -138,7 +139,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC
|
||||
}
|
||||
ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding)
|
||||
|
||||
err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) {
|
||||
err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log log.Log, err error) {
|
||||
handleQueryResults(key, results, ctx, log, stream, c, err)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -148,7 +149,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC
|
||||
}
|
||||
|
||||
// handleQueryResults processes the results of similarity search and determines next actions.
|
||||
func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, err error) {
|
||||
func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log log.Log, stream bool, c config.PluginConfig, err error) {
|
||||
if err != nil {
|
||||
handleInternalError(err, fmt.Sprintf("[%s] [handleQueryResults] error querying vector database for key: %s", PLUGIN_NAME, key), log)
|
||||
return
|
||||
@@ -186,14 +187,14 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht
|
||||
}
|
||||
|
||||
// logAndReturnError logs an error and returns it.
|
||||
func logAndReturnError(log wrapper.Log, message string) error {
|
||||
func logAndReturnError(log log.Log, message string) error {
|
||||
message = fmt.Sprintf("[%s] %s", PLUGIN_NAME, message)
|
||||
log.Errorf(message)
|
||||
return errors.New(message)
|
||||
}
|
||||
|
||||
// handleInternalError logs an error and resumes the HTTP request.
|
||||
func handleInternalError(err error, message string, log wrapper.Log) {
|
||||
func handleInternalError(err error, message string, log log.Log) {
|
||||
if err != nil {
|
||||
log.Errorf("[%s] [handleInternalError] %s: %v", PLUGIN_NAME, message, err)
|
||||
} else {
|
||||
@@ -204,7 +205,7 @@ func handleInternalError(err error, message string, log wrapper.Log) {
|
||||
}
|
||||
|
||||
// Caches the response value
|
||||
func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) {
|
||||
func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log log.Log) {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
log.Warnf("[%s] [cacheResponse] cached value for key %s is empty", PLUGIN_NAME, key)
|
||||
return
|
||||
@@ -219,7 +220,7 @@ func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, v
|
||||
}
|
||||
|
||||
// Handles embedding upload if available
|
||||
func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) {
|
||||
func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log log.Log) {
|
||||
embedding := ctx.GetContext(CACHE_KEY_EMBEDDING_KEY)
|
||||
if embedding == nil {
|
||||
return
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -79,7 +80,7 @@ type CohereProvider struct {
|
||||
func (t *CohereProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_COHERE
|
||||
}
|
||||
func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (t *CohereProvider) constructParameters(texts []string, log log.Log) (string, [][2]string, []byte, error) {
|
||||
model := t.config.model
|
||||
|
||||
if model == "" {
|
||||
@@ -118,7 +119,7 @@ func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereRespons
|
||||
func (t *CohereProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -103,7 +104,7 @@ type DSProvider struct {
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (d *DSProvider) constructParameters(texts []string, log log.Log) (string, [][2]string, []byte, error) {
|
||||
|
||||
model := d.config.model
|
||||
|
||||
@@ -159,7 +160,7 @@ func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error)
|
||||
func (d *DSProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -78,7 +80,7 @@ type HuggingFaceEmbeddingRequest struct {
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
func (t *HuggingFaceProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (t *HuggingFaceProvider) constructParameters(text string, log log.Log) (string, [][2]string, []byte, error) {
|
||||
if text == "" {
|
||||
err := errors.New("queryString text cannot be empty")
|
||||
return "", nil, nil, err
|
||||
@@ -127,7 +129,7 @@ func (t *HuggingFaceProvider) parseTextEmbedding(responseBody []byte) ([]float64
|
||||
func (t *HuggingFaceProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -69,7 +71,7 @@ type ollamaEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func (t *ollamaProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (t *ollamaProvider) constructParameters(text string, log log.Log) (string, [][2]string, []byte, error) {
|
||||
if text == "" {
|
||||
err := errors.New("queryString text cannot be empty")
|
||||
return "", nil, nil, err
|
||||
@@ -105,7 +107,7 @@ func (t *ollamaProvider) parseTextEmbedding(responseBody []byte) (*ollamaRespons
|
||||
func (t *ollamaProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -93,7 +94,7 @@ type OpenAIProvider struct {
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (t *OpenAIProvider) constructParameters(text string, log log.Log) (string, [][2]string, []byte, error) {
|
||||
if text == "" {
|
||||
err := errors.New("queryString text cannot be empty")
|
||||
return "", nil, nil, err
|
||||
@@ -130,7 +131,7 @@ func (t *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIRespons
|
||||
func (t *OpenAIProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package embedding
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -108,6 +109,6 @@ type Provider interface {
|
||||
GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -97,7 +98,7 @@ type TIProvider struct {
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
func (t *TIProvider) constructParameters(texts []string, log log.Log) (string, [][2]string, []byte, error) {
|
||||
|
||||
data := TextInEmbeddingRequest{
|
||||
Input: texts,
|
||||
@@ -142,7 +143,7 @@ func (t *TIProvider) parseTextEmbedding(responseBody []byte) (*TextInResponse, e
|
||||
func (t *TIProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
log log.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -38,7 +39,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, c *config.PluginConfig, log log.Log) error {
|
||||
// config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider"))
|
||||
// config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider"))
|
||||
// config.RedisConfig.FromJson(json.Get("redis"))
|
||||
@@ -54,7 +55,7 @@ func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log log.Log) types.Action {
|
||||
skipCache, _ := proxywasm.GetHttpRequestHeader(SKIP_CACHE_HEADER)
|
||||
if skipCache == "on" {
|
||||
ctx.SetContext(SKIP_CACHE_HEADER, struct{}{})
|
||||
@@ -78,7 +79,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wr
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []byte, log log.Log) types.Action {
|
||||
|
||||
bodyJson := gjson.ParseBytes(body)
|
||||
// TODO: It may be necessary to support stream mode determination for different LLM providers.
|
||||
@@ -128,7 +129,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log log.Log) types.Action {
|
||||
skipCache := ctx.GetContext(SKIP_CACHE_HEADER)
|
||||
if skipCache != nil {
|
||||
ctx.SetUserAttribute("cache_status", "skip")
|
||||
@@ -150,7 +151,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log log.Log) []byte {
|
||||
log.Debugf("[onHttpResponseBody] is last chunk: %v", isLastChunk)
|
||||
log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk))
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
|
||||
func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log log.Log) error {
|
||||
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
|
||||
if tempContentI == nil {
|
||||
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk)
|
||||
@@ -28,7 +29,7 @@ func unifySSEChunk(data []byte) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
|
||||
func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log log.Log) error {
|
||||
var partialMessage []byte
|
||||
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
|
||||
log.Debugf("[handleStreamChunk] cache content: %v", ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY))
|
||||
@@ -54,7 +55,7 @@ func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []b
|
||||
return nil
|
||||
}
|
||||
|
||||
func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
|
||||
func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log log.Log) (string, error) {
|
||||
var body []byte
|
||||
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
|
||||
if tempContentI != nil {
|
||||
@@ -70,7 +71,7 @@ func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, c
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
|
||||
func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log log.Log) (string, error) {
|
||||
if len(chunk) > 0 {
|
||||
var lastMessage []byte
|
||||
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
|
||||
@@ -96,7 +97,7 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun
|
||||
return tempContentI.(string), nil
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) {
|
||||
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log log.Log) (string, error) {
|
||||
content := ""
|
||||
for _, chunk := range strings.Split(sseMessage, "\n\n") {
|
||||
log.Debugf("single sse message: %s", chunk)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -44,8 +45,8 @@ func (c *ChromaProvider) GetProviderType() string {
|
||||
func (d *ChromaProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 collection_id, embeddings 和 ids
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -96,8 +97,8 @@ func (d *ChromaProvider) UploadAnswerAndEmbedding(
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 collection_id, embeddings 和 ids
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -177,7 +178,7 @@ type chromaQueryResponse struct {
|
||||
Included []string `json:"included"`
|
||||
}
|
||||
|
||||
func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log log.Log) ([]QueryResult, error) {
|
||||
var queryResp chromaQueryResponse
|
||||
err := json.Unmarshal(responseBody, &queryResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -119,8 +120,8 @@ func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, err
|
||||
func (d *DvProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
url, body, headers, err := d.constructEmbeddingQueryParameters(emb)
|
||||
log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers)
|
||||
if err != nil {
|
||||
@@ -157,7 +158,7 @@ func getStringValue(fields map[string]interface{}, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log log.Log) ([]QueryResult, error) {
|
||||
resp, err := d.parseQueryResponse(responseBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -215,7 +216,7 @@ func (d *DvProvider) constructUploadParameters(emb []float64, queryString string
|
||||
return url, requestBody, header, err
|
||||
}
|
||||
|
||||
func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log log.Log, callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -235,7 +236,7 @@ func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log log.Log, callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -45,8 +46,8 @@ func (c *ESProvider) GetProviderType() string {
|
||||
func (d *ESProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
|
||||
requestBody, err := json.Marshal(esQueryRequest{
|
||||
Source: Source{Excludes: []string{"embedding"}},
|
||||
@@ -99,8 +100,8 @@ func (d *ESProvider) UploadAnswerAndEmbedding(
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 index, embeddings 和 question
|
||||
// 下面是一个例子
|
||||
// POST /<index>/_doc
|
||||
@@ -176,7 +177,7 @@ type esQueryResponse struct {
|
||||
} `json:"hits"`
|
||||
}
|
||||
|
||||
func (d *ESProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *ESProvider) parseQueryResponse(responseBody []byte, log log.Log) ([]QueryResult, error) {
|
||||
log.Infof("[ES] responseBody: %s", string(responseBody))
|
||||
var queryResp esQueryResponse
|
||||
err := json.Unmarshal(responseBody, &queryResp)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -58,8 +59,8 @@ func (d *milvusProvider) UploadAnswerAndEmbedding(
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 collectionName, data 和 Authorization. question, answer 可选
|
||||
// 需要填写 id,否则 v2.4.13-hotfix 提示 invalid syntax: invalid parameter[expected=Int64][actual=]
|
||||
// 如果不填写 id,要在创建 collection 的时候设置 autoId 为 true
|
||||
@@ -120,8 +121,8 @@ type milvusQueryRequest struct {
|
||||
func (d *milvusProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 collectionName, data, annsField. outputFields 为可选参数
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -175,7 +176,7 @@ func (d *milvusProvider) QueryEmbedding(
|
||||
)
|
||||
}
|
||||
|
||||
func (d *milvusProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *milvusProvider) parseQueryResponse(responseBody []byte, log log.Log) ([]QueryResult, error) {
|
||||
if !gjson.GetBytes(responseBody, "data.0.distance").Exists() {
|
||||
log.Errorf("[Milvus] No distance found in response body: %s", responseBody)
|
||||
return nil, errors.New("[Milvus] No distance found in response body")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -67,8 +68,8 @@ func (d *pineconeProvider) UploadAnswerAndEmbedding(
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 vector 和 question
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -122,8 +123,8 @@ type pineconeQueryRequest struct {
|
||||
func (d *pineconeProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 vector
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -163,7 +164,7 @@ func (d *pineconeProvider) QueryEmbedding(
|
||||
)
|
||||
}
|
||||
|
||||
func (d *pineconeProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *pineconeProvider) parseQueryResponse(responseBody []byte, log log.Log) ([]QueryResult, error) {
|
||||
if !gjson.GetBytes(responseBody, "matches.0.score").Exists() {
|
||||
log.Errorf("[Pinecone] No distance found in response body: %s", responseBody)
|
||||
return nil, errors.New("[Pinecone] No distance found in response body")
|
||||
|
||||
@@ -3,6 +3,7 @@ package vector
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -50,8 +51,8 @@ type EmbeddingQuerier interface {
|
||||
QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error
|
||||
}
|
||||
|
||||
type EmbeddingUploader interface {
|
||||
@@ -59,8 +60,8 @@ type EmbeddingUploader interface {
|
||||
queryString string,
|
||||
queryEmb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error
|
||||
}
|
||||
|
||||
type AnswerAndEmbeddingUploader interface {
|
||||
@@ -69,16 +70,16 @@ type AnswerAndEmbeddingUploader interface {
|
||||
queryEmb []float64,
|
||||
answer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error
|
||||
}
|
||||
|
||||
type StringQuerier interface {
|
||||
QueryString(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -63,8 +64,8 @@ func (d *qdrantProvider) UploadAnswerAndEmbedding(
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 id 和 vector. payload 可选
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -122,8 +123,8 @@ type qdrantQueryRequest struct {
|
||||
func (d *qdrantProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 vector 和 limit. with_payload 可选,为了直接得到问题答案,所以这里需要
|
||||
// 下面是一个例子
|
||||
// {
|
||||
@@ -164,7 +165,7 @@ func (d *qdrantProvider) QueryEmbedding(
|
||||
)
|
||||
}
|
||||
|
||||
func (d *qdrantProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *qdrantProvider) parseQueryResponse(responseBody []byte, log log.Log) ([]QueryResult, error) {
|
||||
// 返回的内容例子如下
|
||||
// {
|
||||
// "time": 0.002,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -45,8 +46,8 @@ func (c *WeaviateProvider) GetProviderType() string {
|
||||
func (d *WeaviateProvider) QueryEmbedding(
|
||||
emb []float64,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(results []QueryResult, ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 class, vector
|
||||
// 下面是一个例子
|
||||
// {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"}
|
||||
@@ -109,8 +110,8 @@ func (d *WeaviateProvider) UploadAnswerAndEmbedding(
|
||||
queryEmb []float64,
|
||||
queryAnswer string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
|
||||
log log.Log,
|
||||
callback func(ctx wrapper.HttpContext, log log.Log, err error)) error {
|
||||
// 最少需要填写的参数为 class, vector 和 question 和 answer
|
||||
// 下面是一个例子
|
||||
// {"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题", "answer": "这里是答案"}}
|
||||
@@ -155,7 +156,7 @@ type weaviateQueryRequest struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
func (d *WeaviateProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
|
||||
func (d *WeaviateProvider) parseQueryResponse(responseBody []byte, log log.Log) ([]QueryResult, error) {
|
||||
log.Infof("[Weaviate] queryResp: %s", string(responseBody))
|
||||
|
||||
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.collectionID)).Exists() {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -121,7 +122,7 @@ type ChatHistory struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log log.Log) error {
|
||||
c.RedisInfo.ServiceName = json.Get("redis.serviceName").String()
|
||||
if c.RedisInfo.ServiceName == "" {
|
||||
return errors.New("redis service name must not be empty")
|
||||
@@ -166,7 +167,7 @@ func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout), wrapper.WithDataBase(c.RedisInfo.Database))
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
log.Warnf("content is not json, can't process:%s", contentType)
|
||||
@@ -192,7 +193,7 @@ func TrimQuote(source string) string {
|
||||
return strings.Trim(source, `"`)
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
bodyJson := gjson.ParseBytes(body)
|
||||
if bodyJson.Get("stream").Bool() {
|
||||
ctx.SetContext(StreamContextKey, struct{}{})
|
||||
@@ -319,7 +320,7 @@ func getIntQueryParameter(name string, path string, defaultValue int) int {
|
||||
return num
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
|
||||
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log log.Log) string {
|
||||
content := ""
|
||||
for _, chunk := range strings.Split(sseMessage, "\n\n") {
|
||||
subMessages := strings.Split(chunk, "\n")
|
||||
@@ -355,14 +356,14 @@ func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage
|
||||
return content
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.SetContext(StreamContextKey, struct{}{})
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log log.Log) []byte {
|
||||
if ctx.GetContext(ToolCallsContextKey) != nil {
|
||||
// we should not cache tool call result
|
||||
return chunk
|
||||
@@ -454,7 +455,7 @@ func onHttpStreamResponseBody(ctx wrapper.HttpContext, config PluginConfig, chun
|
||||
return chunk
|
||||
}
|
||||
|
||||
func saveChatHistory(ctx wrapper.HttpContext, config PluginConfig, questionI any, value string, log wrapper.Log) {
|
||||
func saveChatHistory(ctx wrapper.HttpContext, config PluginConfig, questionI any, value string, log log.Log) {
|
||||
question := questionI.(string)
|
||||
identityKey := ctx.GetStringContext(IdentityKey, "")
|
||||
var chat []ChatHistory
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -100,7 +101,7 @@ type KVExtractor struct {
|
||||
ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log log.Log) error {
|
||||
log.Infof("config:%s", json.Raw)
|
||||
// init scene
|
||||
c.SceneInfo.Category = json.Get("scene.category").String()
|
||||
@@ -194,14 +195,14 @@ func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
log.Debug("start onHttpRequestHeaders function.")
|
||||
|
||||
log.Debug("end onHttpRequestHeaders function.")
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debug("start onHttpRequestBody function.")
|
||||
bodyJson := gjson.ParseBytes(body)
|
||||
TempKey := strings.Trim(bodyJson.Get(config.KeyFrom.RequestBody).Raw, `"`)
|
||||
@@ -259,21 +260,21 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
log.Debug("start onHttpResponseHeaders function.")
|
||||
|
||||
log.Debug("end onHttpResponseHeaders function.")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log log.Log) []byte {
|
||||
log.Debug("start onStreamingResponseBody function.")
|
||||
|
||||
log.Debug("end onStreamingResponseBody function.")
|
||||
return chunk
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debug("start onHttpResponseBody function.")
|
||||
|
||||
log.Debug("end onHttpResponseBody function.")
|
||||
@@ -290,7 +291,7 @@ type ProxyRequestMessage struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func generateProxyRequest(c *PluginConfig, texts []string, log wrapper.Log) (string, []byte, [][2]string) {
|
||||
func generateProxyRequest(c *PluginConfig, texts []string, log log.Log) (string, []byte, [][2]string) {
|
||||
url := c.LLMInfo.ProxyPath
|
||||
var userMessage ProxyRequestMessage
|
||||
userMessage.Role = "user"
|
||||
@@ -338,7 +339,7 @@ type ProxyResponseOutputChoicesMessage struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func proxyResponseHandler(responseBody []byte, log wrapper.Log) (*ProxyResponse, error) {
|
||||
func proxyResponseHandler(responseBody []byte, log log.Log) (*ProxyResponse, error) {
|
||||
var response ProxyResponse
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
@@ -348,7 +349,7 @@ func proxyResponseHandler(responseBody []byte, log wrapper.Log) (*ProxyResponse,
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getProxyResponseByExtractor(c *PluginConfig, responseBody []byte, log wrapper.Log) string {
|
||||
func getProxyResponseByExtractor(c *PluginConfig, responseBody []byte, log log.Log) string {
|
||||
bodyJson := gjson.ParseBytes(responseBody)
|
||||
responseContent := strings.Trim(bodyJson.Get(c.KeyFrom.ResponseBody).Raw, `"`)
|
||||
// llm返回的结果
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/santhosh-tekuri/jsonschema"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
@@ -141,7 +142,7 @@ func parseUrl(url string) (string, string) {
|
||||
return url[:index], url[index:]
|
||||
}
|
||||
|
||||
func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(result gjson.Result, config *PluginConfig, log log.Log) error {
|
||||
config.serviceName = result.Get("serviceName").String()
|
||||
config.serviceUrl = result.Get("serviceUrl").String()
|
||||
config.serviceDomain = result.Get("serviceDomain").String()
|
||||
@@ -278,7 +279,7 @@ func (r *RequestContext) assembleReqBody(config PluginConfig) []byte {
|
||||
return reqBody
|
||||
}
|
||||
|
||||
func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) {
|
||||
func (r *RequestContext) SaveBodyToHistMsg(log log.Log, reqBody []byte, respBody []byte) {
|
||||
r.RespBody = respBody
|
||||
lastUserMessage := ""
|
||||
lastSystemMessage := ""
|
||||
@@ -318,7 +319,7 @@ func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, resp
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) {
|
||||
func (r *RequestContext) SaveStrToHistMsg(log log.Log, errMsg string) {
|
||||
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||
Role: "system",
|
||||
Content: errMsg,
|
||||
@@ -340,7 +341,7 @@ func (c *PluginConfig) ValidateBody(body []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) {
|
||||
func (c *PluginConfig) ValidateJson(body []byte, log log.Log) (string, error) {
|
||||
content := gjson.ParseBytes(body).Get(c.contentPath).String()
|
||||
// first extract json from response body
|
||||
if content == "" {
|
||||
@@ -399,7 +400,7 @@ func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) {
|
||||
return jsonStr, nil
|
||||
}
|
||||
|
||||
func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) {
|
||||
func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log log.Log, body []byte) {
|
||||
log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body))
|
||||
header := [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
@@ -414,7 +415,7 @@ func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log,
|
||||
}
|
||||
}
|
||||
|
||||
func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) {
|
||||
func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log log.Log, retryCount int, requestContext *RequestContext) {
|
||||
// if retry count exceeds max retry count, return the response
|
||||
if retryCount >= config.maxRetry {
|
||||
log.Debugf("retry count exceeds max retry count")
|
||||
@@ -445,7 +446,7 @@ func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapp
|
||||
}, uint32(config.serviceTimeout))
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||
sendResponse(ctx, config, log, nil)
|
||||
return types.ActionPause
|
||||
@@ -505,7 +506,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
// if the request is from this plugin, continue the request
|
||||
fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool)
|
||||
if ok && fromThisPlugin {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -31,11 +32,11 @@ type AIPromptDecoratorConfig struct {
|
||||
Append []Message `json:"append"`
|
||||
}
|
||||
|
||||
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig, log wrapper.Log) error {
|
||||
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig, log log.Log) error {
|
||||
return json.Unmarshal([]byte(jsonConfig.Raw), config)
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, log log.Log) types.Action {
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -66,7 +67,7 @@ func decorateGeographicPrompt(entry *Message) (*Message, error) {
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log log.Log) types.Action {
|
||||
messageJson := `{"messages":[]}`
|
||||
|
||||
for _, entry := range config.Prepend {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -23,7 +24,7 @@ type AIPromptTemplateConfig struct {
|
||||
templates map[string]string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log log.Log) error {
|
||||
config.templates = make(map[string]string)
|
||||
for _, v := range json.Get("templates").Array() {
|
||||
config.templates[v.Get("name").String()] = v.Get("template").Raw
|
||||
@@ -32,7 +33,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log log.Log) types.Action {
|
||||
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
|
||||
if templateEnable == "false" {
|
||||
ctx.DontReadRequestBody()
|
||||
@@ -42,7 +43,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptTemplateConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptTemplateConfig, body []byte, log log.Log) types.Action {
|
||||
if gjson.GetBytes(body, "template").Exists() && gjson.GetBytes(body, "properties").Exists() {
|
||||
name := gjson.GetBytes(body, "template").String()
|
||||
template := config.templates[name]
|
||||
|
||||
@@ -2,7 +2,6 @@ package config
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -75,7 +74,7 @@ func (c *PluginConfig) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PluginConfig) Complete(log wrapper.Log) error {
|
||||
func (c *PluginConfig) Complete() error {
|
||||
if c.activeProviderConfig == nil {
|
||||
c.activeProvider = nil
|
||||
return nil
|
||||
@@ -89,7 +88,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
|
||||
}
|
||||
|
||||
providerConfig := c.GetProviderConfig()
|
||||
return providerConfig.SetApiTokensFailover(log, c.activeProvider)
|
||||
return providerConfig.SetApiTokensFailover(c.activeProvider)
|
||||
}
|
||||
|
||||
func (c *PluginConfig) GetProvider() provider.Provider {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -27,16 +28,16 @@ const (
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
pluginName,
|
||||
wrapper.ParseOverrideConfigBy(parseGlobalConfig, parseOverrideRuleConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeader),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||
wrapper.ParseOverrideConfig(parseGlobalConfig, parseOverrideRuleConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeader),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
||||
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig) error {
|
||||
log.Debugf("loading global config: %s", json.String())
|
||||
|
||||
pluginConfig.FromJson(json)
|
||||
@@ -44,7 +45,7 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
|
||||
log.Errorf("global rule config is invalid: %v", err)
|
||||
return err
|
||||
}
|
||||
if err := pluginConfig.Complete(log); err != nil {
|
||||
if err := pluginConfig.Complete(); err != nil {
|
||||
log.Errorf("failed to apply global rule config: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +53,7 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
||||
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig) error {
|
||||
log.Debugf("loading override rule config: %s", json.String())
|
||||
|
||||
*pluginConfig = global
|
||||
@@ -62,7 +63,7 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
|
||||
log.Errorf("overriden rule config is invalid: %v", err)
|
||||
return err
|
||||
}
|
||||
if err := pluginConfig.Complete(log); err != nil {
|
||||
if err := pluginConfig.Complete(); err != nil {
|
||||
log.Errorf("failed to apply overriden rule config: %v", err)
|
||||
return err
|
||||
}
|
||||
@@ -70,7 +71,7 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -112,15 +113,15 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
|
||||
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
||||
// Set the apiToken for the current request.
|
||||
providerConfig.SetApiTokenInUse(ctx, log)
|
||||
providerConfig.SetApiTokenInUse(ctx)
|
||||
// Set available apiTokens of current request in the context, will be used in the retryOnFailure
|
||||
providerConfig.SetAvailableApiTokens(ctx, log)
|
||||
providerConfig.SetAvailableApiTokens(ctx)
|
||||
|
||||
// save the original request host and path in case they are needed for apiToken health check and retry
|
||||
ctx.SetContext(provider.CtxRequestHost, wrapper.GetRequestHost())
|
||||
ctx.SetContext(provider.CtxRequestPath, wrapper.GetRequestPath())
|
||||
|
||||
err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||
err := handler.OnRequestHeaders(ctx, apiName)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
||||
return types.ActionContinue
|
||||
@@ -140,7 +141,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -161,11 +162,11 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
||||
}
|
||||
if providerConfig.IsOpenAIProtocol() {
|
||||
newBody = normalizeOpenAiRequestBody(newBody, log)
|
||||
newBody = normalizeOpenAiRequestBody(newBody)
|
||||
}
|
||||
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
||||
body = newBody
|
||||
action, err := handler.OnRequestBody(ctx, apiName, body, log)
|
||||
action, err := handler.OnRequestBody(ctx, apiName, body)
|
||||
if err == nil {
|
||||
return action
|
||||
}
|
||||
@@ -174,7 +175,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig) types.Action {
|
||||
if !wrapper.IsResponseFromUpstream() {
|
||||
// Response is not coming from the upstream. Let it pass through.
|
||||
ctx.DontReadResponseBody()
|
||||
@@ -201,23 +202,23 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
log.Errorf("unable to load :status header from response: %v", err)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status, log)
|
||||
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, status)
|
||||
}
|
||||
|
||||
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
||||
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse)
|
||||
|
||||
headers := util.GetOriginalResponseHeaders()
|
||||
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers)
|
||||
} else {
|
||||
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
|
||||
}
|
||||
util.ReplaceResponseHeaders(headers)
|
||||
|
||||
checkStream(ctx, log)
|
||||
checkStream(ctx)
|
||||
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
|
||||
var needHandleStreamingBody bool
|
||||
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
|
||||
@@ -233,7 +234,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool) []byte {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -246,7 +247,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
|
||||
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
return modifiedChunk
|
||||
}
|
||||
@@ -254,7 +255,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
}
|
||||
if handler, ok := activeProvider.(provider.StreamingEventHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk, log)
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk)
|
||||
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
|
||||
if len(events) == 0 {
|
||||
// No events are extracted, return the original chunk
|
||||
@@ -269,7 +270,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
continue
|
||||
}
|
||||
|
||||
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event, log)
|
||||
outputEvents, err := handler.OnStreamingEvent(ctx, apiName, event)
|
||||
if err != nil {
|
||||
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
|
||||
return chunk
|
||||
@@ -287,7 +288,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
return chunk
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
||||
activeProvider := pluginConfig.GetProvider()
|
||||
|
||||
if activeProvider == nil {
|
||||
@@ -299,19 +300,19 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
|
||||
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
||||
body, err := handler.TransformResponseBody(ctx, apiName, body)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
||||
if err = provider.ReplaceResponseBody(body); err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||
}
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte {
|
||||
func normalizeOpenAiRequestBody(body []byte) []byte {
|
||||
var err error
|
||||
// Default setting include_usage.
|
||||
if gjson.GetBytes(body, "stream").Bool() {
|
||||
@@ -323,7 +324,7 @@ func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
func checkStream(ctx wrapper.HttpContext) {
|
||||
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
||||
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
||||
if err != nil {
|
||||
|
||||
@@ -48,20 +48,20 @@ func (m *ai360Provider) GetProviderType() string {
|
||||
return providerTypeAi360
|
||||
}
|
||||
|
||||
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, ai360Domain)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
@@ -70,16 +71,16 @@ func (m *azureProvider) GetProviderType() string {
|
||||
return providerTypeAzure
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
finalRequestUrl := *m.serviceUrl
|
||||
if u, e := url.Parse(ctx.Path()); e == nil {
|
||||
if len(u.Query()) != 0 {
|
||||
|
||||
@@ -49,19 +49,19 @@ func (m *baichuanProvider) GetProviderType() string {
|
||||
return providerTypeBaichuan
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, baichuanDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -50,19 +50,19 @@ func (g *baiduProvider) GetProviderType() string {
|
||||
return providerTypeBaidu
|
||||
}
|
||||
|
||||
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, baiduDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
@@ -112,12 +113,12 @@ func (c *claudeProvider) GetProviderType() string {
|
||||
return providerTypeClaude
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
c.config.handleRequestHeaders(c, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||
|
||||
@@ -130,26 +131,26 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
headers.Set("anthropic-version", c.config.claudeVersion)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
return c.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
if err := c.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claudeRequest := c.buildClaudeTextGenRequest(request)
|
||||
return json.Marshal(claudeRequest)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -185,7 +186,7 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
log.Errorf("unable to unmarshal claude response: %v", err)
|
||||
continue
|
||||
}
|
||||
response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse, log)
|
||||
response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse)
|
||||
if response != nil {
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
@@ -266,7 +267,7 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse, log wrapper.Log) *chatCompletionResponse {
|
||||
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
|
||||
switch origResponse.Type {
|
||||
case "message_start":
|
||||
choice := chatCompletionChoice{
|
||||
|
||||
@@ -48,19 +48,19 @@ func (c *cloudflareProvider) GetProviderType() string {
|
||||
return providerTypeCloudflare
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
c.config.handleRequestHeaders(c, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
||||
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -65,16 +65,16 @@ func (m *cohereProvider) GetProviderType() string {
|
||||
return providerTypeCohere
|
||||
}
|
||||
|
||||
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest {
|
||||
@@ -96,19 +96,19 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, cohereDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -64,7 +65,7 @@ type ContextInserter interface {
|
||||
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
|
||||
}
|
||||
|
||||
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
|
||||
func (c *contextCache) GetContent(callback func(string, error)) error {
|
||||
if callback == nil {
|
||||
return errors.New("callback is nil")
|
||||
}
|
||||
@@ -106,27 +107,27 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
|
||||
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte) error {
|
||||
if c.loaded {
|
||||
log.Debugf("context file loaded from cache")
|
||||
insertContext(provider, c.content, nil, body, log)
|
||||
insertContext(provider, c.content, nil, body)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("loading context file from %s", c.fileUrl.String())
|
||||
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
|
||||
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil)
|
||||
return
|
||||
}
|
||||
c.content = string(responseBody)
|
||||
c.loaded = true
|
||||
log.Debugf("content: %s", c.content)
|
||||
insertContext(provider, c.content, nil, body, log)
|
||||
insertContext(provider, c.content, nil, body)
|
||||
}, c.timeout)
|
||||
}
|
||||
|
||||
func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
|
||||
func insertContext(provider Provider, content string, err error, body []byte) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
@@ -146,7 +147,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
|
||||
if err != nil {
|
||||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
|
||||
}
|
||||
if err := replaceRequestBody(body, log); err != nil {
|
||||
if err := replaceRequestBody(body); err != nil {
|
||||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func (m *cozeProvider) GetProviderType() string {
|
||||
return providerTypeCoze
|
||||
}
|
||||
|
||||
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, cozeDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -82,12 +82,12 @@ func (d *deeplProvider) GetProviderType() string {
|
||||
return providerTypeDeepl
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
d.config.handleRequestHeaders(d, ctx, apiName, log)
|
||||
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
d.config.handleRequestHeaders(d, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if apiName != "" {
|
||||
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
|
||||
}
|
||||
@@ -96,14 +96,14 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !d.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
|
||||
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return nil, err
|
||||
@@ -119,7 +119,7 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
|
||||
return json.Marshal(baiduRequest)
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -51,19 +51,19 @@ func (m *deepseekProvider) GetProviderType() string {
|
||||
return providerTypeDeepSeek
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, deepseekDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -50,12 +51,12 @@ func (d *difyProvider) GetProviderType() string {
|
||||
return providerTypeDify
|
||||
}
|
||||
|
||||
func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
d.config.handleRequestHeaders(d, ctx, apiName, log)
|
||||
func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
d.config.handleRequestHeaders(d, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if d.config.difyApiUrl != "" {
|
||||
log.Debugf("use local host: %s", d.config.difyApiUrl)
|
||||
util.OverwriteRequestHostHeader(headers, d.config.difyApiUrl)
|
||||
@@ -73,19 +74,19 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+d.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
|
||||
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return d.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
return d.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
err := d.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
err := d.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
|
||||
return json.Marshal(difyRequest)
|
||||
}
|
||||
|
||||
func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
@@ -146,7 +147,7 @@ func (d *difyProvider) responseDify2OpenAI(ctx wrapper.HttpContext, response *Di
|
||||
}
|
||||
}
|
||||
|
||||
func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -49,19 +49,19 @@ func (m *doubaoProvider) GetProviderType() string {
|
||||
return providerTypeDoubao
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, doubaoDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
@@ -125,11 +126,11 @@ func (c *ProviderConfig) initVariable() {
|
||||
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
|
||||
func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
|
||||
c.initVariable()
|
||||
// Reset shared data in case plugin configuration is updated
|
||||
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
|
||||
@@ -147,7 +148,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
|
||||
|
||||
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
|
||||
// Only the Wasm VM that successfully acquires the lease will perform health check
|
||||
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
|
||||
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID) {
|
||||
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
|
||||
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
|
||||
if err != nil {
|
||||
@@ -157,7 +158,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
|
||||
if len(unavailableTokens) > 0 {
|
||||
for _, apiToken := range unavailableTokens {
|
||||
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
|
||||
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
|
||||
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
|
||||
Cluster: healthCheckEndpoint.Cluster,
|
||||
})
|
||||
@@ -165,7 +166,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
|
||||
ctx := createHttpContext()
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
|
||||
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform request headers and body: %v", err)
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
|
||||
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
|
||||
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode == 200 {
|
||||
c.handleAvailableApiToken(apiToken, log)
|
||||
c.handleAvailableApiToken(apiToken)
|
||||
}
|
||||
}, uint32(c.failover.healthCheckTimeout))
|
||||
if err != nil {
|
||||
@@ -191,19 +192,19 @@ func generateUrl(header http.Header) string {
|
||||
return fmt.Sprintf("https://%s%s", header.Get(":authority"), header.Get(":path"))
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) (http.Header, []byte, error) {
|
||||
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte) (http.Header, []byte, error) {
|
||||
modifiedHeaders := util.SliceToHeader(headers)
|
||||
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders, log)
|
||||
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders)
|
||||
}
|
||||
|
||||
var err error
|
||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body)
|
||||
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, modifiedHeaders, log)
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, modifiedHeaders)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
|
||||
@@ -213,14 +214,14 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
|
||||
}
|
||||
|
||||
func createHttpContext() *wrapper.CommonHttpCtx[any] {
|
||||
setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
|
||||
setParseConfig := wrapper.ParseConfig[any](parseConfig)
|
||||
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
|
||||
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
|
||||
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
|
||||
func (c *ProviderConfig) generateRequestHeadersAndBody() (HealthCheckEndpoint, [][2]string, []byte) {
|
||||
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get request host and path: %v", err)
|
||||
@@ -248,20 +249,20 @@ func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthC
|
||||
return healthCheckEndpoint, headers, body
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
|
||||
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string) bool {
|
||||
now := time.Now().Unix()
|
||||
|
||||
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrorStatusNotFound) {
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
return c.setLease(vmID, now, cas)
|
||||
} else {
|
||||
log.Errorf("Failed to get lease: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
return c.setLease(vmID, now, cas)
|
||||
}
|
||||
|
||||
var lease Lease
|
||||
@@ -275,13 +276,13 @@ func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bo
|
||||
if lease.VMID == vmID || now-lease.Timestamp > 60 {
|
||||
lease.VMID = vmID
|
||||
lease.Timestamp = now
|
||||
return c.setLease(vmID, now, cas, log)
|
||||
return c.setLease(vmID, now, cas)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
|
||||
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32) bool {
|
||||
lease := Lease{
|
||||
VMID: vmID,
|
||||
Timestamp: timestamp,
|
||||
@@ -305,7 +306,7 @@ func generateVMID() string {
|
||||
|
||||
// When number of request successes exceeds the threshold during health check,
|
||||
// add the apiToken back to the available list and remove it from the unavailable list
|
||||
func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
|
||||
func (c *ProviderConfig) handleAvailableApiToken(apiToken string) {
|
||||
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
|
||||
@@ -315,18 +316,18 @@ func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Lo
|
||||
successCount := successApiTokenRequestCount[apiToken] + 1
|
||||
if successCount >= c.failover.successThreshold {
|
||||
log.Infof("healthcheck after failover: apiToken %s is available now, add it back to the apiTokens list", apiToken)
|
||||
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
|
||||
addApiToken(c.failover.ctxApiTokens, apiToken, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
|
||||
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
|
||||
addApiToken(c.failover.ctxApiTokens, apiToken)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken)
|
||||
} else {
|
||||
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken)
|
||||
}
|
||||
}
|
||||
|
||||
// When number of request failures exceeds the threshold,
|
||||
// remove the apiToken from the available list and add it to the unavailable list
|
||||
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
|
||||
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string) {
|
||||
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
|
||||
@@ -346,26 +347,26 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT
|
||||
failureCount := failureApiTokenRequestCount[apiToken] + 1
|
||||
if failureCount >= c.failover.failureThreshold {
|
||||
log.Infof("failover: apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
|
||||
removeApiToken(c.failover.ctxApiTokens, apiToken, log)
|
||||
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
|
||||
removeApiToken(c.failover.ctxApiTokens, apiToken)
|
||||
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
|
||||
// Set the request host and path to shared data in case they are needed in apiToken health check
|
||||
c.setHealthCheckEndpoint(ctx, log)
|
||||
c.setHealthCheckEndpoint(ctx)
|
||||
} else {
|
||||
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
|
||||
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
|
||||
}
|
||||
}
|
||||
|
||||
func addApiToken(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiToken(key, apiToken, addApiTokenOperation, log)
|
||||
func addApiToken(key, apiToken string) {
|
||||
modifyApiToken(key, apiToken, addApiTokenOperation)
|
||||
}
|
||||
|
||||
func removeApiToken(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiToken(key, apiToken, removeApiTokenOperation, log)
|
||||
func removeApiToken(key, apiToken string) {
|
||||
modifyApiToken(key, apiToken, removeApiTokenOperation)
|
||||
}
|
||||
|
||||
func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
|
||||
func modifyApiToken(key, apiToken, op string) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
apiTokens, cas, err := getApiTokens(key)
|
||||
if err != nil {
|
||||
@@ -468,15 +469,15 @@ func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
|
||||
return apiTokens, cas, nil
|
||||
}
|
||||
|
||||
func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
|
||||
func addApiTokenRequestCount(key, apiToken string) {
|
||||
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation)
|
||||
}
|
||||
|
||||
func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
|
||||
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
|
||||
func resetApiTokenRequestCount(key, apiToken string) {
|
||||
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
|
||||
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string) {
|
||||
if c.isFailoverEnabled() {
|
||||
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
|
||||
if err != nil {
|
||||
@@ -484,12 +485,12 @@ func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string,
|
||||
}
|
||||
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
|
||||
log.Infof("Reset apiToken %s request failure count", apiTokenInUse)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
|
||||
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
|
||||
func modifyApiTokenRequestCount(key, apiToken string, op string) {
|
||||
for attempt := 1; attempt <= casMaxRetries; attempt++ {
|
||||
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
|
||||
if err != nil {
|
||||
@@ -524,7 +525,7 @@ func (c *ProviderConfig) initApiTokens() error {
|
||||
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
|
||||
func (c *ProviderConfig) GetGlobalRandomToken() string {
|
||||
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
|
||||
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
|
||||
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
|
||||
@@ -550,7 +551,7 @@ func (c *ProviderConfig) GetAvailableApiToken(ctx wrapper.HttpContext) []string
|
||||
}
|
||||
|
||||
// SetAvailableApiTokens set available apiTokens of current request in the context, will be used in the retryOnFailure
|
||||
func (c *ProviderConfig) SetAvailableApiTokens(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
func (c *ProviderConfig) SetAvailableApiTokens(ctx wrapper.HttpContext) {
|
||||
var apiTokens []string
|
||||
if c.isFailoverEnabled() {
|
||||
apiTokens, _, _ = getApiTokens(c.failover.ctxApiTokens)
|
||||
@@ -572,14 +573,14 @@ func (c *ProviderConfig) resetSharedData() {
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string, log wrapper.Log) types.Action {
|
||||
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {
|
||||
if c.isFailoverEnabled() && util.MatchStatus(status, c.failover.failoverOnStatus) {
|
||||
log.Warnf("apiToken:%s need failover, error status:%s", apiTokenInUse, status)
|
||||
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
|
||||
c.handleUnavailableApiToken(ctx, apiTokenInUse)
|
||||
}
|
||||
if c.IsRetryOnFailureEnabled() && util.MatchStatus(status, c.retryOnFailure.retryOnStatus) {
|
||||
log.Warnf("need retry, notice that retry response will be bufferd, error status:%s", status)
|
||||
err := c.retryFailedRequest(activeProvider, ctx, apiTokenInUse, apiTokens, log)
|
||||
err := c.retryFailedRequest(activeProvider, ctx, apiTokenInUse, apiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("retryFailedRequest failed, err:%v", err)
|
||||
return types.ActionContinue
|
||||
@@ -598,11 +599,11 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||
return token
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) {
|
||||
var apiToken string
|
||||
// if enable apiToken failover, only use available apiToken from global apiTokens list
|
||||
if c.isFailoverEnabled() {
|
||||
apiToken = c.GetGlobalRandomToken(log)
|
||||
apiToken = c.GetGlobalRandomToken()
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
}
|
||||
@@ -610,7 +611,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext) {
|
||||
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get cluster_name: %v", err)
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -56,35 +56,35 @@ func (g *geminiProvider) GetProviderType() string {
|
||||
return providerTypeGemini
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return g.onChatCompletionRequestBody(ctx, body, headers, log)
|
||||
return g.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
|
||||
return g.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
err := g.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -95,9 +95,9 @@ func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
@@ -107,7 +107,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
log.Infof("chunk body:%s", string(chunk))
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
@@ -143,15 +143,15 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return g.onChatCompletionResponseBody(ctx, body, log)
|
||||
return g.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
return g.onEmbeddingsResponseBody(ctx, body, log)
|
||||
return g.onEmbeddingsResponseBody(ctx, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
geminiResponse := &geminiChatResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
||||
@@ -164,7 +164,7 @@ func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, b
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
geminiResponse := &geminiEmbeddingResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
|
||||
@@ -434,7 +434,7 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa
|
||||
}
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
proxywasm.LogErrorf("get toolCalls from gemini response failed: " + err.Error())
|
||||
log.Errorf("get toolCalls from gemini response failed: " + err.Error())
|
||||
return toolCalls
|
||||
}
|
||||
toolCall := toolCall{
|
||||
|
||||
@@ -51,20 +51,20 @@ func (m *githubProvider) GetProviderType() string {
|
||||
return providerTypeGithub
|
||||
}
|
||||
|
||||
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, githubDomain)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -48,19 +48,19 @@ func (g *groqProvider) GetProviderType() string {
|
||||
return providerTypeGroq
|
||||
}
|
||||
|
||||
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, groqDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -135,13 +136,13 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {
|
||||
return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if m.useOpenAICompatibleAPI() {
|
||||
util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain)
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
@@ -156,7 +157,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
||||
}
|
||||
|
||||
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
|
||||
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
@@ -185,7 +186,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
|
||||
// 若无配置文件,直接返回
|
||||
if m.config.context == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
return types.ActionContinue, replaceJsonRequestBody(request)
|
||||
}
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
log.Debugf("#debug nash5# ctx file loaded! callback start, content is: %s", content)
|
||||
@@ -204,17 +205,17 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
|
||||
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
|
||||
|
||||
if err := replaceJsonRequestBody(request, log); err != nil {
|
||||
if err := replaceJsonRequestBody(request); err != nil {
|
||||
util.ErrorHandler("ai-proxy.hunyuan.insert_ctx_failed", fmt.Errorf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
})
|
||||
if err == nil {
|
||||
log.Debugf("#debug nash5# ctx file load success!")
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
|
||||
log.Debugf("#debug nash5# ctx file load failed!")
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
return types.ActionContinue, replaceJsonRequestBody(request)
|
||||
}
|
||||
|
||||
// 使用open ai接口协议
|
||||
@@ -228,7 +229,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
return types.ActionContinue, errors.New("missing model in chat completion request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model) // 设置原始请求的model,以便返回值使用
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
@@ -258,7 +259,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
string(body),
|
||||
)
|
||||
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
|
||||
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
|
||||
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
@@ -278,10 +279,10 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
|
||||
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
|
||||
|
||||
if err := replaceJsonRequestBody(hunyuanRequest, log); err != nil {
|
||||
if err := replaceJsonRequestBody(hunyuanRequest); err != nil {
|
||||
util.ErrorHandler("ai-proxy.hunyuan.insert_ctx_failed", fmt.Errorf("failed to replace request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
})
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
@@ -289,12 +290,12 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
}
|
||||
|
||||
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
|
||||
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if m.useOpenAICompatibleAPI() {
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
request := &chatCompletionRequest{}
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -317,7 +318,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
||||
return json.Marshal(hunyuanRequest)
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
}
|
||||
@@ -364,7 +365,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
newBufferedBody = newBufferedBody[newEventPivot+2:] // 跳过结束标识
|
||||
|
||||
// 转换并追加到输出缓冲区
|
||||
convertedData, _ := m.convertChunkFromHunyuanToOpenAI(ctx, eventData, log)
|
||||
convertedData, _ := m.convertChunkFromHunyuanToOpenAI(ctx, eventData)
|
||||
// log.Debugf("@@@ >>> converted one chunk: %s", string(convertedData))
|
||||
outputBuffer = append(outputBuffer, convertedData...)
|
||||
}
|
||||
@@ -376,7 +377,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
return outputBuffer, nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, hunyuanChunk []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, hunyuanChunk []byte) ([]byte, error) {
|
||||
// 将hunyuan的chunk转为openai的chunk
|
||||
hunyuanFormattedChunk := &hunyuanTextGenDetailedResponseNonStreaming{}
|
||||
if err := json.Unmarshal(hunyuanChunk, hunyuanFormattedChunk); err != nil {
|
||||
@@ -433,7 +434,7 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
|
||||
return []byte(openAIChunk.String()), nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -73,45 +74,45 @@ func (m *minimaxProvider) GetProviderType() string {
|
||||
return providerTypeMinimax
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, minimaxDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
if minimaxApiTypePro == m.config.minimaxApiType {
|
||||
// Use chat completion Pro API.
|
||||
return m.handleRequestBodyByChatCompletionPro(body, log)
|
||||
return m.handleRequestBodyByChatCompletionPro(body)
|
||||
} else {
|
||||
// Use chat completion V2 API.
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionPro processes the request body using the chat completion Pro API.
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte) (types.Action, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// Map the model and rewrite the request path.
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping)
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
|
||||
if m.config.context == nil {
|
||||
minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, "")
|
||||
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest, log)
|
||||
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest)
|
||||
}
|
||||
|
||||
err := m.contextCache.GetContent(func(content string, err error) {
|
||||
@@ -126,30 +127,30 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
|
||||
// For minimaxChatCompletionPro, we need to manually handle context messages.
|
||||
// minimaxChatCompletionV2 uses the default defaultInsertHttpContextMessage method to insert context messages.
|
||||
minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, content)
|
||||
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
|
||||
if err := replaceJsonRequestBody(minimaxRequest); err != nil {
|
||||
util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err))
|
||||
}
|
||||
}, log)
|
||||
})
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
|
||||
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
return m.handleRequestBodyByChatCompletionV2(body, headers)
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionV2 processes the request body using the chat completion V2 API.
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header) ([]byte, error) {
|
||||
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
|
||||
|
||||
rawModel := gjson.GetBytes(body, "model").String()
|
||||
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
|
||||
mappedModel := getMappedModel(rawModel, m.config.modelMapping)
|
||||
return sjson.SetBytes(body, "model", mappedModel)
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
// Skip OnStreamingResponseBody() and OnResponseBody() when using the original protocol
|
||||
// or when the model corresponds to the chat completion V2 interface.
|
||||
if m.config.protocol == protocolOriginal || minimaxApiTypePro != m.config.minimaxApiType {
|
||||
@@ -160,7 +161,7 @@ func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiN
|
||||
}
|
||||
|
||||
// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -199,7 +200,7 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
}
|
||||
|
||||
// TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@@ -47,19 +47,19 @@ func (m *mistralProvider) GetProviderType() string {
|
||||
return providerTypeMistral
|
||||
}
|
||||
|
||||
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, mistralDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -172,7 +172,7 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r
|
||||
if contentPushed {
|
||||
if m.ReasoningContent != "" {
|
||||
// This shouldn't happen, but if it does, we can add a log here.
|
||||
proxywasm.LogWarnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
|
||||
log.Warnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -63,12 +64,12 @@ func (m *moonshotProvider) GetProviderType() string {
|
||||
return providerTypeMoonshot
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, moonshotDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
@@ -77,7 +78,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN
|
||||
|
||||
// moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法
|
||||
// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法
|
||||
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
@@ -87,12 +88,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
|
||||
}
|
||||
|
||||
request := &chatCompletionRequest{}
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
if m.config.moonshotFileId == "" && m.contextCache == nil {
|
||||
return types.ActionContinue, replaceJsonRequestBody(request, log)
|
||||
return types.ActionContinue, replaceJsonRequestBody(request)
|
||||
}
|
||||
|
||||
apiKey := m.config.GetOrSetTokenWithContext(ctx)
|
||||
@@ -105,23 +106,23 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
|
||||
_ = util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
|
||||
return
|
||||
}
|
||||
err = m.performChatCompletion(ctx, content, request, log)
|
||||
err = m.performChatCompletion(ctx, content, request)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
|
||||
}
|
||||
}, log)
|
||||
})
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) performChatCompletion(ctx wrapper.HttpContext, fileContent string, request *chatCompletionRequest, log wrapper.Log) error {
|
||||
func (m *moonshotProvider) performChatCompletion(ctx wrapper.HttpContext, fileContent string, request *chatCompletionRequest) error {
|
||||
insertContextMessage(request, fileContent)
|
||||
return replaceJsonRequestBody(request, log)
|
||||
return replaceJsonRequestBody(request)
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) getContextContent(apiKey string, callback func(string, error), log wrapper.Log) error {
|
||||
func (m *moonshotProvider) getContextContent(apiKey string, callback func(string, error)) error {
|
||||
if m.config.moonshotFileId != "" {
|
||||
if m.fileContent != "" {
|
||||
callback(m.fileContent, nil)
|
||||
@@ -142,7 +143,7 @@ func (m *moonshotProvider) getContextContent(apiKey string, callback func(string
|
||||
}
|
||||
|
||||
if m.contextCache != nil {
|
||||
return m.contextCache.GetContent(callback, log)
|
||||
return m.contextCache.GetContent(callback)
|
||||
}
|
||||
|
||||
return errors.New("both moonshotFileId and context are not configured")
|
||||
@@ -161,7 +162,7 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
|
||||
}
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
|
||||
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error) {
|
||||
if name != ApiNameChatCompletion {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -54,19 +54,19 @@ func (m *ollamaProvider) GetProviderType() string {
|
||||
return providerTypeOllama
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceDomain)
|
||||
headers.Del("Content-Length")
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -69,7 +69,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
}
|
||||
}
|
||||
config.setDefaultCapabilities(capabilities)
|
||||
proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
|
||||
log.Debugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
|
||||
pairs[0], customPath, isDirectCustomPath, capabilities)
|
||||
return &openaiProvider{
|
||||
config: config,
|
||||
@@ -92,12 +92,12 @@ func (m *openaiProvider) GetProviderType() string {
|
||||
return providerTypeOpenAI
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if m.customPath != "" {
|
||||
if m.isDirectCustomPath || apiName == "" {
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
@@ -118,15 +118,15 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
// We don't need to process the request body for other APIs.
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if m.config.responseJsonSchema != nil {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
@@ -136,5 +136,5 @@ func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
|
||||
request.ResponseFormat = m.config.responseJsonSchema
|
||||
body, _ = json.Marshal(request)
|
||||
}
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -145,19 +146,19 @@ type Provider interface {
|
||||
}
|
||||
|
||||
type RequestHeadersHandler interface {
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error
|
||||
}
|
||||
|
||||
type RequestBodyHandler interface {
|
||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error)
|
||||
}
|
||||
|
||||
type StreamingResponseBodyHandler interface {
|
||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error)
|
||||
}
|
||||
|
||||
type StreamingEventHandler interface {
|
||||
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
|
||||
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error)
|
||||
}
|
||||
|
||||
type ApiNameHandler interface {
|
||||
@@ -165,25 +166,25 @@ type ApiNameHandler interface {
|
||||
}
|
||||
|
||||
type TransformRequestHeadersHandler interface {
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
|
||||
}
|
||||
|
||||
type TransformRequestBodyHandler interface {
|
||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
|
||||
// Some providers (e.g. gemini) transform request headers (e.g., path) based on the request body (e.g., model).
|
||||
type TransformRequestBodyHeadersHandler interface {
|
||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error)
|
||||
}
|
||||
|
||||
type TransformResponseHeadersHandler interface {
|
||||
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
|
||||
}
|
||||
|
||||
type TransformResponseBodyHandler interface {
|
||||
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
@@ -496,7 +497,7 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
|
||||
switch req := request.(type) {
|
||||
case *chatCompletionRequest:
|
||||
if err := decodeChatCompletionRequest(body, req); err != nil {
|
||||
@@ -511,18 +512,18 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
ctx.SetContext(ctxKeyIsStreaming, false)
|
||||
}
|
||||
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
return c.setRequestModel(ctx, req)
|
||||
case *embeddingsRequest:
|
||||
if err := decodeEmbeddingsRequest(body, req); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
return c.setRequestModel(ctx, req)
|
||||
default:
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}) error {
|
||||
var model *string
|
||||
|
||||
switch req := request.(type) {
|
||||
@@ -534,16 +535,16 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
|
||||
return errors.New("unsupported request type")
|
||||
}
|
||||
|
||||
return c.mapModel(ctx, model, log)
|
||||
return c.mapModel(ctx, model)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string) error {
|
||||
if *model == "" {
|
||||
return errors.New("missing model in request")
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
|
||||
|
||||
mappedModel := getMappedModel(*model, c.modelMapping, log)
|
||||
mappedModel := getMappedModel(*model, c.modelMapping)
|
||||
if mappedModel == "" {
|
||||
return errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
@@ -553,15 +554,15 @@ func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wr
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
mappedModel := doGetMappedModel(model, modelMapping, log)
|
||||
func getMappedModel(model string, modelMapping map[string]string) string {
|
||||
mappedModel := doGetMappedModel(model, modelMapping)
|
||||
if len(mappedModel) != 0 {
|
||||
return mappedModel
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
|
||||
func doGetMappedModel(model string, modelMapping map[string]string) string {
|
||||
if len(modelMapping) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -590,7 +591,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
|
||||
return ""
|
||||
}
|
||||
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
|
||||
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
|
||||
body := chunk
|
||||
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
|
||||
body = append(bufferedStreamingBody, chunk...)
|
||||
@@ -679,8 +680,7 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
|
||||
) (types.Action, error) {
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
@@ -689,13 +689,13 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
// use openai protocol
|
||||
var err error
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -704,28 +704,28 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
|
||||
if apiName == ApiNameChatCompletion {
|
||||
if c.context == nil {
|
||||
return types.ActionContinue, replaceRequestBody(body, log)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
||||
err = contextCache.GetContextFromFile(ctx, provider, body)
|
||||
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
return types.ActionContinue, replaceRequestBody(body, log)
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers, log)
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
}
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
stream := gjson.GetBytes(body, "stream").Bool()
|
||||
@@ -738,7 +738,7 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
||||
}
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
|
||||
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping))
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -79,7 +80,7 @@ type qwenProvider struct {
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if m.config.qwenDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, m.config.qwenDomain)
|
||||
} else {
|
||||
@@ -92,11 +93,11 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if m.config.qwenEnableCompatible {
|
||||
if gjson.GetBytes(body, "model").Exists() {
|
||||
rawModel := gjson.GetBytes(body, "model").String()
|
||||
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
|
||||
mappedModel := getMappedModel(rawModel, m.config.modelMapping)
|
||||
newBody, err := sjson.SetBytes(body, "model", mappedModel)
|
||||
if err != nil {
|
||||
log.Errorf("Replace model error: %v", err)
|
||||
@@ -108,11 +109,11 @@ func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
|
||||
}
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return m.onChatCompletionRequestBody(ctx, body, headers, log)
|
||||
return m.onChatCompletionRequestBody(ctx, body, headers)
|
||||
case ApiNameEmbeddings:
|
||||
return m.onEmbeddingsRequestBody(ctx, body, log)
|
||||
return m.onEmbeddingsRequestBody(ctx, body)
|
||||
default:
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,8 +121,8 @@ func (m *qwenProvider) GetProviderType() string {
|
||||
return providerTypeQwen
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadRequestBody()
|
||||
@@ -131,16 +132,16 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
err := m.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -162,9 +163,9 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
|
||||
return m.buildQwenTextGenerationRequest(ctx, request, streaming)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -175,7 +176,7 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
|
||||
return json.Marshal(qwenRequest)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
|
||||
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error) {
|
||||
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -189,7 +190,7 @@ func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, e
|
||||
}
|
||||
|
||||
var outputEvents []StreamEvent
|
||||
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
|
||||
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming)
|
||||
for _, response := range responses {
|
||||
responseBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
@@ -203,15 +204,15 @@ func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, e
|
||||
return outputEvents, nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if m.config.qwenEnableCompatible {
|
||||
return body, nil
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return m.onChatCompletionResponseBody(ctx, body, log)
|
||||
return m.onChatCompletionResponseBody(ctx, body)
|
||||
}
|
||||
if apiName == ApiNameEmbeddings {
|
||||
return m.onEmbeddingsResponseBody(ctx, body, log)
|
||||
return m.onEmbeddingsResponseBody(ctx, body)
|
||||
}
|
||||
if m.config.isSupportedAPI(apiName) {
|
||||
return body, nil
|
||||
@@ -219,7 +220,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
|
||||
return nil, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
qwenResponse := &qwenTextGenResponse{}
|
||||
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
@@ -228,7 +229,7 @@ func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, bod
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
qwenResponse := &qwenTextEmbeddingResponse{}
|
||||
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
@@ -308,7 +309,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
|
||||
}
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse {
|
||||
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool) []*chatCompletionResponse {
|
||||
baseMessage := chatCompletionResponse{
|
||||
Id: qwenResponse.RequestId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
|
||||
@@ -3,7 +3,8 @@ package provider
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
@@ -24,7 +25,7 @@ func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
|
||||
func replaceJsonRequestBody(request interface{}) error {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to marshal request: %v", err)
|
||||
@@ -37,7 +38,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func replaceRequestBody(body []byte, log wrapper.Log) error {
|
||||
func replaceRequestBody(body []byte) error {
|
||||
log.Debugf("request body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -65,7 +66,7 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func ReplaceResponseBody(body []byte, log wrapper.Log) error {
|
||||
func ReplaceResponseBody(body []byte) error {
|
||||
log.Debugf("response body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -50,24 +51,24 @@ func (c *ProviderConfig) IsRetryOnFailureEnabled() bool {
|
||||
return c.retryOnFailure.enabled
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, log wrapper.Log) error {
|
||||
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string) error {
|
||||
log.Infof("Retry failed request: provider=%s", activeProvider.GetProviderType())
|
||||
retryClient := createRetryClient()
|
||||
apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName)
|
||||
ctx.SetContext(ctxRetryCount, 1)
|
||||
return c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens, log)
|
||||
return c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) {
|
||||
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte) ([][2]string, []byte) {
|
||||
if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok {
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers)
|
||||
} else {
|
||||
c.DefaultTransformResponseHeaders(ctx, headers)
|
||||
}
|
||||
|
||||
if handler, ok := activeProvider.(TransformResponseBodyHandler); ok {
|
||||
var err error
|
||||
body, err = handler.TransformResponseBody(ctx, apiName, body, log)
|
||||
body, err = handler.TransformResponseBody(ctx, apiName, body)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform response body: %v", err)
|
||||
}
|
||||
@@ -77,7 +78,7 @@ func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) retryCall(
|
||||
ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider,
|
||||
ctx wrapper.HttpContext, activeProvider Provider,
|
||||
apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte,
|
||||
retryClient *wrapper.ClusterClient[wrapper.RouteCluster],
|
||||
apiTokenInUse string, apiTokens []string) {
|
||||
@@ -87,7 +88,7 @@ func (c *ProviderConfig) retryCall(
|
||||
|
||||
if statusCode == 200 {
|
||||
log.Infof("Retry request succeeded")
|
||||
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log)
|
||||
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody)
|
||||
proxywasm.SendHttpResponse(200, headers, body, -1)
|
||||
return
|
||||
} else {
|
||||
@@ -97,7 +98,7 @@ func (c *ProviderConfig) retryCall(
|
||||
retryCount++
|
||||
if retryCount <= int(c.retryOnFailure.maxRetries) {
|
||||
ctx.SetContext(ctxRetryCount, retryCount)
|
||||
err := c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens, log)
|
||||
err := c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens)
|
||||
if err != nil {
|
||||
log.Errorf("sendRetryRequest failed, err:%v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
@@ -113,10 +114,10 @@ func (c *ProviderConfig) retryCall(
|
||||
func (c *ProviderConfig) sendRetryRequest(
|
||||
ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider,
|
||||
retryClient *wrapper.ClusterClient[wrapper.RouteCluster],
|
||||
apiTokenInUse string, apiTokens []string, log wrapper.Log) error {
|
||||
apiTokenInUse string, apiTokens []string) error {
|
||||
|
||||
// Remove last failed token from retry apiTokens list
|
||||
apiTokens = removeApiTokenFromRetryList(apiTokens, apiTokenInUse, log)
|
||||
apiTokens = removeApiTokenFromRetryList(apiTokens, apiTokenInUse)
|
||||
if len(apiTokens) == 0 {
|
||||
return errors.New("No more apiTokens to retry")
|
||||
}
|
||||
@@ -130,14 +131,14 @@ func (c *ProviderConfig) sendRetryRequest(
|
||||
{"content-type", "application/json"},
|
||||
{":authority", ctx.GetStringContext(CtxRequestHost, "")},
|
||||
{":path", ctx.GetStringContext(CtxRequestPath, "")},
|
||||
}, requestBody, log)
|
||||
}, requestBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sendRetryRequest failed to transform request headers and body: %v", err)
|
||||
}
|
||||
|
||||
err = retryClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient, apiTokenInUse, apiTokens)
|
||||
c.retryCall(ctx, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient, apiTokenInUse, apiTokens)
|
||||
}, uint32(c.retryOnFailure.retryTimeout))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to send retry request: %v", err)
|
||||
@@ -150,7 +151,7 @@ func createRetryClient() *wrapper.ClusterClient[wrapper.RouteCluster] {
|
||||
return retryClient
|
||||
}
|
||||
|
||||
func removeApiTokenFromRetryList(apiTokens []string, removedApiToken string, log wrapper.Log) []string {
|
||||
func removeApiTokenFromRetryList(apiTokens []string, removedApiToken string) []string {
|
||||
var availableApiTokens []string
|
||||
for _, s := range apiTokens {
|
||||
if s != removedApiToken {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
@@ -73,19 +74,19 @@ func (p *sparkProvider) GetProviderType() string {
|
||||
return providerTypeSpark
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
p.config.handleRequestHeaders(p, ctx, apiName, log)
|
||||
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
p.config.handleRequestHeaders(p, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !p.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
|
||||
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return body, nil
|
||||
}
|
||||
@@ -100,7 +101,7 @@ func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName A
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, respons
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, sparkHost)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -48,19 +48,19 @@ func (m *stepfunProvider) GetProviderType() string {
|
||||
return providerTypeStepfun
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, stepfunDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -47,19 +47,19 @@ func (m *togetherAIProvider) GetProviderType() string {
|
||||
return providerTypeTogetherAI
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, togetherAIDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -47,19 +47,19 @@ func (m *yiProvider) GetProviderType() string {
|
||||
return providerTypeYi
|
||||
}
|
||||
|
||||
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, yiDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -49,19 +49,19 @@ func (m *zhipuAiProvider) GetProviderType() string {
|
||||
return providerTypeZhipuAi
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !m.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -72,7 +73,7 @@ type RedisInfo struct {
|
||||
Database int `required:"false" yaml:"database" json:"database"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *QuotaConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *QuotaConfig, log log.Log) error {
|
||||
log.Debugf("parse config()")
|
||||
// admin
|
||||
config.AdminPath = json.Get("admin_path").String()
|
||||
@@ -126,7 +127,7 @@ func parseConfig(json gjson.Result, config *QuotaConfig, log wrapper.Log) error
|
||||
return config.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(database))
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log log.Log) types.Action {
|
||||
log.Debugf("onHttpRequestHeaders()")
|
||||
// get tokens
|
||||
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
|
||||
@@ -183,7 +184,7 @@ func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log w
|
||||
return types.HeaderStopAllIterationAndWatermark
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debugf("onHttpRequestBody()")
|
||||
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
|
||||
if !ok {
|
||||
@@ -211,7 +212,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte,
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log log.Log) []byte {
|
||||
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
|
||||
if !ok {
|
||||
return data
|
||||
@@ -274,7 +275,7 @@ func deniedUnauthorizedConsumer() types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func getOperationMode(path string, adminPath string, log wrapper.Log) (ChatMode, AdminMode) {
|
||||
func getOperationMode(path string, adminPath string, log log.Log) (ChatMode, AdminMode) {
|
||||
fullAdminPath := "/v1/chat/completions" + adminPath
|
||||
if strings.HasSuffix(path, fullAdminPath+"/refresh") {
|
||||
return ChatModeAdmin, AdminModeRefresh
|
||||
@@ -291,7 +292,7 @@ func getOperationMode(path string, adminPath string, log wrapper.Log) (ChatMode,
|
||||
return ChatModeNone, AdminModeNone
|
||||
}
|
||||
|
||||
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log wrapper.Log) types.Action {
|
||||
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action {
|
||||
// check consumer
|
||||
if adminConsumer != config.AdminConsumer {
|
||||
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
|
||||
@@ -325,7 +326,7 @@ func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer str
|
||||
|
||||
return types.ActionPause
|
||||
}
|
||||
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log wrapper.Log) types.Action {
|
||||
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log log.Log) types.Action {
|
||||
// check consumer
|
||||
if adminConsumer != config.AdminConsumer {
|
||||
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
|
||||
@@ -368,7 +369,7 @@ func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer strin
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log wrapper.Log) types.Action {
|
||||
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log log.Log) types.Action {
|
||||
// check consumer
|
||||
if adminConsumer != config.AdminConsumer {
|
||||
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"ai-rag/dashscope"
|
||||
"ai-rag/dashvector"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -51,7 +52,7 @@ type Message struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *AIRagConfig, log log.Log) error {
|
||||
checkList := []string{
|
||||
"dashscope.apiKey",
|
||||
"dashscope.serviceFQDN",
|
||||
@@ -91,12 +92,12 @@ func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log log.Log) types.Action {
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log log.Log) types.Action {
|
||||
var rawRequest Request
|
||||
_ = json.Unmarshal(body, &rawRequest)
|
||||
messageLength := len(rawRequest.Messages)
|
||||
@@ -165,7 +166,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log log.Log) types.Action {
|
||||
recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool)
|
||||
if ok && recall {
|
||||
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true")
|
||||
|
||||
@@ -24,19 +24,18 @@ import (
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/elasticsearch"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/google"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/quark"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type SearchRewrite struct {
|
||||
@@ -92,7 +91,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *Config, log log.Log) error {
|
||||
config.defaultEnable = true // Default to true if not specified
|
||||
if json.Get("defaultEnable").Exists() {
|
||||
config.defaultEnable = json.Get("defaultEnable").Bool()
|
||||
@@ -276,7 +275,7 @@ func parseConfig(json gjson.Result, config *Config, log wrapper.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||||
// The request does not have a body.
|
||||
if contentType == "" {
|
||||
@@ -292,7 +291,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Lo
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
|
||||
// Check if plugin should be enabled based on config and request
|
||||
webSearchOptions := gjson.GetBytes(body, "web_search_options")
|
||||
if !config.defaultEnable {
|
||||
@@ -451,7 +450,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
||||
}}, log)
|
||||
}
|
||||
|
||||
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log wrapper.Log) types.Action {
|
||||
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log log.Log) types.Action {
|
||||
searchResultGroups := make([][]engine.SearchResult, len(config.engine))
|
||||
var finished int
|
||||
var searching int
|
||||
@@ -549,7 +548,7 @@ func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
|
||||
if !config.needReference {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
@@ -566,7 +565,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log wrapper.L
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
|
||||
references := ctx.GetStringContext("References", "")
|
||||
if references == "" {
|
||||
return types.ActionContinue
|
||||
@@ -620,7 +619,7 @@ const (
|
||||
BUFFER_SIZE = 30
|
||||
)
|
||||
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log log.Log) []byte {
|
||||
if ctx.GetBoolContext("ReferenceAppended", false) {
|
||||
return chunk
|
||||
}
|
||||
@@ -659,7 +658,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
||||
}
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log wrapper.Log) string {
|
||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log log.Log) string {
|
||||
log.Debugf("single sse message: %s", sseMessage)
|
||||
subMessages := strings.Split(sseMessage, "\n")
|
||||
var message string
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -172,7 +173,7 @@ func generateHexID(length int) (string, error) {
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *AISecurityConfig, log log.Log) error {
|
||||
serviceName := json.Get("serviceName").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
serviceHost := json.Get("serviceHost").String()
|
||||
@@ -250,7 +251,7 @@ func generateRandomID() string {
|
||||
return "chatcmpl-" + string(b)
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log log.Log) types.Action {
|
||||
if !config.checkRequest {
|
||||
log.Debugf("request checking is disabled")
|
||||
ctx.DontReadRequestBody()
|
||||
@@ -258,7 +259,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debugf("checking request body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||
@@ -367,7 +368,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log log.Log) types.Action {
|
||||
if !config.checkResponse {
|
||||
log.Debugf("response checking is disabled")
|
||||
ctx.DontReadResponseBody()
|
||||
@@ -382,7 +383,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debugf("checking response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
@@ -507,7 +508,7 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
|
||||
return strings.Join(strChunks, "")
|
||||
}
|
||||
|
||||
func marshalStr(raw string, log wrapper.Log) string {
|
||||
func marshalStr(raw string, log log.Log) string {
|
||||
helper := map[string]string{
|
||||
"placeholder": raw,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -128,7 +129,7 @@ func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrapper.Log) error {
|
||||
func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log log.Log) error {
|
||||
// Parse tracing span attributes setting.
|
||||
attributeConfigs := configJson.Get("attributes").Array()
|
||||
config.attributes = make([]Attribute, len(attributeConfigs))
|
||||
@@ -152,7 +153,7 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrappe
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action {
|
||||
route, _ := getRouteName()
|
||||
cluster, _ := getClusterName()
|
||||
api, api_error := getAPIName()
|
||||
@@ -176,7 +177,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action {
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, RequestBody, body, log)
|
||||
|
||||
@@ -185,7 +186,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
if !strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.BufferResponseBody()
|
||||
@@ -197,7 +198,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, l
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, data []byte, endOfStream bool, log log.Log) []byte {
|
||||
// Buffer stream body for record log & span attributes
|
||||
if config.shouldBufferStreamingBody {
|
||||
var streamingBodyBuffer []byte
|
||||
@@ -255,7 +256,7 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
return data
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log log.Log) types.Action {
|
||||
// Get requestStartTime from http context
|
||||
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
|
||||
|
||||
@@ -313,7 +314,7 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
|
||||
}
|
||||
|
||||
// fetches the tracing span value from the specified source.
|
||||
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
|
||||
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log log.Log) {
|
||||
for _, attribute := range config.attributes {
|
||||
var key string
|
||||
var value interface{}
|
||||
@@ -352,7 +353,7 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
||||
}
|
||||
}
|
||||
|
||||
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
|
||||
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log log.Log) interface{} {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
var value interface{}
|
||||
if rule == RuleFirst {
|
||||
@@ -387,7 +388,7 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
|
||||
}
|
||||
|
||||
// Set the tracing span with value.
|
||||
func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
|
||||
func setSpanAttribute(key string, value interface{}, log log.Log) {
|
||||
if value != "" {
|
||||
traceSpanTag := wrapper.TraceSpanTagPrefix + key
|
||||
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
|
||||
@@ -398,7 +399,7 @@ func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
|
||||
}
|
||||
}
|
||||
|
||||
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
|
||||
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log log.Log) {
|
||||
// Generate usage metrics
|
||||
var ok bool
|
||||
var route, cluster, model string
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -84,7 +85,7 @@ type LimitConfigItem struct {
|
||||
timeWindow int64 // 时间窗口大小
|
||||
}
|
||||
|
||||
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error {
|
||||
func initRedisClusterClient(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
|
||||
redisConfig := json.Get("redis")
|
||||
if !redisConfig.Exists() {
|
||||
return errors.New("missing redis in config")
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -81,7 +82,7 @@ type LimitRedisContext struct {
|
||||
window int64
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
|
||||
err := initRedisClusterClient(json, config, log)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -95,7 +96,7 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapp
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action {
|
||||
// 判断是否命中限流规则
|
||||
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
|
||||
if ruleItem == nil || configItem == nil {
|
||||
@@ -143,7 +144,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log log.Log) []byte {
|
||||
var inputToken, outputToken int64
|
||||
if inputToken, outputToken, ok := getUsage(data); ok {
|
||||
ctx.SetContext("input_token", inputToken)
|
||||
@@ -189,7 +190,7 @@ func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bo
|
||||
return
|
||||
}
|
||||
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
for _, rule := range ruleItems {
|
||||
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
|
||||
if ruleItem != nil && configItem != nil {
|
||||
@@ -199,7 +200,7 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
switch rule.limitType {
|
||||
// 根据HTTP请求头限流
|
||||
case limitByHeaderType, limitByPerHeaderType:
|
||||
@@ -258,7 +259,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Lo
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func logDebugAndReturnEmpty(log wrapper.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
log.Debugf(errMsg, args...)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -52,7 +53,7 @@ const llmRequestTemplate = `{
|
||||
}
|
||||
}`
|
||||
|
||||
func parseConfig(json gjson.Result, config *AITransformerConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *AITransformerConfig, log log.Log) error {
|
||||
config.requestTransformEnable = json.Get("request.enable").Bool()
|
||||
config.requestTransformPrompt = json.Get("request.prompt").String()
|
||||
config.responseTransformEnable = json.Get("response.enable").Bool()
|
||||
@@ -89,7 +90,7 @@ func extraceHttpFrame(frame string) ([][2]string, []byte, error) {
|
||||
return headers, body, nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log log.Log) types.Action {
|
||||
log.Info("onHttpRequestHeaders")
|
||||
if !config.requestTransformEnable || config.requestTransformPrompt == "" {
|
||||
ctx.DontReadRequestBody()
|
||||
@@ -99,7 +100,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AITransformerConfig, l
|
||||
}
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log log.Log) types.Action {
|
||||
log.Info("onHttpRequestBody")
|
||||
headers, err := proxywasm.GetHttpRequestHeaders()
|
||||
if err != nil {
|
||||
@@ -133,7 +134,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AITransformerConfig, body
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AITransformerConfig, log log.Log) types.Action {
|
||||
if !config.responseTransformEnable || config.responseTransformPrompt == "" {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
@@ -142,7 +143,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AITransformerConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AITransformerConfig, body []byte, log log.Log) types.Action {
|
||||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||
if err != nil {
|
||||
log.Error("Failed to get http response headers.")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"api-workflow/utils"
|
||||
. "api-workflow/workflow"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -44,7 +45,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, c *PluginConfig, log log.Log) error {
|
||||
|
||||
edges := make([]Edge, 0)
|
||||
nodes := make(map[string]Node)
|
||||
@@ -174,7 +175,7 @@ func initWorkflowExecStatus(config *PluginConfig) (map[string]int, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log log.Log) types.Action {
|
||||
|
||||
initHeader := make([][2]string, 0)
|
||||
// 初始化运行状态
|
||||
@@ -199,7 +200,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
}
|
||||
|
||||
// 放入符合条件的edge
|
||||
func recursive(edge Edge, headers [][2]string, body []byte, depth uint32, config PluginConfig, log wrapper.Log, ctx wrapper.HttpContext) error {
|
||||
func recursive(edge Edge, headers [][2]string, body []byte, depth uint32, config PluginConfig, log log.Log, ctx wrapper.HttpContext) error {
|
||||
|
||||
var err error
|
||||
// 防止递归次数太多
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
@@ -106,7 +107,7 @@ var (
|
||||
protectionSpace = "MSE Gateway" // 认证失败时,返回响应头 WWW-Authenticate: Basic realm=MSE Gateway
|
||||
)
|
||||
|
||||
func parseGlobalConfig(json gjson.Result, global *BasicAuthConfig, log wrapper.Log) error {
|
||||
func parseGlobalConfig(json gjson.Result, global *BasicAuthConfig, log log.Log) error {
|
||||
// log.Debug("global config")
|
||||
ruleSet = false
|
||||
global.credential2Name = make(map[string]string)
|
||||
|
||||
@@ -19,6 +19,7 @@ package main
|
||||
import (
|
||||
"bot-detect/config"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -34,7 +35,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, botDetectConfig *config.BotDetectConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, botDetectConfig *config.BotDetectConfig, log log.Log) error {
|
||||
log.Debug("parseConfig()")
|
||||
|
||||
if json.Get("blocked_code").Exists() {
|
||||
@@ -81,7 +82,7 @@ func parseConfig(json gjson.Result, botDetectConfig *config.BotDetectConfig, log
|
||||
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, botDetectConfig config.BotDetectConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, botDetectConfig config.BotDetectConfig, log log.Log) types.Action {
|
||||
log.Debug("onHttpRequestHeaders()")
|
||||
//// Get user-agent header
|
||||
ua, err := proxywasm.GetHttpRequestHeader("user-agent")
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -25,7 +27,7 @@ type CacheControlConfig struct {
|
||||
expires string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *CacheControlConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *CacheControlConfig, log log.Log) error {
|
||||
suffix := json.Get("suffix").String()
|
||||
if suffix != "" {
|
||||
parts := strings.Split(suffix, "|")
|
||||
@@ -38,7 +40,7 @@ func parseConfig(json gjson.Result, config *CacheControlConfig, log wrapper.Log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CacheControlConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CacheControlConfig, log log.Log) types.Action {
|
||||
path := ctx.Path()
|
||||
if strings.Contains(path, "?") {
|
||||
path = strings.Split(path, "?")[0]
|
||||
@@ -49,7 +51,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config CacheControlConfig, lo
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config CacheControlConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config CacheControlConfig, log log.Log) types.Action {
|
||||
hit := false
|
||||
if len(config.suffix) == 0 {
|
||||
hit = true
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -31,7 +32,7 @@ type MyConfig struct {
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *MyConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *MyConfig, log log.Log) error {
|
||||
chatgptUri := json.Get("chatgptUri").String()
|
||||
var chatgptHost string
|
||||
if chatgptUri == "" {
|
||||
@@ -90,7 +91,7 @@ const bodyTemplate string = `
|
||||
}
|
||||
`
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config MyConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config MyConfig, log log.Log) types.Action {
|
||||
pairs := strings.SplitN(ctx.Path(), "?", 2)
|
||||
|
||||
if len(pairs) < 2 {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -64,7 +65,7 @@ type LimitContext struct {
|
||||
reset int
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log log.Log) error {
|
||||
err := initRedisClusterClient(json, config)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -76,7 +77,7 @@ func parseConfig(json gjson.Result, config *ClusterKeyRateLimitConfig, log wrapp
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action {
|
||||
// 判断是否命中限流规则
|
||||
val, ruleItem, configItem := checkRequestAgainstLimitRule(ctx, config.ruleItems, log)
|
||||
if ruleItem == nil || configItem == nil {
|
||||
@@ -115,7 +116,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, log log.Log) types.Action {
|
||||
limitContext, ok := ctx.GetContext(LimitContextKey).(LimitContext)
|
||||
if !ok {
|
||||
return types.ActionContinue
|
||||
@@ -127,7 +128,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCo
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
for _, rule := range ruleItems {
|
||||
val, ruleItem, configItem := hitRateRuleItem(ctx, rule, log)
|
||||
if ruleItem != nil && configItem != nil {
|
||||
@@ -137,7 +138,7 @@ func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRule
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log log.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
switch rule.limitType {
|
||||
// 根据HTTP请求头限流
|
||||
case limitByHeaderType, limitByPerHeaderType:
|
||||
@@ -196,7 +197,7 @@ func hitRateRuleItem(ctx wrapper.HttpContext, rule LimitRuleItem, log wrapper.Lo
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func logDebugAndReturnEmpty(log wrapper.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
func logDebugAndReturnEmpty(log log.Log, errMsg string, args ...interface{}) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
log.Debugf(errMsg, args...)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -34,7 +35,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, corsConfig *config.CorsConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, corsConfig *config.CorsConfig, log log.Log) error {
|
||||
log.Debug("parseConfig()")
|
||||
allowOrigins := json.Get("allow_origins").Array()
|
||||
for _, origin := range allowOrigins {
|
||||
@@ -71,7 +72,7 @@ func parseConfig(json gjson.Result, corsConfig *config.CorsConfig, log wrapper.L
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig, log log.Log) types.Action {
|
||||
log.Debug("onHttpRequestHeaders()")
|
||||
requestUrl, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
method, _ := proxywasm.GetHttpRequestHeader(":method")
|
||||
@@ -109,7 +110,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig,
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, corsConfig config.CorsConfig, log log.Log) types.Action {
|
||||
log.Debug("onHttpResponseHeaders()")
|
||||
// Remove trace header if existed
|
||||
proxywasm.RemoveHttpResponseHeader(config.HeaderPluginTrace)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -43,7 +44,7 @@ type CustomResponseConfig struct {
|
||||
contentType string
|
||||
}
|
||||
|
||||
func parseConfig(gjson gjson.Result, config *CustomResponseConfig, log wrapper.Log) error {
|
||||
func parseConfig(gjson gjson.Result, config *CustomResponseConfig, log log.Log) error {
|
||||
headersArray := gjson.Get("headers").Array()
|
||||
config.headers = make([][2]string, 0, len(headersArray))
|
||||
for _, v := range headersArray {
|
||||
@@ -96,7 +97,7 @@ func parseConfig(gjson gjson.Result, config *CustomResponseConfig, log wrapper.L
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CustomResponseConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CustomResponseConfig, log log.Log) types.Action {
|
||||
if len(config.enableOnStatus) != 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -108,7 +109,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config CustomResponseConfig,
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config CustomResponseConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config CustomResponseConfig, log log.Log) types.Action {
|
||||
// enableOnStatus is not empty, compare the status code.
|
||||
// if match the status code, mock the response.
|
||||
statusCodeStr, err := proxywasm.GetHttpResponseHeader(":status")
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"de-graphql/config"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -37,7 +38,7 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *config.DeGraphQLConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *config.DeGraphQLConfig, log log.Log) error {
|
||||
log.Debug("parseConfig()")
|
||||
gql := json.Get("gql").String()
|
||||
endpoint := json.Get("endpoint").String()
|
||||
@@ -57,7 +58,7 @@ func parseConfig(json gjson.Result, config *config.DeGraphQLConfig, log wrapper.
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.DeGraphQLConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.DeGraphQLConfig, log log.Log) types.Action {
|
||||
log.Debug("onHttpRequestHeaders()")
|
||||
log.Debugf("schema:%s host:%s path:%s", ctx.Scheme(), ctx.Host(), ctx.Path())
|
||||
requestUrl, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
@@ -102,17 +103,17 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.DeGraphQLConfig
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config config.DeGraphQLConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config config.DeGraphQLConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debug("onHttpRequestBody()")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.DeGraphQLConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.DeGraphQLConfig, log log.Log) types.Action {
|
||||
log.Debug("onHttpResponseHeaders()")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config config.DeGraphQLConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config config.DeGraphQLConfig, body []byte, log log.Log) types.Action {
|
||||
log.Debug("onHttpResponseBody()")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"ext-auth/expr"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -56,7 +58,7 @@ type AuthorizationResponse struct {
|
||||
AllowedClientHeaders expr.Matcher
|
||||
}
|
||||
|
||||
func ParseConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) error {
|
||||
func ParseConfig(json gjson.Result, config *ExtAuthConfig, log log.Log) error {
|
||||
httpServiceConfig := json.Get("http_service")
|
||||
if !httpServiceConfig.Exists() {
|
||||
return errors.New("missing http_service in config")
|
||||
@@ -88,7 +90,7 @@ func ParseConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) error {
|
||||
func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig, log log.Log) error {
|
||||
var httpService HttpService
|
||||
|
||||
if err := parseEndpointConfig(json, &httpService, log); err != nil {
|
||||
@@ -114,7 +116,7 @@ func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig, log wrappe
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseEndpointConfig(json gjson.Result, httpService *HttpService, log wrapper.Log) error {
|
||||
func parseEndpointConfig(json gjson.Result, httpService *HttpService, log log.Log) error {
|
||||
endpointMode := json.Get("endpoint_mode").String()
|
||||
if endpointMode == "" {
|
||||
endpointMode = EndpointModeEnvoy
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"ext-auth/expr"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"ext-auth/config"
|
||||
"ext-auth/util"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -50,7 +51,7 @@ const (
|
||||
HeaderXForwardedHost = "x-forwarded-host"
|
||||
)
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.ExtAuthConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.ExtAuthConfig, log log.Log) types.Action {
|
||||
// If the request's domain and path match the MatchRules, skip authentication
|
||||
if config.MatchRules.IsAllowedByMode(ctx.Host(), ctx.Method(), wrapper.GetRequestPathWithoutQuery()) {
|
||||
ctx.DontReadRequestBody()
|
||||
@@ -73,14 +74,14 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.ExtAuthConfig,
|
||||
return checkExtAuth(ctx, config, nil, log, types.HeaderStopAllIterationAndWatermark)
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config config.ExtAuthConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config config.ExtAuthConfig, body []byte, log log.Log) types.Action {
|
||||
if config.HttpService.AuthorizationRequest.WithRequestBody {
|
||||
return checkExtAuth(ctx, config, body, log, types.DataStopIterationAndBuffer)
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func checkExtAuth(ctx wrapper.HttpContext, cfg config.ExtAuthConfig, body []byte, log wrapper.Log, pauseAction types.Action) types.Action {
|
||||
func checkExtAuth(ctx wrapper.HttpContext, cfg config.ExtAuthConfig, body []byte, log log.Log, pauseAction types.Action) types.Action {
|
||||
httpServiceConfig := cfg.HttpService
|
||||
|
||||
extAuthReqHeaders := buildExtAuthRequestHeaders(ctx, cfg)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/frontend-gray/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/frontend-gray/util"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -26,14 +27,14 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, grayConfig *config.GrayConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, grayConfig *config.GrayConfig, log log.Log) error {
|
||||
// 解析json 为GrayConfig
|
||||
config.JsonToGrayConfig(json, grayConfig)
|
||||
log.Infof("Rewrite: %v, GrayDeployments: %v", json.Get("rewrite"), json.Get("grayDeployments"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, grayConfig config.GrayConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, grayConfig config.GrayConfig, log log.Log) types.Action {
|
||||
requestPath, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
requestPath = path.Clean(requestPath)
|
||||
parsedURL, err := url.Parse(requestPath)
|
||||
@@ -129,7 +130,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, grayConfig config.GrayConfig,
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeader(ctx wrapper.HttpContext, grayConfig config.GrayConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeader(ctx wrapper.HttpContext, grayConfig config.GrayConfig, log log.Log) types.Action {
|
||||
enabledGray, _ := ctx.GetContext(config.EnabledGray).(bool)
|
||||
if !enabledGray {
|
||||
ctx.DontReadResponseBody()
|
||||
@@ -213,7 +214,7 @@ func onHttpResponseHeader(ctx wrapper.HttpContext, grayConfig config.GrayConfig,
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, grayConfig config.GrayConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, grayConfig config.GrayConfig, body []byte, log log.Log) types.Action {
|
||||
enabledGray, _ := ctx.GetContext(config.EnabledGray).(bool)
|
||||
if !enabledGray {
|
||||
return types.ActionContinue
|
||||
|
||||
@@ -2,9 +2,11 @@ module higress/plugins/wasm-go/extensions/geo-ip
|
||||
|
||||
go 1.19
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2 h1:gH7OIGXm4wtW5Vo7L2deMPqF7OVWNESDHv1CaaTGu6s=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2/go.mod h1:359don/ahMxpfeLMzr29Cjwcu8IywTTDUzWlBPRNLHw=
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56 h1:Wi5Tgn8K+jDcBYL+dIMS1+qXYH2r7tpRAyBgqrWfQtw=
|
||||
github.com/asergeyev/nradix v0.0.0-20170505151046-3872ab85bb56/go.mod h1:8BhOLuqtSuT5NZtZMwfvEibi09RO3u79uqfHZzfDTR4=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -8,8 +6,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -56,7 +57,7 @@ type GeoIpData struct {
|
||||
Isp string `json:"isp"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *GeoIpConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *GeoIpConfig, log log.Log) error {
|
||||
sourceType := json.Get("ip_source_type")
|
||||
if sourceType.Exists() && sourceType.String() != "" {
|
||||
switch sourceType.String() {
|
||||
@@ -104,7 +105,7 @@ func parseConfig(json gjson.Result, config *GeoIpConfig, log wrapper.Log) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadGeoIpDataToRdxtree(log wrapper.Log) error {
|
||||
func ReadGeoIpDataToRdxtree(log log.Log) error {
|
||||
GeoIpRdxTree = iptree.New()
|
||||
|
||||
//eg., cidr country province city isp
|
||||
@@ -141,7 +142,7 @@ func ReadGeoIpDataToRdxtree(log wrapper.Log) error {
|
||||
}
|
||||
|
||||
// search geodata using client ip in radixtree.
|
||||
func SearchGeoIpDataInRdxtree(ip string, log wrapper.Log) (*GeoIpData, error) {
|
||||
func SearchGeoIpDataInRdxtree(ip string, log log.Log) (*GeoIpData, error) {
|
||||
val, found, err := GeoIpRdxTree.GetByString(ip)
|
||||
if err != nil {
|
||||
log.Errorf("search geo ip data in raditree failed. %v %s", err, ip)
|
||||
@@ -196,7 +197,7 @@ func isInternalIp(ip string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config GeoIpConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config GeoIpConfig, log log.Log) types.Action {
|
||||
var (
|
||||
s string
|
||||
err error
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -23,7 +24,7 @@ type MyConfig struct {
|
||||
set_header []gjson.Result
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *MyConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *MyConfig, log log.Log) error {
|
||||
config.set_header = json.Get("set_header").Array()
|
||||
config.rules = json.Get("rules").Array()
|
||||
for _, item := range config.rules {
|
||||
|
||||
@@ -17,10 +17,10 @@ package main
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -33,7 +33,7 @@ func main() {
|
||||
type HelloWorldConfig struct {
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config HelloWorldConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config HelloWorldConfig, log log.Log) types.Action {
|
||||
err := proxywasm.AddHttpRequestHeader("hello", "world")
|
||||
if err != nil {
|
||||
log.Critical("failed to set request header")
|
||||
|
||||
@@ -19,11 +19,11 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -41,7 +41,7 @@ type HttpCallConfig struct {
|
||||
tokenHeader string
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *HttpCallConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *HttpCallConfig, log log.Log) error {
|
||||
config.bodyHeader = json.Get("bodyHeader").String()
|
||||
if config.bodyHeader == "" {
|
||||
return errors.New("missing bodyHeader in config")
|
||||
@@ -96,7 +96,7 @@ func parseConfig(json gjson.Result, config *HttpCallConfig, log wrapper.Log) err
|
||||
}
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config HttpCallConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config HttpCallConfig, log log.Log) types.Action {
|
||||
config.client.Get(config.requestPath, nil,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
defer proxywasm.ResumeHttpRequest()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -38,7 +39,7 @@ func main() {
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders))
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *RestrictionConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *RestrictionConfig, log log.Log) error {
|
||||
sourceType := json.Get("ip_source_type")
|
||||
if sourceType.Exists() && sourceType.String() != "" {
|
||||
switch sourceType.String() {
|
||||
@@ -117,7 +118,7 @@ func getDownStreamIp(config RestrictionConfig) (net.IP, error) {
|
||||
return realIP, nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(context wrapper.HttpContext, config RestrictionConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(context wrapper.HttpContext, config RestrictionConfig, log log.Log) types.Action {
|
||||
realIp, err := getDownStreamIp(config)
|
||||
if err != nil {
|
||||
return deniedUnauthorized(config, "get_ip_failed")
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -28,7 +28,7 @@ var RuleSet bool
|
||||
|
||||
// ParseGlobalConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
|
||||
// 此处解析的是全局配置,域名和路由级配置由 ParseRuleConfig 负责。
|
||||
func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log wrapper.Log) error {
|
||||
func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log log.Log) error {
|
||||
RuleSet = false
|
||||
consumers := json.Get("consumers")
|
||||
if !consumers.IsArray() {
|
||||
@@ -53,7 +53,7 @@ func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log wrapper.Log
|
||||
|
||||
// ParseRuleConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
|
||||
// 此处解析的是域名和路由级配置,全局配置由 ParseConfig 负责。
|
||||
func ParseRuleConfig(json gjson.Result, global JWTAuthConfig, config *JWTAuthConfig, log wrapper.Log) error {
|
||||
func ParseRuleConfig(json gjson.Result, global JWTAuthConfig, config *JWTAuthConfig, log log.Log) error {
|
||||
// override config via global
|
||||
*config = global
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/jwt-auth/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
@@ -36,7 +37,7 @@ import (
|
||||
// - 若有至少一个 domain/route 配置该插件:则遵循 (2*)
|
||||
//
|
||||
// https://github.com/alibaba/higress/blob/e09edff827b94fa5bcc149bbeadc905361100c2a/plugins/wasm-go/extensions/basic-auth/main.go#L191
|
||||
func OnHTTPRequestHeaders(ctx wrapper.HttpContext, config cfg.JWTAuthConfig, log wrapper.Log) types.Action {
|
||||
func OnHTTPRequestHeaders(ctx wrapper.HttpContext, config cfg.JWTAuthConfig, log log.Log) types.Action {
|
||||
var (
|
||||
noAllow = len(config.Allow) == 0 // 未配置 allow 列表,表示插件在该 domain/route 未生效
|
||||
globalAuthNoSet = config.GlobalAuthCheck() == cfg.GlobalAuthNoSet
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -127,7 +128,7 @@ type KeyAuthConfig struct {
|
||||
credential2Name map[string]string `yaml:"-"`
|
||||
}
|
||||
|
||||
func parseGlobalConfig(json gjson.Result, global *KeyAuthConfig, log wrapper.Log) error {
|
||||
func parseGlobalConfig(json gjson.Result, global *KeyAuthConfig, log log.Log) error {
|
||||
log.Debug("global config")
|
||||
|
||||
// init
|
||||
@@ -200,7 +201,7 @@ func parseGlobalConfig(json gjson.Result, global *KeyAuthConfig, log wrapper.Log
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseOverrideRuleConfig(json gjson.Result, global KeyAuthConfig, config *KeyAuthConfig, log wrapper.Log) error {
|
||||
func parseOverrideRuleConfig(json gjson.Result, global KeyAuthConfig, config *KeyAuthConfig, log log.Log) error {
|
||||
log.Debug("domain/route config")
|
||||
|
||||
*config = global
|
||||
@@ -233,7 +234,7 @@ func parseOverrideRuleConfig(json gjson.Result, global KeyAuthConfig, config *Ke
|
||||
// - global_auth 未设置:
|
||||
// - 若没有一个 domain/route 配置该插件:则遵循 (1*)
|
||||
// - 若有至少一个 domain/route 配置该插件:则遵循 (2*)
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config KeyAuthConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config KeyAuthConfig, log log.Log) types.Action {
|
||||
var (
|
||||
noAllow = len(config.allow) == 0 // 未配置 allow 列表,表示插件在该 domain/route 未生效
|
||||
globalAuthNoSet = config.globalAuth == nil
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
oidc "github.com/higress-group/oauth2-proxy"
|
||||
"github.com/higress-group/oauth2-proxy/pkg/apis/options"
|
||||
"github.com/higress-group/oauth2-proxy/pkg/util"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -35,7 +35,7 @@ type PluginConfig struct {
|
||||
}
|
||||
|
||||
// 在控制台插件配置中填写的yaml配置会自动转换为json,此处直接从json这个参数里解析配置即可
|
||||
func parseConfig(json gjson.Result, config *PluginConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *PluginConfig, log log.Log) error {
|
||||
oidc.SetLogger(log)
|
||||
opts, err := oidc.LoadOptions(json)
|
||||
if err != nil {
|
||||
@@ -55,7 +55,7 @@ func parseConfig(json gjson.Result, config *PluginConfig, log wrapper.Log) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
config.oidcHandler.SetContext(ctx)
|
||||
req := getHttpRequest()
|
||||
rw := util.NewRecorder()
|
||||
@@ -77,7 +77,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log log.Log) types.Action {
|
||||
value := ctx.GetContext(oidc.SetCookieHeader)
|
||||
if value != nil {
|
||||
proxywasm.AddHttpResponseHeader(oidc.SetCookieHeader, value.(string))
|
||||
|
||||
@@ -17,7 +17,7 @@ package main
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
func TestConfig(t *testing.T) {
|
||||
json := gjson.Result{Type: gjson.JSON, Raw: `{"serviceSource": "k8s","serviceName": "opa","servicePort": 8181,"namespace": "example1","policy": "example1","timeout": "5s"}`}
|
||||
config := &OpaConfig{}
|
||||
assert.NoError(t, parseConfig(json, config, wrapper.Log{}))
|
||||
assert.NoError(t, parseConfig(json, config, log.Log{}))
|
||||
assert.Equal(t, config.policy, "example1")
|
||||
assert.Equal(t, config.timeout, uint32(5000))
|
||||
assert.NotNil(t, config.client)
|
||||
@@ -45,6 +45,6 @@ func TestConfig(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
json = gjson.Result{Type: gjson.JSON, Raw: test.raw}
|
||||
assert.Equal(t, parseConfig(json, config, wrapper.Log{}) == nil, test.result)
|
||||
assert.Equal(t, parseConfig(json, config, log.Log{}) == nil, test.result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
@@ -42,7 +43,7 @@ type Metadata struct {
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *OpaConfig, log wrapper.Log) error {
|
||||
func parseConfig(json gjson.Result, config *OpaConfig, log log.Log) error {
|
||||
policy := json.Get("policy").String()
|
||||
if strings.TrimSpace(policy) == "" {
|
||||
return errors.New("policy not allow empty")
|
||||
@@ -76,15 +77,15 @@ func parseConfig(json gjson.Result, config *OpaConfig, log wrapper.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config OpaConfig, log wrapper.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config OpaConfig, log log.Log) types.Action {
|
||||
return opaCall(ctx, config, nil, log)
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config OpaConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config OpaConfig, body []byte, log log.Log) types.Action {
|
||||
return opaCall(ctx, config, body, log)
|
||||
}
|
||||
|
||||
func opaCall(ctx wrapper.HttpContext, config OpaConfig, body []byte, log wrapper.Log) types.Action {
|
||||
func opaCall(ctx wrapper.HttpContext, config OpaConfig, body []byte, log log.Log) types.Action {
|
||||
request := make(map[string]interface{}, 6)
|
||||
headers, _ := proxywasm.GetHttpRequestHeaders()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
@@ -25,7 +26,7 @@ type RedisConfig struct {
|
||||
KeyPrefix string
|
||||
}
|
||||
|
||||
func ParseConfig(json gjson.Result, config *ReplayProtectionConfig, log wrapper.Log) error {
|
||||
func ParseConfig(json gjson.Result, config *ReplayProtectionConfig, log log.Log) error {
|
||||
// Parse Redis configuration
|
||||
redisConfig := json.Get("redis")
|
||||
if !redisConfig.Exists() {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user