Skip to content

Commit

Permalink
Reused memory usage for FusedOp
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao committed Oct 6, 2024
1 parent 5b2b9b3 commit c66112d
Showing 1 changed file with 36 additions and 46 deletions.
82 changes: 36 additions & 46 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,40 +101,40 @@ __host__ void
assert((int)regions.size() == fused->numInputs + fused->numWeights +
fused->numOutputs +
softmax_grad_additional_region);
GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> input_accessor;
std::vector<GenericTensorAccessorR> weight_accessor;
std::vector<GenericTensorAccessorW> output_accessor;
assert(fused->numInputs <= MAX_NUM_INPUTS);
for (int i = 0; i < fused->numInputs; i++) {
input_accessor[i] =
input_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->input_data_types[i],
regions[i],
task->regions[i],
FID_DATA,
ctx,
runtime);
runtime));
}
int roff = fused->numInputs;
assert(fused->numWeights <= MAX_NUM_WEIGHTS);
for (int i = 0; i < fused->numWeights; i++) {
weight_accessor[i] =
weight_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->weight_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numWeights;
assert(fused->numOutputs <= MAX_NUM_OUTPUTS);
for (int i = 0; i < fused->numOutputs; i++) {
output_accessor[i] =
output_accessor.push_back(
helperGetGenericTensorAccessorWO(fused->output_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numOutputs;
// Assert that all meta share the same dnn/blas handler
Expand All @@ -153,39 +153,28 @@ __host__ void

int ioff = 0, woff = 0, ooff = 0;
for (int op = 0; op < fused->numOperators; op++) {
#if 0
std::cout << get_operator_type_name(fused->op_op_type[op]) << std::endl;
#endif
GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> my_input_accessor;
std::vector<GenericTensorAccessorR> my_weight_accessor;
std::vector<GenericTensorAccessorW> my_output_accessor;
for (int i = 0; i < fused->op_num_inputs[op]; i++) {
int my_off = fused->op_input_idx[i + ioff];
if (fused->op_input_source[i + ioff] == SOURCE_INPUT) {
my_input_accessor[i] = input_accessor[my_off];
#if 0
printf("\tmy_input_accessor[%i] = input_accessor[%i]\n", i, my_off);
#endif
my_input_accessor.push_back(input_accessor[my_off]);
} else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) {
my_input_accessor[i] = output_accessor[my_off];
#if 0
printf("\tmy_input_accessor[%i] = output_accessor[%i]\n", i, my_off);
#endif
my_input_accessor.push_back(output_accessor[my_off]);
} else {
assert(false);
}
}
for (int i = 0; i < fused->op_num_weights[op]; i++) {
assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT);
my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]];
my_weight_accessor.push_back(
weight_accessor[fused->op_weight_idx[i + woff]]);
}
for (int i = 0; i < fused->op_num_outputs[op]; i++) {
int my_off = fused->op_output_idx[i + ooff];
assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT);
my_output_accessor[i] = output_accessor[my_off];
#if 0
printf("\tmy_output_accessor[%i] = output_accessor[%i]\n", i, my_off);
#endif
my_output_accessor.push_back(output_accessor[my_off]);
}
switch (fused->op_op_type[op]) {
case OP_CONCAT: {
Expand All @@ -195,7 +184,7 @@ __host__ void
int num_inputs = fused->op_num_inputs[op];
Kernels::Concat::forward_kernel_wrapper(m,
my_output_accessor[0],
my_input_accessor,
my_input_accessor.data(),
num_inputs,
m->legion_axis);
break;
Expand Down Expand Up @@ -1242,40 +1231,40 @@ __host__ void FusedOp::forward_task(Task const *task,
assert(regions.size() == task->regions.size());
assert((int)regions.size() ==
fused->numInputs + fused->numWeights + fused->numOutputs);
GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> input_accessor;
std::vector<GenericTensorAccessorR> weight_accessor;
std::vector<GenericTensorAccessorW> output_accessor;
assert(fused->numInputs <= MAX_NUM_INPUTS);
for (int i = 0; i < fused->numInputs; i++) {
input_accessor[i] =
input_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->input_data_types[i],
regions[i],
task->regions[i],
FID_DATA,
ctx,
runtime);
runtime));
}
int roff = fused->numInputs;
assert(fused->numWeights <= MAX_NUM_WEIGHTS);
for (int i = 0; i < fused->numWeights; i++) {
weight_accessor[i] =
weight_accessor.push_back(
helperGetGenericTensorAccessorRO(fused->weight_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
roff += fused->numWeights;
assert(fused->numOutputs <= MAX_NUM_OUTPUTS);
for (int i = 0; i < fused->numOutputs; i++) {
output_accessor[i] =
output_accessor.push_back(
helperGetGenericTensorAccessorWO(fused->output_data_types[i],
regions[i + roff],
task->regions[i + roff],
FID_DATA,
ctx,
runtime);
runtime));
}
// Assert that all meta share the same dnn/blas handler
int start = 0;
Expand All @@ -1293,31 +1282,32 @@ __host__ void FusedOp::forward_task(Task const *task,

int ioff = 0, woff = 0, ooff = 0;
for (int op = 0; op < fused->numOperators; op++) {
GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS];
GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS];
GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS];
std::vector<GenericTensorAccessorR> my_input_accessor;
std::vector<GenericTensorAccessorR> my_weight_accessor;
std::vector<GenericTensorAccessorW> my_output_accessor;
for (int i = 0; i < fused->op_num_inputs[op]; i++) {
int my_off = fused->op_input_idx[i + ioff];
if (fused->op_input_source[i + ioff] == SOURCE_INPUT) {
assert(my_off < fused->numInputs);
my_input_accessor[i] = input_accessor[my_off];
my_input_accessor.push_back(input_accessor[my_off]);
} else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) {
assert(my_off < fused->numOutputs);
my_input_accessor[i] = output_accessor[my_off];
my_input_accessor.push_back(output_accessor[my_off]);
} else {
assert(false);
}
}
for (int i = 0; i < fused->op_num_weights[op]; i++) {
assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT);
assert(fused->op_weight_idx[i + woff] < fused->numWeights);
my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]];
my_weight_accessor.push_back(
weight_accessor[fused->op_weight_idx[i + woff]]);
}
for (int i = 0; i < fused->op_num_outputs[op]; i++) {
int my_off = fused->op_output_idx[i + ooff];
assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT);
assert(my_off < fused->numOutputs);
my_output_accessor[i] = output_accessor[my_off];
my_output_accessor.push_back(output_accessor[my_off]);
}
switch (fused->op_op_type[op]) {
case OP_CONCAT: {
Expand All @@ -1327,7 +1317,7 @@ __host__ void FusedOp::forward_task(Task const *task,
int num_inputs = fused->op_num_inputs[op];
Kernels::Concat::forward_kernel_wrapper(m,
my_output_accessor[0],
my_input_accessor,
my_input_accessor.data(),
num_inputs,
m->legion_axis);
break;
Expand Down

0 comments on commit c66112d

Please sign in to comment.