diff --git a/lib/op-attrs/src/aggregate.cc b/lib/op-attrs/src/aggregate.cc index c3b63b914c..883d6a2def 100644 --- a/lib/op-attrs/src/aggregate.cc +++ b/lib/op-attrs/src/aggregate.cc @@ -21,6 +21,18 @@ ParallelTensorShape return output_shape; } +TensorShape + get_output_shape(AggregateAttrs const &attrs, + TensorShape const &gate_preds, + TensorShape const &gate_assign, + TensorShape const &true_gate_assign, + TensorShape const &full_gate_gradients, + std::vector const &exp_preds) { + TensorShape output_shape = exp_preds.at(0); + return output_shape; +} + + bool is_valid(AggregateAttrs const &attrs, ParallelTensorShape const &gate_preds, ParallelTensorShape const &gate_assign, diff --git a/lib/op-attrs/test/src/tensor_shape_inf.cc b/lib/op-attrs/test/src/tensor_shape_inf.cc new file mode 100644 index 0000000000..34e966f14e --- /dev/null +++ b/lib/op-attrs/test/src/tensor_shape_inf.cc @@ -0,0 +1,44 @@ +#include "doctest/doctest.h" +#include "op-attrs/operator_attrs.h" +#include "parallel_tensor_shape.h" +#include + +using namespace FlexFlow; +using namespace rc; + +ParallelTensorShape aggregate_parallel_helper(int n, + FFOrdered const &gate_pred_assign_dims, + FFOrdered const &full_gate_grad_dims, + FFOrdered const &exp_dims) { + AggregateAttrs attrs {n, lambda_bal: 1.0}; + ParallelTensorShape gate_preds {gate_pred_assign_dims, DataType::FLOAT}; + ParallelTensorShape gate_assign {gate_pred_assign_dims, DataType::INT32}; + ParallelTensorShape true_gate_assign {gate_pred_assign_dims, DataType::INT32}; + ParallelTensorShape full_gate_gradients {full_gate_grad_dims, DataType::FLOAT}; + std::vector exp_preds; + for (int i=0; i gate_pred_assign_dims {{k, 1, false}, {batch_size, 1, false}, replica_dim}; + FFOrdered full_gate_grad_dims {{n, 1, false}, {batch_size, 1, false}, {1, 1, true}}; + FFOrdered exp_dims {{output_dim, 1, false}, {rows, 1, false}, {1, 1, true}}; + FFOrdered output_dims {{output_dim, 1, false}, {rows, 1, false}, replica_dim}; + ParallelTensorShape correct_output {output_dims, DataType::FLOAT}; + + CHECK(aggregate_parallel_helper(n, gate_pred_assign_dims, full_gate_grad_dims, exp_dims) == correct_output); +} \ No newline at end of file