Skip to content

Commit

Permalink
fix DeviceSpecific
Browse files Browse the repository at this point in the history
  • Loading branch information
KateUnger committed Aug 23, 2023
1 parent e954ab0 commit 5458923
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions lib/runtime/src/ops/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) {
return {BATCHMATMUL_BWD_TASK_ID, bwd};
}

static DeviceSpecificArg<BMMPerDeviceState>
static DeviceSpecific<BMMPerDeviceState>
init_task_impl(TaskArgumentAccessor const &acc) {
auto const a_seq_length_dim = acc.get_argument<int>(A_SEQ_LENGTH_DIM);
auto const b_seq_length_dim = acc.get_argument<int>(B_SEQ_LENGTH_DIM);
PerDeviceFFHandle handle = acc.get_argument<PerDeviceFFHandle>(HANDLE);
Allocator allocator = acc.get_allocator();

DeviceSpecificArg<BMMPerDeviceState> per_device_state =
DeviceSpecific<BMMPerDeviceState> per_device_state =
acc.create_device_specific<BMMPerDeviceState>(
init_kernel(handle, allocator, a_seq_length_dim, b_seq_length_dim));

Expand All @@ -86,7 +86,7 @@ static DeviceSpecificArg<BMMPerDeviceState>
return per_device_state;
}

static DeviceSpecificArg<BMMPerDeviceState>
static DeviceSpecific<BMMPerDeviceState>
init_task(Task const *task,
std::vector<PhysicalRegion> const &regions,
Context ctx,
Expand Down Expand Up @@ -241,7 +241,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim,

auto init_accessor =
env.get_init_accessor(BATCHMATMUL_INIT_TASK_ID, init_binding);
DeviceSpecificArg<BMMPerDeviceState> per_device_state =
DeviceSpecific<BMMPerDeviceState> per_device_state =
init_task_impl(init_accessor);

SimTaskBinding fwd_binding;
Expand Down

0 comments on commit 5458923

Please sign in to comment.