Skip to content

Commit

Permalink
[MOE] update moe cpp example and aggregate implementation (#555)
Browse files Browse the repository at this point in the history
* [MOE] update moe cpp example and aggregate implementation

* [MOE] bug fixes to make the MOE example work
  • Loading branch information
jiazhihao authored and goliaro committed Jan 17, 2023
1 parent 3920d98 commit e8770cc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
22 changes: 10 additions & 12 deletions examples/cpp/inference/mixture_of_experts/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
using namespace Legion;

LegionRuntime::Logger::Category log_app("MoE");
int num_exp = 5;
int num_exp = 4;
int num_select = 2;

void parse_input_args(char **argv, int argc, MoeConfig &config) {
Expand All @@ -54,22 +54,20 @@ Tensor create_moe(FFModel *model,
gate_preds = model->dense(gate_preds, num_exp, AC_MODE_RELU);
Tensor topK_output[2];
model->top_k(gate_preds, topK_output, num_select, false);
Tensor exp_tensors[num_exp];
model->group_by(input, topK_output[1], exp_tensors, num_exp, alpha);
for (int i=0; i<num_exp; i++) {
exp_tensors[i]->dims[2] = 1; // temporary fix to replica dimension being undefined
exp_tensors[i]->print("exp_tensors[i]");
}
Tensor agg_inputs[num_exp + 4];
agg_inputs[0] = model->softmax(topK_output[0]); // gate preds
agg_inputs[1] = topK_output[1]; // gate assign
agg_inputs[2] = topK_output[1]; // gate assign TopK (for cache)
agg_inputs[3] = gate_preds; // full gate preds
for (int i = 0; i < num_exp; i++) {
Tensor exp_pred =
model->dense(exp_tensors[i], moeConfig->hidden_size, AC_MODE_RELU);
exp_pred->print("exp_pred");
agg_inputs[i + 4] = model->softmax(exp_pred);
for (int i = 0; i < num_exp /*number of experts layers*/; i++) {
Tensor exp_pred = model->experts(gate_preds,
topK_output[1],
32 /*number of experts*/,
32 * i /*expert start index*/,
1 /*number of linear layers*/,
moeConfig->hidden_size /*output_size*/,
moeConfig->hidden_size /*internal_size*/);
agg_inputs[i + 4] = exp_pred;
}
for (int i = 0; i < num_exp + 4; i++) {
agg_inputs[i]->print("agg_inputs[i]");
Expand Down
21 changes: 12 additions & 9 deletions src/ops/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ Tensor FFModel::aggregate(
int num_dim = inputs[4]->num_dims;
// Set output shape
int dims[MAX_TENSOR_DIM];
for (int i = 0; i < num_dim - 1; i++) {
for (int i = 0; i < num_dim; i++)
dims[i] = inputs[4]->dims[i];
}
dims[num_dim - 1] = inputs[0]->dims[num_dim - 1];
li->outputs[0] = create_tensor_legion_ordering(
num_dim, dims, DT_FLOAT, li, 0, true /*create_grad*/);
}
Expand Down Expand Up @@ -143,11 +141,16 @@ Aggregate::Aggregate(FFModel &model,
}
// Set output shape
ParallelDim dims[MAX_TENSOR_DIM];
<<<<<<< HEAD
for (int i = 0; i < num_dim - 1; i++) {
dims[i] = inputs[4]->dims[i];
}
dims[num_dim - 2] = inputs[0]->dims[num_dim - 2];
dims[num_dim - 1] = inputs[0]->dims[num_dim - 1];
=======
for (int i = 0; i < num_dim; i++)
dims[i] = inputs[4]->dims[i];
>>>>>>> 99a89a9b... [MOE] update moe cpp example and aggregate implementation (#555)
numOutputs = 1;
outputs[0] = model.create_parallel_tensor_legion_ordering(
num_dim, dims, DT_FLOAT, this);
Expand Down Expand Up @@ -204,7 +207,7 @@ void Aggregate::forward(FFModel const &ff) {
set_argumentmap_for_forward(ff, argmap);
IndexLauncher launcher(AGGREGATE_FWD_TASK_ID,
parallel_is,
TaskArgument(this, sizeof(Aggregate)),
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
Expand Down Expand Up @@ -255,7 +258,7 @@ FutureMap Aggregate::inference(FFModel const &ff,
size_t machine_view_hash = mv ? mv->hash() : outputs[0]->machine_view.hash();
IndexLauncher launcher(AGGREGATE_FWD_TASK_ID,
parallel_is,
TaskArgument(this, sizeof(Aggregate)),
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
Expand Down Expand Up @@ -299,10 +302,10 @@ void Aggregate::forward_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
int n = ((Aggregate *)task->args)->n;

assert((int)regions.size() == n + 3);
assert((int)task->regions.size() == n + 3);
assert(regions.size() == task->regions.size());
int n = regions.size() - 3;
// FIXME: skip the aggregate computation for now
return;

AggregateMeta const *m = *((AggregateMeta **)task->local_args);

Expand Down
16 changes: 9 additions & 7 deletions src/ops/experts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ Tensor FFModel::experts(const Tensor input,
1 /*outputs*/,
input,
indices);
assert(input->num_dims == indices->num_dims + 1);
for (int i = 0; i < indices->num_dims; i++)
assert(input->dims[i + 1] == indices->dims[i]);
assert(input->num_dims == indices->num_dims);
for (int i = 1; i < indices->num_dims; i++)
assert(input->dims[i] == indices->dims[i]);
assert(indices->data_type == DT_INT32 || indices->data_type == DT_INT64);
int dims[MAX_TENSOR_DIM];
int numdim = input->num_dims;
Expand Down Expand Up @@ -168,12 +168,14 @@ Experts::Experts(FFModel &model,
experts_num_layers(_experts_num_layers),
experts_output_dim_size(_experts_output_dim_size),
experts_internal_dim_size(_experts_internal_dim_size) {
assert(input->num_dims == indices->num_dims + 1);
assert(input->num_dims == indices->num_dims);
assert(indices->data_type == DT_INT32 || indices->data_type == DT_INT64);
for (int i = 0; i < indices->num_dims; i++)
assert(input->dims[i + 1] == indices->dims[i]);
// Assume that we don't parallelize the channel dim
for (int i = 1; i < indices->num_dims; i++)
assert(input->dims[i] == indices->dims[i]);
// Assume that we don't parallelize the channel dim of input
// nor the expert_assigned dim of indices
assert(input->dims[0].degree == 1);
assert(indices->dims[0].degree == 1);
ParallelDim dims[MAX_TENSOR_DIM];
for (int i = 0; i < input->num_dims; i++)
dims[i] = input->dims[i];
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/ffconst_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ std::string get_operator_type_name(OperatorType type) {
return "Split";
case OP_EMBEDDING:
return "Embedding";
case OP_EXPERTS:
return "Experts";
case OP_GROUP_BY:
return "Group_by";
case OP_CACHE:
Expand Down

0 comments on commit e8770cc

Please sign in to comment.