mirror of
https://github.com/alibaba/higress.git
synced 2026-02-23 04:00:51 +08:00
308 lines
9.8 KiB
Go
308 lines
9.8 KiB
Go
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package main
|
||
|
||
import (
|
||
ejson "encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"api-workflow/utils"
|
||
. "api-workflow/workflow"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
const (
|
||
DefaultMaxDepth uint32 = 100
|
||
WorkflowExecStatus string = "workflowExecStatus"
|
||
DefaultTimeout uint32 = 5000
|
||
)
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"api-workflow",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
)
|
||
}
|
||
|
||
func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||
|
||
edges := make([]Edge, 0)
|
||
nodes := make(map[string]Node)
|
||
var err error
|
||
// env
|
||
env := json.Get("env")
|
||
// timeout
|
||
c.Env.Timeout = uint32(env.Get("timeout").Int())
|
||
if c.Env.Timeout == 0 {
|
||
c.Env.Timeout = DefaultTimeout
|
||
}
|
||
// max_depth
|
||
c.Env.MaxDepth = uint32(env.Get("max_depth").Int())
|
||
if c.Env.MaxDepth == 0 {
|
||
c.Env.MaxDepth = DefaultMaxDepth
|
||
}
|
||
// workflow
|
||
workflow := json.Get("workflow")
|
||
if !workflow.Exists() {
|
||
return errors.New("workflow is empty")
|
||
}
|
||
// workflow.edges
|
||
edges_ := workflow.Get("edges")
|
||
if edges_.Exists() && edges_.IsArray() {
|
||
for _, w := range edges_.Array() {
|
||
task := Task{}
|
||
edge := Edge{}
|
||
edge.Source = w.Get("source").String()
|
||
if edge.Source == "" {
|
||
return errors.New("source is empty")
|
||
}
|
||
edge.Target = w.Get("target").String()
|
||
if edge.Target == "" {
|
||
return errors.New("target is empty")
|
||
}
|
||
edge.Task = &task
|
||
|
||
edge.Conditional = w.Get("conditional").String()
|
||
edges = append(edges, edge)
|
||
}
|
||
}
|
||
c.Workflow.Edges = edges
|
||
// workflow.nodes
|
||
nodes_ := workflow.Get("nodes")
|
||
if nodes_.Exists() && nodes_.IsArray() {
|
||
for _, value := range nodes_.Array() {
|
||
node := Node{}
|
||
node.Name = value.Get("name").String()
|
||
if node.Name == "" {
|
||
return errors.New("tool name is empty")
|
||
}
|
||
node.ServiceName = value.Get("service_name").String()
|
||
if node.ServiceName == "" {
|
||
return errors.New("tool service name is empty")
|
||
}
|
||
node.ServicePort = value.Get("service_port").Int()
|
||
if node.ServicePort == 0 {
|
||
if strings.HasSuffix(node.ServiceName, ".static") {
|
||
// use default logic port which is 80 for static service
|
||
node.ServicePort = 80
|
||
} else {
|
||
return errors.New("tool service port is empty")
|
||
}
|
||
|
||
}
|
||
node.ServiceDomain = value.Get("service_domain").String()
|
||
node.ServicePath = value.Get("service_path").String()
|
||
if node.ServicePath == "" {
|
||
node.ServicePath = "/"
|
||
}
|
||
node.ServiceMethod = value.Get("service_method").String()
|
||
if node.ServiceMethod == "" {
|
||
return errors.New("service_method is empty")
|
||
}
|
||
serviceHeaders := value.Get("service_headers")
|
||
if serviceHeaders.Exists() && serviceHeaders.IsArray() {
|
||
serviceHeaders_ := []ServiceHeader{}
|
||
err = ejson.Unmarshal([]byte(serviceHeaders.Raw), &serviceHeaders_)
|
||
node.ServiceHeaders = serviceHeaders_
|
||
}
|
||
|
||
node.ServiceBodyTmpl = value.Get("service_body_tmpl").String()
|
||
serviceBodyReplaceKeys := value.Get("service_body_replace_keys")
|
||
if serviceBodyReplaceKeys.Exists() && serviceBodyReplaceKeys.IsArray() {
|
||
serviceBodyReplaceKeys_ := []BodyReplaceKeyPair{}
|
||
err = ejson.Unmarshal([]byte(serviceBodyReplaceKeys.Raw), &serviceBodyReplaceKeys_)
|
||
node.ServiceBodyReplaceKeys = serviceBodyReplaceKeys_
|
||
if err != nil {
|
||
return fmt.Errorf("unmarshal service body replace keys failed, err:%v", err)
|
||
}
|
||
}
|
||
|
||
nodes[node.Name] = node
|
||
}
|
||
c.Workflow.Nodes = nodes
|
||
// workflow.WorkflowExecStatus
|
||
c.Workflow.WorkflowExecStatus, err = initWorkflowExecStatus(c)
|
||
log.Debugf("init status : %v", c.Workflow.WorkflowExecStatus)
|
||
if err != nil {
|
||
log.Errorf("init workflow exec status failed, err:%v", err)
|
||
return fmt.Errorf("init workflow exec status failed, err:%v", err)
|
||
}
|
||
}
|
||
log.Debugf("config : %v", c)
|
||
return nil
|
||
}
|
||
|
||
func initWorkflowExecStatus(config *PluginConfig) (map[string]int, error) {
|
||
result := make(map[string]int)
|
||
|
||
for name, _ := range config.Workflow.Nodes {
|
||
result[name] = 0
|
||
}
|
||
for _, edge := range config.Workflow.Edges {
|
||
|
||
if edge.Source == TaskStart || edge.Target == TaskContinue || edge.Target == TaskEnd {
|
||
continue
|
||
}
|
||
|
||
count, ok := result[edge.Target]
|
||
if !ok {
|
||
return nil, fmt.Errorf("Target %s is not exist in nodes", edge.Target)
|
||
}
|
||
result[edge.Target] = count + 1
|
||
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
|
||
initHeader := make([][2]string, 0)
|
||
// 初始化运行状态
|
||
ctx.SetContext(WorkflowExecStatus, config.Workflow.WorkflowExecStatus)
|
||
|
||
// 执行工作流
|
||
for _, edge := range config.Workflow.Edges {
|
||
|
||
if edge.Source == TaskStart {
|
||
ctx.SetContext(fmt.Sprintf("%s", TaskStart), body)
|
||
err := recursive(edge, initHeader, body, 1, config, log, ctx)
|
||
if err != nil {
|
||
// 工作流处理错误,返回500给用户
|
||
log.Errorf("recursive failed: %v", err)
|
||
_ = utils.SendResponse(500, "api-workflow.recursive_failed", utils.MimeTypeTextPlain, fmt.Sprintf("workflow plugin recursive failed: %v", err))
|
||
|
||
}
|
||
}
|
||
}
|
||
|
||
return types.ActionPause
|
||
}
|
||
|
||
// 放入符合条件的edge
|
||
func recursive(edge Edge, headers [][2]string, body []byte, depth uint32, config PluginConfig, log wrapper.Log, ctx wrapper.HttpContext) error {
|
||
|
||
var err error
|
||
// 防止递归次数太多
|
||
if depth > config.Env.MaxDepth {
|
||
return fmt.Errorf("maximum recursion depth reached")
|
||
}
|
||
|
||
// 判断是不是end
|
||
if edge.IsEnd() {
|
||
log.Debugf("source is %s,target is %s,workflow is end", edge.Source, edge.Target)
|
||
log.Debugf("body is %s", string(body))
|
||
_ = proxywasm.SendHttpResponse(200, headers, body, -1)
|
||
return nil
|
||
}
|
||
// 判断是不是continue
|
||
if edge.IsContinue() {
|
||
log.Debugf("source is %s,target is %s,workflow is continue", edge.Source, edge.Target)
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
return nil
|
||
}
|
||
|
||
// 封装task
|
||
err = edge.WrapperTask(config, ctx)
|
||
if err != nil {
|
||
log.Errorf("workflow exec wrapperTask find error,source is %s,target is %s,error is %v ", edge.Source, edge.Target, err)
|
||
return fmt.Errorf("workflow exec wrapperTask find error,source is %s,target is %s,error is %v ", edge.Source, edge.Target, err)
|
||
}
|
||
|
||
// 执行task
|
||
log.Debugf("workflow exec task,source is %s,target is %s, body is %s,header is %v", edge.Source, edge.Target, string(edge.Task.Body), edge.Task.Headers)
|
||
err = wrapper.HttpCall(edge.Task.Cluster, edge.Task.Method, edge.Task.ServicePath, edge.Task.Headers, edge.Task.Body, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
log.Debugf("code:%d", statusCode)
|
||
// 判断response code
|
||
if statusCode < 400 {
|
||
|
||
// 存入这轮返回的body
|
||
ctx.SetContext(fmt.Sprintf("%s", edge.Target), responseBody)
|
||
|
||
headers_ := make([][2]string, len(responseHeaders))
|
||
for key, value := range responseHeaders {
|
||
headers_ = append(headers_, [2]string{key, value[0]})
|
||
}
|
||
// 判断是否进入下一步
|
||
nextStatus := ctx.GetContext(WorkflowExecStatus).(map[string]int)
|
||
|
||
// 进入下一步
|
||
for _, next := range config.Workflow.Edges {
|
||
if next.Source == edge.Target {
|
||
// 更新workflow status
|
||
if next.Target != TaskContinue && next.Target != TaskEnd {
|
||
|
||
nextStatus[next.Target] = nextStatus[next.Target] - 1
|
||
log.Debugf("source is %s,target is %s,stauts is %v", next.Source, next.Target, nextStatus)
|
||
// 还有没执行完的边
|
||
if nextStatus[next.Target] > 0 {
|
||
ctx.SetContext(WorkflowExecStatus, nextStatus)
|
||
return
|
||
}
|
||
// 执行出了问题
|
||
if nextStatus[next.Target] < 0 {
|
||
log.Errorf("workflow exec status find error %v", nextStatus)
|
||
_ = utils.SendResponse(500, "api-workflow.exec_task_failed", utils.MimeTypeTextPlain, fmt.Sprintf("workflow exec status find error %v", nextStatus))
|
||
return
|
||
}
|
||
}
|
||
// 判断是否执行
|
||
isPass, err2 := next.IsPass(ctx)
|
||
if err2 != nil {
|
||
log.Errorf("check pass find error:%v", err2)
|
||
_ = utils.SendResponse(500, "api-workflow.task_check_paas_failed", utils.MimeTypeTextPlain, fmt.Sprintf("check pass find error:%v", err2))
|
||
return
|
||
}
|
||
if isPass {
|
||
log.Debugf("source is %s,target is %s,workflow is pass ", next.Source, next.Target)
|
||
nextStatus = ctx.GetContext(WorkflowExecStatus).(map[string]int)
|
||
nextStatus[next.Target] = nextStatus[next.Target] - 1
|
||
ctx.SetContext(WorkflowExecStatus, nextStatus)
|
||
continue
|
||
|
||
}
|
||
|
||
// 执行下一步
|
||
err = recursive(next, headers_, responseBody, depth+1, config, log, ctx)
|
||
if err != nil {
|
||
log.Errorf("recursive error:%v", err)
|
||
_ = utils.SendResponse(500, "api-workflow.recursive_failed", utils.MimeTypeTextPlain, fmt.Sprintf("recursive error:%v", err))
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
} else {
|
||
// statusCode >= 400 ,task httpCall执行失败,放行请求,打印错误,结束workflow
|
||
log.Errorf("workflow exec task find error,code is %d,body is %s", statusCode, string(responseBody))
|
||
_ = utils.SendResponse(500, "api-workflow.httpCall_failed", utils.MimeTypeTextPlain, fmt.Sprintf("workflow exec task find error,code is %d,body is %s", statusCode, string(responseBody)))
|
||
}
|
||
return
|
||
|
||
}, config.Env.MaxDepth*config.Env.Timeout)
|
||
if err != nil {
|
||
log.Errorf("httpcall error:%v", err)
|
||
}
|
||
|
||
return err
|
||
}
|