Skip to content

Commit

Permalink
Adding preprocessor checks for torch version during torch cpp extensi…
Browse files Browse the repository at this point in the history
…ons compilation (#8989)
  • Loading branch information
baijumeswani authored Sep 9, 2021
1 parent 0367e1f commit d78e90d
Showing 1 changed file with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,24 @@ std::vector<DLManagedTensor*> ExecuteATenOperator(const char* op_name, const cha
torch::jit::push(stack, arguments[i]);
}

#ifndef TORCH_VERSION_PREEQ
#define TORCH_VERSION_PREEQ(x, y) \
((TORCH_VERSION_MAJOR == (x) && TORCH_VERSION_MINOR >= (y)) || \
(TORCH_VERSION_MAJOR > (x)))
#endif

// pull request https://github.com/pytorch/pytorch/pull/63414 introduced
// a backwards incompatibility by changing the API. To make ORTModule
// work with both torch versions >=1.10 as well as < 1.10, we need
// preprocessor checks
#if TORCH_VERSION_PREEQ(1, 10)
// torch version is >= 1.10
aten_op.op->getOperation()(stack);
#else
// torch version is < 1.10
aten_op.op->getOperation()(&stack);
#endif

std::vector<DLManagedTensor*> result;
for (const auto& ret : torch::jit::pop(stack, aten_op.return_size)) {
const auto& tensor = ret.toTensor();
Expand Down

0 comments on commit d78e90d

Please sign in to comment.