mirror of
https://github.com/alibaba/higress.git
synced 2026-06-10 05:07:30 +08:00
feat: Support statistics images/audio/responses API Token usage (#2542)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -27,11 +26,10 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/arxiv"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-search/engine/bing"
|
||||
@@ -86,16 +84,16 @@ func main() {}
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"ai-search",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config, log log.Log) error {
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.defaultEnable = true // Default to true if not specified
|
||||
if json.Get("defaultEnable").Exists() {
|
||||
config.defaultEnable = json.Get("defaultEnable").Bool()
|
||||
@@ -279,7 +277,7 @@ func parseConfig(json gjson.Result, config *Config, log log.Log) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||||
// The request does not have a body.
|
||||
if contentType == "" {
|
||||
@@ -296,7 +294,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config, log log.Log) t
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
// Check if plugin should be enabled based on config and request
|
||||
webSearchOptions := gjson.GetBytes(body, "web_search_options")
|
||||
if !config.defaultEnable {
|
||||
@@ -437,7 +435,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts, log) {
|
||||
if types.ActionContinue == executeSearch(ctx, config, queryIndex, body, searchContexts) {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}, searchRewrite.timeoutMillisecond)
|
||||
@@ -453,10 +451,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
||||
return executeSearch(ctx, config, queryIndex, body, []engine.SearchContext{{
|
||||
Querys: []string{query},
|
||||
Language: config.defaultLanguage,
|
||||
}}, log)
|
||||
}})
|
||||
}
|
||||
|
||||
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext, log log.Log) types.Action {
|
||||
func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body []byte, searchContexts []engine.SearchContext) types.Action {
|
||||
searchResultGroups := make([][]engine.SearchResult, len(config.engine))
|
||||
var finished int
|
||||
var searching int
|
||||
@@ -559,7 +557,7 @@ func executeSearch(ctx wrapper.HttpContext, config Config, queryIndex int, body
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log) types.Action {
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
if !config.needReference {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
@@ -576,7 +574,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config, log log.Log)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log log.Log) types.Action {
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
references := ctx.GetStringContext("References", "")
|
||||
if references == "" {
|
||||
return types.ActionContinue
|
||||
@@ -618,19 +616,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte, log
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func unifySSEChunk(data []byte) []byte {
|
||||
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
|
||||
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
|
||||
return data
|
||||
}
|
||||
|
||||
const (
|
||||
PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage"
|
||||
BUFFER_CONTENT_CONTEXT_KEY = "bufferContent"
|
||||
BUFFER_SIZE = 30
|
||||
)
|
||||
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool, log log.Log) []byte {
|
||||
func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byte, isLastChunk bool) []byte {
|
||||
if ctx.GetBoolContext("ReferenceAppended", false) {
|
||||
return chunk
|
||||
}
|
||||
@@ -638,7 +630,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
||||
if references == "" {
|
||||
return chunk
|
||||
}
|
||||
chunk = unifySSEChunk(chunk)
|
||||
chunk = wrapper.UnifySSEChunk(chunk)
|
||||
var partialMessage []byte
|
||||
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
|
||||
log.Debugf("[handleStreamChunk] buffer content: %v", ctx.GetContext(BUFFER_CONTENT_CONTEXT_KEY))
|
||||
@@ -651,7 +643,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
||||
var newMessages []string
|
||||
for i, msg := range messages {
|
||||
if i < len(messages)-1 {
|
||||
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail", log)
|
||||
newMsg := processSSEMessage(ctx, msg, fmt.Sprintf(config.referenceFormat, references), config.referenceLocation == "tail")
|
||||
if newMsg != "" {
|
||||
newMessages = append(newMessages, newMsg)
|
||||
}
|
||||
@@ -669,7 +661,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, config Config, chunk []byt
|
||||
}
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool, log log.Log) string {
|
||||
func processSSEMessage(ctx wrapper.HttpContext, sseMessage string, references string, tailReference bool) string {
|
||||
log.Debugf("single sse message: %s", sseMessage)
|
||||
subMessages := strings.Split(sseMessage, "\n")
|
||||
var message string
|
||||
|
||||
Reference in New Issue
Block a user