Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 29, 2024
1 parent 7dbb864 commit 6dc8c68
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/m1_script.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

export BUILD_VERSION=0.4.0
export TENSORDICT_BUILD_VERSION=0.4.0
4 changes: 2 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install wheel
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
TENSORDICT_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
shell: bash
run: |
python3 -mpip install wheel
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
TENSORDICT_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
Expand Down
14 changes: 9 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:

def get_version():
version = (ROOT_DIR / "version.txt").read_text().strip()
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
if os.getenv("TENSORDICT_BUILD_VERSION"):
version = os.getenv("TENSORDICT_BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
return version
Expand All @@ -62,11 +62,13 @@ def write_version_file(version):
f.write(f"git_version = {repr(sha)}\n")


def _get_pytorch_version(is_nightly):
def _get_pytorch_version(is_nightly, is_local):
# if "PYTORCH_VERSION" in os.environ:
# return f"torch=={os.environ['PYTORCH_VERSION']}"
if is_nightly:
return "torch>=2.4.0.dev"
if is_local:
return "torch"
return "torch>=2.3.0"


Expand Down Expand Up @@ -153,9 +155,11 @@ def _main(argv):

write_version_file(version)
logging.info(f"Building wheel {package_name}-{version}")
logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}")
BUILD_VERSION = os.getenv("TENSORDICT_BUILD_VERSION")
logging.info(f"TENSORDICT_BUILD_VERSION is {BUILD_VERSION}")
local_build = BUILD_VERSION is None

pytorch_package_dep = _get_pytorch_version(is_nightly)
pytorch_package_dep = _get_pytorch_version(is_nightly, local_build)
logging.info("-- PyTorch dependency:", pytorch_package_dep)

long_description = (ROOT_DIR / "README.md").read_text()
Expand Down

0 comments on commit 6dc8c68

Please sign in to comment.