Skip to content

Commit

Permalink
use pytorch version env variable (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Jun 15, 2024
1 parent bc2f8b7 commit aa61c45
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
torch
numpy
sentencepiece
packaging
expecttest # So we can use IS_FBCODE flag
Expand Down
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def get_extensions():

return ext_modules

# Mimic code from torchvision https://github.com/pytorch/vision/blob/143d078b28f00471156a4e562dd3836370acc9ee/setup.py#L58
pytorch_dep = "torch"
if os.getenv("PYTORCH_VERSION"):
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")

requirements = [
"numpy",
pytorch_dep,
]

setup(
name=package_name,
version=version+version_suffix,
Expand All @@ -97,7 +107,7 @@ def get_extensions():
"torchao.kernel.configs": ["*.pkl"],
},
ext_modules=get_extensions() if use_cpp != "0" else None,
install_requires=read_requirements("requirements.txt"),
install_requires=requirements,
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
long_description=open("README.md").read(),
Expand Down

0 comments on commit aa61c45

Please sign in to comment.