mirror of
https://github.com/alibaba/higress.git
synced 2026-03-02 07:30:49 +08:00
qwen bailian compatible bug fix (#1597)
This commit is contained in:
@@ -27,6 +27,7 @@ const (
|
||||
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
|
||||
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
|
||||
qwenBailianPath = "/api/v1/apps"
|
||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
|
||||
qwenTopPMin = 0.000001
|
||||
@@ -71,7 +72,8 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
}
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
if m.config.qwenEnableCompatible {
|
||||
if m.config.IsOriginal() {
|
||||
} else if m.config.qwenEnableCompatible {
|
||||
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
|
||||
} else if apiName == ApiNameChatCompletion {
|
||||
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
|
||||
@@ -762,6 +764,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
|
||||
switch {
|
||||
case strings.Contains(path, qwenChatCompletionPath),
|
||||
strings.Contains(path, qwenMultimodalGenerationPath),
|
||||
strings.Contains(path, qwenBailianPath),
|
||||
strings.Contains(path, qwenCompatiblePath):
|
||||
return ApiNameChatCompletion
|
||||
case strings.Contains(path, qwenTextEmbeddingPath):
|
||||
|
||||
@@ -384,26 +384,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get response headers: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
hdsMap := convertHeaders(headers)
|
||||
if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") {
|
||||
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
|
||||
if statusCode != "200" {
|
||||
log.Debugf("response is not 200, skip response body check")
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetContext("headers", hdsMap)
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
log.Debugf("checking response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||
model := ctx.GetStringContext("requestModel", "unknown")
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
|
||||
@@ -303,39 +303,33 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
|
||||
// fetches the tracing span value from the specified source.
|
||||
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
|
||||
for _, attribute := range config.attributes {
|
||||
var key, value string
|
||||
var err error
|
||||
var key string
|
||||
var value interface{}
|
||||
if source == attribute.ValueSource {
|
||||
key = attribute.Key
|
||||
switch source {
|
||||
case FixedValue:
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value)
|
||||
value = attribute.Value
|
||||
case RequestHeader:
|
||||
if value, err = proxywasm.GetHttpRequestHeader(attribute.Value); err == nil {
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
}
|
||||
value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
|
||||
case RequestBody:
|
||||
raw := gjson.GetBytes(body, attribute.Value).Raw
|
||||
if len(raw) > 2 {
|
||||
value = raw[1 : len(raw)-1]
|
||||
}
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
value = gjson.GetBytes(body, attribute.Value).Value()
|
||||
case ResponseHeader:
|
||||
if value, err = proxywasm.GetHttpResponseHeader(attribute.Value); err == nil {
|
||||
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
}
|
||||
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
|
||||
case ResponseStreamingBody:
|
||||
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
|
||||
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
case ResponseBody:
|
||||
value = gjson.GetBytes(body, attribute.Value).String()
|
||||
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
value = gjson.GetBytes(body, attribute.Value).Value()
|
||||
default:
|
||||
}
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
|
||||
if attribute.ApplyToLog {
|
||||
ctx.SetUserAttribute(key, value)
|
||||
}
|
||||
// for metrics
|
||||
if key == Model || key == InputToken || key == OutputToken {
|
||||
ctx.SetContext(key, value)
|
||||
}
|
||||
if attribute.ApplyToSpan {
|
||||
setSpanAttribute(key, value, log)
|
||||
}
|
||||
@@ -343,14 +337,14 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
|
||||
}
|
||||
}
|
||||
|
||||
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string {
|
||||
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
var value string
|
||||
var value interface{}
|
||||
if rule == RuleFirst {
|
||||
for _, chunk := range chunks {
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
value = jsonObj.String()
|
||||
value = jsonObj.Value()
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -358,17 +352,19 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
|
||||
for _, chunk := range chunks {
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
value = jsonObj.String()
|
||||
value = jsonObj.Value()
|
||||
}
|
||||
}
|
||||
} else if rule == RuleAppend {
|
||||
// extract llm response
|
||||
var strValue string
|
||||
for _, chunk := range chunks {
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
value += jsonObj.String()
|
||||
strValue += jsonObj.String()
|
||||
}
|
||||
}
|
||||
value = strValue
|
||||
} else {
|
||||
log.Errorf("unsupported rule type: %s", rule)
|
||||
}
|
||||
@@ -376,10 +372,10 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
|
||||
}
|
||||
|
||||
// Set the tracing span with value.
|
||||
func setSpanAttribute(key, value string, log wrapper.Log) {
|
||||
func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
|
||||
if value != "" {
|
||||
traceSpanTag := wrapper.TraceSpanTagPrefix + key
|
||||
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil {
|
||||
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
|
||||
log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
|
||||
}
|
||||
} else {
|
||||
@@ -388,36 +384,84 @@ func setSpanAttribute(key, value string, log wrapper.Log) {
|
||||
}
|
||||
|
||||
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
|
||||
route := ctx.GetContext(RouteName).(string)
|
||||
cluster := ctx.GetContext(ClusterName).(string)
|
||||
// Generate usage metrics
|
||||
var model string
|
||||
var inputToken, outputToken int64
|
||||
var ok bool
|
||||
var route, cluster, model string
|
||||
var inputToken, outputToken uint64
|
||||
route, ok = ctx.GetContext(RouteName).(string)
|
||||
if !ok {
|
||||
log.Warnf("RouteName typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
cluster, ok = ctx.GetContext(ClusterName).(string)
|
||||
if !ok {
|
||||
log.Warnf("ClusterName typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
|
||||
log.Warnf("get usage information failed, skip metric record")
|
||||
return
|
||||
}
|
||||
model = ctx.GetUserAttribute(Model).(string)
|
||||
inputToken = ctx.GetUserAttribute(InputToken).(int64)
|
||||
outputToken = ctx.GetUserAttribute(OutputToken).(int64)
|
||||
model, ok = ctx.GetUserAttribute(Model).(string)
|
||||
if !ok {
|
||||
log.Warnf("Model typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
|
||||
if !ok {
|
||||
log.Warnf("InputToken typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
|
||||
if !ok {
|
||||
log.Warnf("OutputToken typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
if inputToken == 0 || outputToken == 0 {
|
||||
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), uint64(inputToken))
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), uint64(outputToken))
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
|
||||
|
||||
// Generate duration metrics
|
||||
var llmFirstTokenDuration, llmServiceDuration int64
|
||||
var llmFirstTokenDuration, llmServiceDuration uint64
|
||||
// Is stream response
|
||||
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
|
||||
llmFirstTokenDuration = ctx.GetUserAttribute(LLMFirstTokenDuration).(int64)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), uint64(llmFirstTokenDuration))
|
||||
llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
|
||||
if !ok {
|
||||
log.Warnf("LLMFirstTokenDuration typd assert failed")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
|
||||
}
|
||||
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
|
||||
llmServiceDuration = ctx.GetUserAttribute(LLMServiceDuration).(int64)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), uint64(llmServiceDuration))
|
||||
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
|
||||
if !ok {
|
||||
log.Warnf("LLMServiceDuration typd assert failed")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func convertToUInt(val interface{}) (uint64, bool) {
|
||||
switch v := val.(type) {
|
||||
case float32:
|
||||
return uint64(v), true
|
||||
case float64:
|
||||
return uint64(v), true
|
||||
case int32:
|
||||
return uint64(v), true
|
||||
case int64:
|
||||
return uint64(v), true
|
||||
case uint32:
|
||||
return uint64(v), true
|
||||
case uint64:
|
||||
return v, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user