feat: Support statistics images/audio/responses API Token usage (#2542)

Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
Xijun Dai
2025-07-16 10:34:09 +08:00
committed by GitHub
parent ce271849de
commit 8346b4a4a2
13 changed files with 193 additions and 262 deletions

View File

@@ -15,7 +15,6 @@
package main
import (
"bytes"
_ "embed"
"errors"
"fmt"
@@ -27,11 +26,10 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
@@ -86,16 +84,16 @@ func main() {}
func init() {
wrapper.SetCtx(
"ai-search",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
)
}
func parseConfig(json gjson.Result, config *Config, log log.Log) error {
func parseConfig(json gjson.Result, config *Config) error {
config.defaultEnable = true // Default to true if not specified
if json.Get("defaultEnable").Exists() {
config.defaultEnable = json.Get("defaultEnable").Bool()
@@ -279,7 +277,7 @@ func parseConfig(json gjson.Result, config *Config, log log.Log) error {
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
// The request does not have a body.
if contentType == "" {
@@ -296,7 +294,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) t
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
// Check if plugin should be enabled based on config and request
webSearchOptions := gjson.GetBytes(body, "web_search_options")
if !config.defaultEnable {
@@ -437,7 +435,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
proxywasm.ResumeHttpRequest()
return
}
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts) {
proxywasm.ResumeHttpRequest()
}
}, searchRewrite.timeoutMillisecond)
@@ -453,10 +451,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
Querys: []string{query},
Language: config.defaultLanguage,
}}, log)
}})
}
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log log.Log) types.Action {
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext) types.Action {
searchResultGroups := make([][]engine.SearchResult, len(config.engine))
var finished int
var searching int
@@ -559,7 +557,7 @@ func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body
return types.ActionContinue
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
if !config.needReference {
ctx.DontReadResponseBody()
return types.ActionContinue
@@ -576,7 +574,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log)
return types.ActionContinue
}
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
references := ctx.GetStringContext("References", "")
if references == "" {
return types.ActionContinue
@@ -618,19 +616,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
return types.ActionContinue
}
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
const (
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
BUFFER_SIZE = 30
)
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log log.Log) []byte {
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool) []byte {
if ctx.GetBoolContext("ReferenceAppended", false) {
return chunk
}
@@ -638,7 +630,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
if references == "" {
return chunk
}
chunk = unifySSEChunk(chunk)
chunk = wrapper.UnifySSEChunk(chunk)
var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
@@ -651,7 +643,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
var newMessages []string
for i, msg := range messages {
if i < len(messages)-1 {
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail")
if newMsg != "" {
newMessages = append(newMessages, newMsg)
}
@@ -669,7 +661,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
}
}
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log log.Log) string {
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool) string {
log.Debugf("single sse message: %s", sseMessage)
subMessages := strings.Split(sseMessage, "\n")
var message string