Skip to content

Commit

Permalink
Graceful handling of cpp extensions (#296)
Browse files Browse the repository at this point in the history
* Graceful handling of cpp extensions

* update

* push

* yolo

* revert some changes'

* Update __init__.py

* Update README.md

---------
  • Loading branch information
msaroufim authored May 30, 2024
1 parent 4c1d568 commit e7837d7
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 9 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ python setup.py develop

If you want to install from source run
```Shell
python setup.py install
python setup.py install
```

** Note:
Since we are building pytorch c++/cuda extensions by default, running `pip install .` will
not work.
If you are running into any issues while building `ao` cpp extensions you can instead build using

```shell
USE_CPP=0 python setup.py install
```

### Quantization

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def read_requirements(file_path):
# Determine the package name based on the presence of an environment variable
package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao"
version_suffix = os.getenv("VERSION_SUFFIX", "")
use_cpp = os.getenv('USE_CPP')


# Version is year.month.date if using nightlies
version = current_date if package_name == "torchao-nightly" else "0.2.0"
Expand Down Expand Up @@ -92,7 +94,7 @@ def get_extensions():
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
ext_modules=get_extensions(),
ext_modules=get_extensions() if use_cpp != "0" else None,
install_requires=read_requirements("requirements.txt"),
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
Expand Down
7 changes: 7 additions & 0 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
parametrize,
run_tests,
)

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2


Expand Down
5 changes: 5 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from parameterized import parameterized
import pytest

try:
import torchao.ops
except RuntimeError:
pytest.skip("torchao.ops not available")


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
Expand Down
10 changes: 8 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import torch
import logging

_IS_FBCODE = (
hasattr(torch._utils_internal, "IS_FBSOURCE") and
torch._utils_internal.IS_FBSOURCE
)

if not _IS_FBCODE:
from . import _C
from . import ops
try:
from . import _C
from . import ops
except:
_C = None
logging.info("Skipping import of cpp extensions")

from torchao.quantization import (
apply_weight_only_int8_quant,
Expand Down
10 changes: 7 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from .nf4tensor import NF4Tensor, to_nf4
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_aq
from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor"
"AffineQuantizedTensor",
"to_aq",
"to_float6_e3m2",
"from_float6_e3m2",
]

# CPP extensions
try:
from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2
__all__.extend(["to_float6_e3m2", "from_float6_e3m2"])
except RuntimeError:
pass

0 comments on commit e7837d7

Please sign in to comment.