qwen bailian compatible bug fix (#1597)

This commit is contained in:
rinfx
2024-12-17 16:57:31 +08:00
committed by GitHub
parent 2a200cdd42
commit 2f5709a93e
3 changed files with 90 additions and 49 deletions

View File

@@ -27,6 +27,7 @@ const (
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation" qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding" qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatiblePath = "/compatible-mode/v1/chat/completions" qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation" qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenTopPMin = 0.000001 qwenTopPMin = 0.000001
@@ -71,7 +72,8 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
} }
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
if m.config.qwenEnableCompatible { if m.config.IsOriginal() {
} else if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion { } else if apiName == ApiNameChatCompletion {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
@@ -762,6 +764,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
switch { switch {
case strings.Contains(path, qwenChatCompletionPath), case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath), strings.Contains(path, qwenMultimodalGenerationPath),
strings.Contains(path, qwenBailianPath),
strings.Contains(path, qwenCompatiblePath): strings.Contains(path, qwenCompatiblePath):
return ApiNameChatCompletion return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath): case strings.Contains(path, qwenTextEmbeddingPath):

View File

@@ -384,26 +384,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
return types.ActionContinue return types.ActionContinue
} }
headers, err := proxywasm.GetHttpResponseHeaders() statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
if err != nil { if statusCode != "200" {
log.Warnf("failed to get response headers: %v", err)
return types.ActionContinue
}
hdsMap := convertHeaders(headers)
if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") {
log.Debugf("response is not 200, skip response body check") log.Debugf("response is not 200, skip response body check")
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
return types.ActionContinue return types.ActionContinue
} }
ctx.SetContext("headers", hdsMap)
return types.HeaderStopIteration return types.HeaderStopIteration
} }
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking response body...") log.Debugf("checking response body...")
startTime := time.Now().UnixMilli() startTime := time.Now().UnixMilli()
hdsMap := ctx.GetContext("headers").(map[string][]string) contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") isStreamingResponse := strings.Contains(contentType, "event-stream")
model := ctx.GetStringContext("requestModel", "unknown") model := ctx.GetStringContext("requestModel", "unknown")
var content string var content string
if isStreamingResponse { if isStreamingResponse {

View File

@@ -303,39 +303,33 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
// fetches the tracing span value from the specified source. // fetches the tracing span value from the specified source.
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) { func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
for _, attribute := range config.attributes { for _, attribute := range config.attributes {
var key, value string var key string
var err error var value interface{}
if source == attribute.ValueSource { if source == attribute.ValueSource {
key = attribute.Key key = attribute.Key
switch source { switch source {
case FixedValue: case FixedValue:
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value)
value = attribute.Value value = attribute.Value
case RequestHeader: case RequestHeader:
if value, err = proxywasm.GetHttpRequestHeader(attribute.Value); err == nil { value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
}
case RequestBody: case RequestBody:
raw := gjson.GetBytes(body, attribute.Value).Raw value = gjson.GetBytes(body, attribute.Value).Value()
if len(raw) > 2 {
value = raw[1 : len(raw)-1]
}
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
case ResponseHeader: case ResponseHeader:
if value, err = proxywasm.GetHttpResponseHeader(attribute.Value); err == nil { value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
}
case ResponseStreamingBody: case ResponseStreamingBody:
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
case ResponseBody: case ResponseBody:
value = gjson.GetBytes(body, attribute.Value).String() value = gjson.GetBytes(body, attribute.Value).Value()
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
default: default:
} }
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
if attribute.ApplyToLog { if attribute.ApplyToLog {
ctx.SetUserAttribute(key, value) ctx.SetUserAttribute(key, value)
} }
// for metrics
if key == Model || key == InputToken || key == OutputToken {
ctx.SetContext(key, value)
}
if attribute.ApplyToSpan { if attribute.ApplyToSpan {
setSpanAttribute(key, value, log) setSpanAttribute(key, value, log)
} }
@@ -343,14 +337,14 @@ func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, so
} }
} }
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string { func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
var value string var value interface{}
if rule == RuleFirst { if rule == RuleFirst {
for _, chunk := range chunks { for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath) jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() { if jsonObj.Exists() {
value = jsonObj.String() value = jsonObj.Value()
break break
} }
} }
@@ -358,17 +352,19 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
for _, chunk := range chunks { for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath) jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() { if jsonObj.Exists() {
value = jsonObj.String() value = jsonObj.Value()
} }
} }
} else if rule == RuleAppend { } else if rule == RuleAppend {
// extract llm response // extract llm response
var strValue string
for _, chunk := range chunks { for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath) jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() { if jsonObj.Exists() {
value += jsonObj.String() strValue += jsonObj.String()
} }
} }
value = strValue
} else { } else {
log.Errorf("unsupported rule type: %s", rule) log.Errorf("unsupported rule type: %s", rule)
} }
@@ -376,10 +372,10 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
} }
// Set the tracing span with value. // Set the tracing span with value.
func setSpanAttribute(key, value string, log wrapper.Log) { func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
if value != "" { if value != "" {
traceSpanTag := wrapper.TraceSpanTagPrefix + key traceSpanTag := wrapper.TraceSpanTagPrefix + key
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil { if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e) log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
} }
} else { } else {
@@ -388,36 +384,84 @@ func setSpanAttribute(key, value string, log wrapper.Log) {
} }
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
route := ctx.GetContext(RouteName).(string)
cluster := ctx.GetContext(ClusterName).(string)
// Generate usage metrics // Generate usage metrics
var model string var ok bool
var inputToken, outputToken int64 var route, cluster, model string
var inputToken, outputToken uint64
route, ok = ctx.GetContext(RouteName).(string)
if !ok {
log.Warnf("RouteName typd assert failed, skip metric record")
return
}
cluster, ok = ctx.GetContext(ClusterName).(string)
if !ok {
log.Warnf("ClusterName typd assert failed, skip metric record")
return
}
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil { if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
log.Warnf("get usage information failed, skip metric record") log.Warnf("get usage information failed, skip metric record")
return return
} }
model = ctx.GetUserAttribute(Model).(string) model, ok = ctx.GetUserAttribute(Model).(string)
inputToken = ctx.GetUserAttribute(InputToken).(int64) if !ok {
outputToken = ctx.GetUserAttribute(OutputToken).(int64) log.Warnf("Model typd assert failed, skip metric record")
return
}
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
if !ok {
log.Warnf("InputToken typd assert failed, skip metric record")
return
}
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
if !ok {
log.Warnf("OutputToken typd assert failed, skip metric record")
return
}
if inputToken == 0 || outputToken == 0 { if inputToken == 0 || outputToken == 0 {
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record") log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
return return
} }
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), uint64(inputToken)) config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), uint64(outputToken)) config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
// Generate duration metrics // Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration int64 var llmFirstTokenDuration, llmServiceDuration uint64
// Is stream response // Is stream response
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil { if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
llmFirstTokenDuration = ctx.GetUserAttribute(LLMFirstTokenDuration).(int64) llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), uint64(llmFirstTokenDuration)) if !ok {
log.Warnf("LLMFirstTokenDuration typd assert failed")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1) config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
} }
if ctx.GetUserAttribute(LLMServiceDuration) != nil { if ctx.GetUserAttribute(LLMServiceDuration) != nil {
llmServiceDuration = ctx.GetUserAttribute(LLMServiceDuration).(int64) llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), uint64(llmServiceDuration)) if !ok {
log.Warnf("LLMServiceDuration typd assert failed")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1) config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
} }
} }
func convertToUInt(val interface{}) (uint64, bool) {
switch v := val.(type) {
case float32:
return uint64(v), true
case float64:
return uint64(v), true
case int32:
return uint64(v), true
case int64:
return uint64(v), true
case uint32:
return uint64(v), true
case uint64:
return v, true
default:
return 0, false
}
}