From 04b12529eb5a1fe5c4bc13fe89828fbdc403bc54 Mon Sep 17 00:00:00 2001 From: "Chang, Hui-Tang" Date: Sun, 3 Nov 2024 20:52:40 +0800 Subject: [PATCH] feat: add `condition` field support for iterator (#803) Because - Previously, we did not support a `condition` field for the iterator. This commit: - Adds `condition` field support for the iterator. --- pkg/data/struct.go | 8 ++++++ pkg/worker/workflow.go | 59 ++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/pkg/data/struct.go b/pkg/data/struct.go index 15af2ca3c..b4e661d68 100644 --- a/pkg/data/struct.go +++ b/pkg/data/struct.go @@ -84,8 +84,16 @@ func unmarshalValue(val format.Value, field reflect.Value, structField reflect.S case format.String: return unmarshalString(v, field) case Array: + if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) { + field.Set(reflect.ValueOf(v)) + return nil + } return unmarshalArray(v, field) case Map: + if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) { + field.Set(reflect.ValueOf(v)) + return nil + } return unmarshalMap(v, field) case format.Null: if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) { diff --git a/pkg/worker/workflow.go b/pkg/worker/workflow.go index 0642ecab5..43e51a411 100644 --- a/pkg/worker/workflow.go +++ b/pkg/worker/workflow.go @@ -82,6 +82,7 @@ type PreIteratorActivityParam struct { WorkflowID string ID string UpstreamIDs []string + Condition string Input string Range any Index string @@ -90,12 +91,13 @@ type PreIteratorActivityParam struct { type PreIteratorActivityResult struct { ChildWorkflowIDs []string - ElementSize []int + ConditionMap map[int]int } type PostIteratorActivityParam struct { WorkflowID string ID string + ConditionMap map[int]int OutputElements map[string]string SystemVariables recipe.SystemVariables } @@ -306,7 +308,7 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip futureArgs = append(futureArgs, args) case datamodel.Iterator: - // TODO tillknuesting: support intermediate result streaming for Iterator + // TODO: support intermediate result streaming for Iterator preIteratorResult := &PreIteratorActivityResult{} if err = workflow.ExecuteActivity(ctx, w.PreIteratorActivity, &PreIteratorActivityParam{ @@ -320,17 +322,16 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip return "" }(comp), Range: comp.Range, + Condition: comp.Condition, Index: comp.Index, SystemVariables: param.SystemVariables, }).Get(ctx, &preIteratorResult); err != nil { - if err != nil { - errs = append(errs, err) - continue - } + errs = append(errs, err) + continue } itFutures := []workflow.Future{} - for iter := range dagData.BatchSize { + for iter := range preIteratorResult.ConditionMap { childWorkflowOptions := workflow.ChildWorkflowOptions{ TaskQueue: TaskQueue, WorkflowID: preIteratorResult.ChildWorkflowIDs[iter], @@ -352,7 +353,7 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip // IsStreaming: param.IsStreaming, })) } - for iter := 0; iter < dagData.BatchSize; iter++ { + for iter := 0; iter < len(itFutures); iter++ { err = itFutures[iter].Get(ctx, nil) if err != nil { errs = append(errs, err) @@ -363,6 +364,7 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip if err = workflow.ExecuteActivity(ctx, w.PostIteratorActivity, &PostIteratorActivityParam{ WorkflowID: workflowID, ID: compID, + ConditionMap: preIteratorResult.ConditionMap, OutputElements: comp.OutputElements, SystemVariables: param.SystemVariables, }).Get(ctx, nil); err != nil { @@ -614,20 +616,21 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi if err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } - - result := &PreIteratorActivityResult{ - ElementSize: make([]int, wfm.GetBatchSize()), + conditionMap, err := w.processCondition(ctx, wfm, param.ID, param.UpstreamIDs, param.Condition) + if err != nil { + return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } - batchSize := wfm.GetBatchSize() - childWorkflowIDs := make([]string, batchSize) + result := &PreIteratorActivityResult{} + + childWorkflowIDs := make([]string, len(conditionMap)) - for iter := range wfm.GetBatchSize() { - if err = wfm.SetComponentStatus(ctx, iter, param.ID, memory.ComponentStatusStarted, true); err != nil { + for idx, originalIdx := range conditionMap { + if err = wfm.SetComponentStatus(ctx, originalIdx, param.ID, memory.ComponentStatusStarted, true); err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } - childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, iter, constant.SegComponent, param.ID, constant.SegIteration) - childWorkflowIDs[iter] = childWorkflowID + childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, originalIdx, constant.SegComponent, param.ID, constant.SegIteration) + childWorkflowIDs[idx] = childWorkflowID // If `input` is provided, the iteration will be performed over it; // otherwise, the iteration will be based on the `range` setup. @@ -636,7 +639,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi var indexes []int var elems []format.Value if useInput { - input, err := recipe.Render(ctx, data.NewString(param.Input), iter, wfm, false) + input, err := recipe.Render(ctx, data.NewString(param.Input), originalIdx, wfm, false) if err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } @@ -679,7 +682,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi return nil, componentActivityError(ctx, wfm, fmt.Errorf("iterator range error"), preIteratorActivityErrorType, param.ID) } - renderedRangeParam, err := recipe.Render(ctx, rangeParam, iter, wfm, false) + renderedRangeParam, err := recipe.Render(ctx, rangeParam, originalIdx, wfm, false) if err != nil { return nil, err } @@ -750,12 +753,11 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi } } - result.ElementSize[iter] = len(indexes) iteratorRecipe := &datamodel.Recipe{ Component: wfm.GetRecipe().Component[param.ID].Component, } - childWFM, err := w.memoryStore.NewWorkflowMemory(ctx, childWorkflowIDs[iter], iteratorRecipe, len(indexes)) + childWFM, err := w.memoryStore.NewWorkflowMemory(ctx, childWorkflowIDs[idx], iteratorRecipe, len(indexes)) if err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } @@ -786,11 +788,11 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi } for e, rangeIndex := range indexes { - variable, err := wfm.Get(ctx, iter, constant.SegVariable) + variable, err := wfm.Get(ctx, originalIdx, constant.SegVariable) if err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } - secret, err := wfm.Get(ctx, iter, constant.SegSecret) + secret, err := wfm.Get(ctx, originalIdx, constant.SegSecret) if err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } @@ -804,7 +806,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi } for _, id := range param.UpstreamIDs { - component, err := wfm.Get(ctx, iter, id) + component, err := wfm.Get(ctx, originalIdx, id) if err != nil { return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID) } @@ -836,6 +838,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi } result.ChildWorkflowIDs = childWorkflowIDs + result.ConditionMap = conditionMap logger.Info("PreIteratorActivity completed") return result, nil } @@ -850,8 +853,8 @@ func (w *worker) PostIteratorActivity(ctx context.Context, param *PostIteratorAc return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID) } - for iter := range wfm.GetBatchSize() { - childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, iter, constant.SegComponent, param.ID, constant.SegIteration) + for _, originalIdx := range param.ConditionMap { + childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, originalIdx, constant.SegComponent, param.ID, constant.SegIteration) childWFM, err := w.memoryStore.GetWorkflowMemory(ctx, childWorkflowID) if err != nil { return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID) @@ -871,11 +874,11 @@ func (w *worker) PostIteratorActivity(ctx context.Context, param *PostIteratorAc } output[k] = elemVals } - if err = wfm.SetComponentData(ctx, iter, param.ID, memory.ComponentDataOutput, output); err != nil { + if err = wfm.SetComponentData(ctx, originalIdx, param.ID, memory.ComponentDataOutput, output); err != nil { return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID) } - if err = wfm.SetComponentStatus(ctx, iter, param.ID, memory.ComponentStatusCompleted, true); err != nil { + if err = wfm.SetComponentStatus(ctx, originalIdx, param.ID, memory.ComponentStatusCompleted, true); err != nil { return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID) } }