Skip to content

Commit

Permalink
Revert "[AIR] Deprecations for 2.3 (#31763)" (#31866)
Browse files Browse the repository at this point in the history
This reverts commit 91b632b.
  • Loading branch information
Alex Wu authored Jan 23, 2023
1 parent 40c4571 commit 58386d0
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 5 deletions.
38 changes: 37 additions & 1 deletion python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
upload_to_uri,
)
from ray.air.constants import PREPROCESSOR_KEY, CHECKPOINT_ID_ATTR
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
Expand Down Expand Up @@ -415,6 +415,42 @@ def to_dict(self) -> dict:
checkpoint_data[PREPROCESSOR_KEY] = self._override_preprocessor
return checkpoint_data

@classmethod
@Deprecated(
message="To restore a checkpoint from a remote object ref, call "
"`ray.get(obj_ref)` instead."
)
def from_object_ref(cls, obj_ref: ray.ObjectRef) -> "Checkpoint":
"""Create checkpoint object from object reference.
Args:
obj_ref: ObjectRef pointing to checkpoint data.
Returns:
Checkpoint: checkpoint object.
"""
raise DeprecationWarning(
"`from_object_ref` is deprecated and will be removed in a future Ray "
"version. To restore a Checkpoint from a remote object ref, call "
"`ray.get(obj_ref)` instead.",
)

@Deprecated(
message="To store the checkpoint in the Ray object store, call `ray.put(ckpt)` "
"instead of `ckpt.to_object_ref()`."
)
def to_object_ref(self) -> ray.ObjectRef:
"""Return checkpoint data as object reference.
Returns:
ray.ObjectRef: ObjectRef pointing to checkpoint data.
"""
raise DeprecationWarning(
"`to_object_ref` is deprecated and will be removed in a future Ray "
"version. To store the checkpoint in the Ray object store, call "
"`ray.put(ckpt)` instead of `ckpt.to_object_ref()`.",
)

@classmethod
def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
"""Create checkpoint object from directory.
Expand Down
12 changes: 8 additions & 4 deletions python/ray/data/preprocessors/batch_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,23 @@ def __init__(
Union[np.ndarray, Dict[str, np.ndarray]],
],
],
batch_format: Optional[BatchFormat],
batch_format: Optional[BatchFormat] = None,
batch_size: Optional[Union[int, Literal["default"]]] = "default",
# TODO: Make batch_format required from user
# TODO: Introduce a "zero_copy" format
# TODO: We should reach consistency of args between BatchMapper and map_batches.
):
if not batch_format:
raise DeprecationWarning(
"batch_format is a required argument for BatchMapper from Ray 2.1."
"You must specify either 'pandas' or 'numpy' batch format."
)

if batch_format not in [
BatchFormat.PANDAS,
BatchFormat.NUMPY,
]:
raise ValueError(
"BatchMapper only supports 'pandas' or 'numpy' batch format."
)
raise ValueError("BatchMapper only supports pandas and numpy batch format.")

self.batch_format = batch_format
self.batch_size = batch_size
Expand Down
16 changes: 16 additions & 0 deletions python/ray/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from ray._private.usage import usage_lib
from ray.train.backend import BackendConfig
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.train_loop_utils import (
get_dataset_shard,
load_checkpoint,
local_rank,
report,
save_checkpoint,
world_rank,
world_size,
)
from ray.train.trainer import TrainingIterator


usage_lib.record_library_usage("train")

__all__ = [
"BackendConfig",
"get_dataset_shard",
"load_checkpoint",
"local_rank",
"report",
"save_checkpoint",
"TrainingIterator",
"world_rank",
"world_size",
"TRAIN_DATASET_KEY",
]
10 changes: 10 additions & 0 deletions python/ray/train/tests/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ def __getstate__(self):
assert results.checkpoint


def test_torch_prepare_model_deprecated():
model = torch.nn.Linear(1, 1)

with pytest.raises(DeprecationWarning):
train.torch.prepare_model(model, wrap_ddp=True)

with pytest.raises(DeprecationWarning):
train.torch.prepare_model(model, ddp_kwargs={"x": "y"})


if __name__ == "__main__":
import sys

Expand Down
15 changes: 15 additions & 0 deletions python/ray/train/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def prepare_model(
move_to_device: bool = True,
parallel_strategy: Optional[str] = "ddp",
parallel_strategy_kwargs: Optional[Dict[str, Any]] = None,
# Deprecated args.
wrap_ddp: bool = False,
ddp_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.nn.Module:
"""Prepares the model for distributed execution.
Expand All @@ -73,6 +76,18 @@ def prepare_model(
or "fsdp", respectively.
"""

if wrap_ddp:
raise DeprecationWarning(
"The `wrap_ddp` argument is deprecated as of Ray 2.1. Use the "
"`parallel_strategy` argument instead."
)

if ddp_kwargs:
raise DeprecationWarning(
"The `ddp_kwargs` argument is deprecated as of Ray 2.1. Use the "
"`parallel_strategy_kwargs` arg instead."
)

if parallel_strategy == "fsdp" and FullyShardedDataParallel is None:
raise ImportError(
"FullyShardedDataParallel requires torch>=1.11.0. "
Expand Down
233 changes: 233 additions & 0 deletions python/ray/train/train_loop_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from typing import TYPE_CHECKING, Dict, Optional, Union

from ray.util.annotations import Deprecated

if TYPE_CHECKING:
from ray.data import Dataset, DatasetPipeline


def _get_deprecation_msg(is_docstring: bool, fn_name: Optional[str] = None):
if is_docstring:
session_api_link = ":ref:`ray.air.session <air-session-ref>`"
else:
session_api_link = (
"`ray.air.session` ( "
"https://docs.ray.io/en/latest/ray-air/package-ref.html"
"#module-ray.air.session"
") ."
)

deprecation_msg = (
f"The `train.{fn_name}` APIs are deprecated in Ray "
f"2.1, and is replaced by {session_api_link}"
"The `ray.air.session` APIs provide the same functionality, "
"but in a unified manner across Ray Train and Ray Tune."
)
return deprecation_msg


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def get_dataset_shard(
dataset_name: Optional[str] = None,
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
"""Returns the Ray Dataset or DatasetPipeline shard for this worker.
Call :meth:`~ray.data.Dataset.iter_torch_batches` or
:meth:`~ray.data.Dataset.to_tf` on this shard to convert it to the appropriate
framework-specific data type.
.. code-block:: python
import ray
from ray import train
def train_func():
model = Net()
for iter in range(100):
data_shard = session.get_dataset_shard("train")
for batch in data_shard.iter_torch_batches():
# ...
return model
dataset = ray.data.read_csv("train.csv")
dataset.filter(...).repeat().random_shuffle()
trainer = Trainer(backend="torch")
trainer.start()
# Trainer will automatically handle sharding.
train_model = trainer.run(train_func, dataset=dataset)
trainer.shutdown()
Args:
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
specifies which dataset shard to return.
Returns:
The ``Dataset`` or ``DatasetPipeline`` shard to use for this worker.
If no dataset is passed into Trainer, then return None.
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=get_dataset_shard.__name__),
)


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def report(**kwargs) -> None:
"""Reports all keyword arguments to Train as intermediate results.
.. code-block:: python
import time
from ray import train
def train_func():
for iter in range(100):
time.sleep(1)
train.report(hello="world")
trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Args:
**kwargs: Any key value pair to be reported by Train.
If callbacks are provided, they are executed on these
intermediate results.
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=report.__name__),
)


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def world_rank() -> int:
"""Get the world rank of this worker.
.. code-block:: python
import time
from ray import train
def train_func():
for iter in range(100):
time.sleep(1)
if train.world_rank() == 0:
print("Worker 0")
trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=world_rank.__name__),
)


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. code-block:: python
import time
from ray import train
def train_func():
if torch.cuda.is_available():
torch.cuda.set_device(train.local_rank())
...
trainer = Trainer(backend="torch", use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=local_rank.__name__),
)


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def load_checkpoint() -> Optional[Dict]:
"""Loads checkpoint data onto the worker.
.. code-block:: python
from ray import train
def train_func():
checkpoint = train.load_checkpoint()
for iter in range(checkpoint["epoch"], 5):
print(iter)
trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func, checkpoint={"epoch": 3})
# 3
# 4
trainer.shutdown()
Args:
**kwargs: Any key value pair to be checkpointed by Train.
Returns:
The most recently saved checkpoint if ``train.save_checkpoint()``
has been called. Otherwise, the checkpoint that the session was
originally initialized with. ``None`` if neither exist.
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=load_checkpoint.__name__),
)


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def save_checkpoint(**kwargs) -> None:
"""Checkpoints all keyword arguments to Train as restorable state.
.. code-block:: python
import time
from ray import train
def train_func():
for iter in range(100):
time.sleep(1)
train.save_checkpoint(epoch=iter)
trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Args:
**kwargs: Any key value pair to be checkpointed by Train.
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=save_checkpoint.__name__),
)


@Deprecated(message=_get_deprecation_msg(is_docstring=True))
def world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
.. code-block:: python
import time
from ray import train
def train_func():
assert train.world_size() == 4
trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
"""
raise DeprecationWarning(
_get_deprecation_msg(is_docstring=False, fn_name=world_size.__name__),
)

0 comments on commit 58386d0

Please sign in to comment.