Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update submodules, prepare for leasing v0.2.4 #127

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmarks/prefill_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import os
import time

# import torch_xla2 first!
# pylint: disable-next=all
import torch_xla2
import humanize
import jax
import numpy as np
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import os
import time

# import torch_xla2 first!
# pylint: disable-next=all
import torch_xla2
import jax
import jax.numpy as jnp
# pylint: disable-next=all
Expand Down
2 changes: 1 addition & 1 deletion deps/JetStream
2 changes: 1 addition & 1 deletion deps/xla
Submodule xla updated 231 files
6 changes: 3 additions & 3 deletions install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ pip show tensorboard && pip uninstall -y tensorboard
pip show tensorflow-text && pip uninstall -y tensorflow-text
pip show torch_xla2 && pip uninstall -y torch_xla2

pip install flax==0.8.3
pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install flax
pip install tensorflow-text
pip install tensorflow

pip install ray[default]==2.22.0
# torch cpu
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
pip install safetensors colorama coverage humanize

git submodule update --init --recursive
pip show google-jetstream && pip uninstall -y google-jetstream
pip show torch_xla2 && pip uninstall -y torch_xla2
pip install -e .
pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
2 changes: 2 additions & 0 deletions run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import time
from typing import List

# import torch_xla2 first!
import torch_xla2 # pylint: disable
import jax
import numpy as np
from absl import app, flags
Expand Down
2 changes: 2 additions & 0 deletions run_server_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Sequence
from absl import app, flags

# import torch_xla2 first!
import torch_xla2 # pylint: disable
import jax
from jetstream.core import server_lib
from jetstream.core.config_lib import ServerConfig
Expand Down
Loading