Skip to content

Commit

Permalink
feat: add condition field support for iterator (#803)
Browse files Browse the repository at this point in the history
Because

- Previously, we did not support a `condition` field for the iterator.

This commit:

- Adds `condition` field support for the iterator.
  • Loading branch information
donch1989 authored Nov 3, 2024
1 parent f207d40 commit 04b1252
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 28 deletions.
8 changes: 8 additions & 0 deletions pkg/data/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,16 @@ func unmarshalValue(val format.Value, field reflect.Value, structField reflect.S
case format.String:
return unmarshalString(v, field)
case Array:
if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) {
field.Set(reflect.ValueOf(v))
return nil
}
return unmarshalArray(v, field)
case Map:
if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) {
field.Set(reflect.ValueOf(v))
return nil
}
return unmarshalMap(v, field)
case format.Null:
if field.Type().Implements(reflect.TypeOf((*format.Value)(nil)).Elem()) {
Expand Down
59 changes: 31 additions & 28 deletions pkg/worker/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type PreIteratorActivityParam struct {
WorkflowID string
ID string
UpstreamIDs []string
Condition string
Input string
Range any
Index string
Expand All @@ -90,12 +91,13 @@ type PreIteratorActivityParam struct {

type PreIteratorActivityResult struct {
ChildWorkflowIDs []string
ElementSize []int
ConditionMap map[int]int
}

type PostIteratorActivityParam struct {
WorkflowID string
ID string
ConditionMap map[int]int
OutputElements map[string]string
SystemVariables recipe.SystemVariables
}
Expand Down Expand Up @@ -306,7 +308,7 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip
futureArgs = append(futureArgs, args)

case datamodel.Iterator:
// TODO tillknuesting: support intermediate result streaming for Iterator
// TODO: support intermediate result streaming for Iterator

preIteratorResult := &PreIteratorActivityResult{}
if err = workflow.ExecuteActivity(ctx, w.PreIteratorActivity, &PreIteratorActivityParam{
Expand All @@ -320,17 +322,16 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip
return ""
}(comp),
Range: comp.Range,
Condition: comp.Condition,
Index: comp.Index,
SystemVariables: param.SystemVariables,
}).Get(ctx, &preIteratorResult); err != nil {
if err != nil {
errs = append(errs, err)
continue
}
errs = append(errs, err)
continue
}

itFutures := []workflow.Future{}
for iter := range dagData.BatchSize {
for iter := range preIteratorResult.ConditionMap {
childWorkflowOptions := workflow.ChildWorkflowOptions{
TaskQueue: TaskQueue,
WorkflowID: preIteratorResult.ChildWorkflowIDs[iter],
Expand All @@ -352,7 +353,7 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip
// IsStreaming: param.IsStreaming,
}))
}
for iter := 0; iter < dagData.BatchSize; iter++ {
for iter := 0; iter < len(itFutures); iter++ {
err = itFutures[iter].Get(ctx, nil)
if err != nil {
errs = append(errs, err)
Expand All @@ -363,6 +364,7 @@ func (w *worker) TriggerPipelineWorkflow(ctx workflow.Context, param *TriggerPip
if err = workflow.ExecuteActivity(ctx, w.PostIteratorActivity, &PostIteratorActivityParam{
WorkflowID: workflowID,
ID: compID,
ConditionMap: preIteratorResult.ConditionMap,
OutputElements: comp.OutputElements,
SystemVariables: param.SystemVariables,
}).Get(ctx, nil); err != nil {
Expand Down Expand Up @@ -614,20 +616,21 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}

result := &PreIteratorActivityResult{
ElementSize: make([]int, wfm.GetBatchSize()),
conditionMap, err := w.processCondition(ctx, wfm, param.ID, param.UpstreamIDs, param.Condition)
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}

batchSize := wfm.GetBatchSize()
childWorkflowIDs := make([]string, batchSize)
result := &PreIteratorActivityResult{}

childWorkflowIDs := make([]string, len(conditionMap))

for iter := range wfm.GetBatchSize() {
if err = wfm.SetComponentStatus(ctx, iter, param.ID, memory.ComponentStatusStarted, true); err != nil {
for idx, originalIdx := range conditionMap {
if err = wfm.SetComponentStatus(ctx, originalIdx, param.ID, memory.ComponentStatusStarted, true); err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}
childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, iter, constant.SegComponent, param.ID, constant.SegIteration)
childWorkflowIDs[iter] = childWorkflowID
childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, originalIdx, constant.SegComponent, param.ID, constant.SegIteration)
childWorkflowIDs[idx] = childWorkflowID

// If `input` is provided, the iteration will be performed over it;
// otherwise, the iteration will be based on the `range` setup.
Expand All @@ -636,7 +639,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi
var indexes []int
var elems []format.Value
if useInput {
input, err := recipe.Render(ctx, data.NewString(param.Input), iter, wfm, false)
input, err := recipe.Render(ctx, data.NewString(param.Input), originalIdx, wfm, false)
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}
Expand Down Expand Up @@ -679,7 +682,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi
return nil, componentActivityError(ctx, wfm, fmt.Errorf("iterator range error"), preIteratorActivityErrorType, param.ID)
}

renderedRangeParam, err := recipe.Render(ctx, rangeParam, iter, wfm, false)
renderedRangeParam, err := recipe.Render(ctx, rangeParam, originalIdx, wfm, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -750,12 +753,11 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi
}
}

result.ElementSize[iter] = len(indexes)
iteratorRecipe := &datamodel.Recipe{
Component: wfm.GetRecipe().Component[param.ID].Component,
}

childWFM, err := w.memoryStore.NewWorkflowMemory(ctx, childWorkflowIDs[iter], iteratorRecipe, len(indexes))
childWFM, err := w.memoryStore.NewWorkflowMemory(ctx, childWorkflowIDs[idx], iteratorRecipe, len(indexes))
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}
Expand Down Expand Up @@ -786,11 +788,11 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi
}

for e, rangeIndex := range indexes {
variable, err := wfm.Get(ctx, iter, constant.SegVariable)
variable, err := wfm.Get(ctx, originalIdx, constant.SegVariable)
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}
secret, err := wfm.Get(ctx, iter, constant.SegSecret)
secret, err := wfm.Get(ctx, originalIdx, constant.SegSecret)
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}
Expand All @@ -804,7 +806,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi
}

for _, id := range param.UpstreamIDs {
component, err := wfm.Get(ctx, iter, id)
component, err := wfm.Get(ctx, originalIdx, id)
if err != nil {
return nil, componentActivityError(ctx, wfm, err, preIteratorActivityErrorType, param.ID)
}
Expand Down Expand Up @@ -836,6 +838,7 @@ func (w *worker) PreIteratorActivity(ctx context.Context, param *PreIteratorActi

}
result.ChildWorkflowIDs = childWorkflowIDs
result.ConditionMap = conditionMap
logger.Info("PreIteratorActivity completed")
return result, nil
}
Expand All @@ -850,8 +853,8 @@ func (w *worker) PostIteratorActivity(ctx context.Context, param *PostIteratorAc
return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID)
}

for iter := range wfm.GetBatchSize() {
childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, iter, constant.SegComponent, param.ID, constant.SegIteration)
for _, originalIdx := range param.ConditionMap {
childWorkflowID := fmt.Sprintf("%s:%d:%s:%s:%s", param.WorkflowID, originalIdx, constant.SegComponent, param.ID, constant.SegIteration)
childWFM, err := w.memoryStore.GetWorkflowMemory(ctx, childWorkflowID)
if err != nil {
return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID)
Expand All @@ -871,11 +874,11 @@ func (w *worker) PostIteratorActivity(ctx context.Context, param *PostIteratorAc
}
output[k] = elemVals
}
if err = wfm.SetComponentData(ctx, iter, param.ID, memory.ComponentDataOutput, output); err != nil {
if err = wfm.SetComponentData(ctx, originalIdx, param.ID, memory.ComponentDataOutput, output); err != nil {
return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID)
}

if err = wfm.SetComponentStatus(ctx, iter, param.ID, memory.ComponentStatusCompleted, true); err != nil {
if err = wfm.SetComponentStatus(ctx, originalIdx, param.ID, memory.ComponentStatusCompleted, true); err != nil {
return componentActivityError(ctx, wfm, err, postIteratorActivityErrorType, param.ID)
}
}
Expand Down

0 comments on commit 04b1252

Please sign in to comment.