Skip to content

Commit

Permalink
Fix PyTorch version string parsing logic
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Dec 18, 2024
1 parent 57c668d commit 0ca42c7
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,26 @@


def _is_torch_compatible(current: str, required: str):
major, minor, patch = current.split(".") # eg. 2.1.2+cu121
major_, minor_, patch_ = required.split(".")
# PyTorch version tag examples:
# * 2.1.2+cu121
# * 2.1.2+cpu
# * 2.1.2
major, minor, patch = current.split(".")
required_major, required_minor, required_patch = required.split(".")

if (major, minor) != (major_, minor_):
if (major, minor) != (required_major, required_minor):
return False

_, cuda = patch.split("+")
_, cuda_ = patch_.split("+")
_, *cuda = patch.split("+")
_, *required_cuda = required_patch.split("+")

if cuda != cuda_:
return False
if not cuda:
return True

cuda, = cuda
required_cuda, = required_cuda

return True
return cuda == required_cuda or cuda == "cpu"


def _check_requirements():
Expand Down

0 comments on commit 0ca42c7

Please sign in to comment.