Skip to content

Commit

Permalink
Merge pull request #98 from Nixtla/feat/ray
Browse files Browse the repository at this point in the history
Feat: add ray integration
  • Loading branch information
AzulGarza authored May 1, 2022
2 parents 39b2a4d + d8cfd1c commit 85253e5
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 7 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,30 @@ jobs:

- name: Run tests
run: nbdev_test_nbs


test-ray:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- name: Clone repo
uses: actions/checkout@v2

- name: Set up python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install library with ray
run: pip3 install ".[ray]"

- name: Windows redis
if: ${{ matrix.os == 'windows-latest' }}
run: pip3 install redis #https://github.com/ray-project/ray/pull/23991, remove later

- name: Run ray
run: python3 action_files/test_ray.py
38 changes: 38 additions & 0 deletions action_files/test_ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import ray
from statsforecast.core import StatsForecast
from statsforecast.models import (
adida,
auto_arima,
croston_classic,
croston_optimized,
croston_sba,
historic_average,
imapa,
naive,
random_walk_with_drift,
seasonal_exponential_smoothing,
seasonal_naive,
seasonal_window_average,
ses,
tsb,
window_average,
)
from statsforecast.utils import generate_series


if __name__ == "__main__":
series = generate_series(20)
ray_context = ray.init()
fcst = StatsForecast(
series,
[adida, croston_classic, croston_optimized,
croston_sba, historic_average, imapa, naive,
random_walk_with_drift, (seasonal_exponential_smoothing, 7, 0.1),
(seasonal_naive, 7), (seasonal_window_average, 7, 4),
(ses, 0.1), (tsb, 0.1, 0.3), (window_average, 4)],
freq='D',
n_jobs=int(ray.cluster_resources()['CPU']),
ray_address=ray_context.address_info['address']
)
fcst.forecast(7)
ray.shutdown()
23 changes: 19 additions & 4 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
"#export\n",
"import inspect\n",
"import logging\n",
"from multiprocessing import Pool\n",
"from functools import partial\n",
"\n",
"import numpy as np\n",
Expand Down Expand Up @@ -309,11 +308,12 @@
"#export\n",
"class StatsForecast:\n",
" \n",
" def __init__(self, df, models, freq, n_jobs=1):\n",
" def __init__(self, df, models, freq, n_jobs=1, ray_address=None):\n",
" self.ga, self.uids, self.last_dates = _grouped_array_from_df(df)\n",
" self.models = models\n",
" self.freq = pd.tseries.frequencies.to_offset(freq)\n",
" self.n_jobs = n_jobs\n",
" self.ray_address = ray_address\n",
" \n",
" def forecast(self, h, xreg=None, level=None):\n",
" if xreg is not None:\n",
Expand Down Expand Up @@ -363,7 +363,22 @@
" from itertools import repeat\n",
" \n",
" xregs = repeat(None)\n",
" with Pool(self.n_jobs) as executor:\n",
" \n",
" if self.ray_address is not None:\n",
" try:\n",
" from ray.util.multiprocessing import Pool\n",
" except ModuleNotFoundError as e:\n",
" msg = (\n",
" '{e}. To use a ray cluster you have to install '\n",
" 'ray. Please run `pip install ray`. '\n",
" )\n",
" raise ModuleNotFoundError(msg) from e\n",
" kwargs = dict(ray_address=self.ray_address)\n",
" else:\n",
" from multiprocessing import Pool\n",
" kwargs = dict()\n",
" \n",
" with Pool(self.n_jobs, **kwargs) as executor:\n",
" for model_args in self.models:\n",
" model, *args = _as_tuple(model_args)\n",
" model_name = _build_forecast_name(model, *args)\n",
Expand Down Expand Up @@ -812,7 +827,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.7.12"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ status = 2

# Optional. Same format as setuptools requirements
requirements = numba numpy pandas scipy statsmodels
ray_requirements = ray
# Optional. Same format as setuptools console_scripts
# console_scripts =
# Optional. Same format as setuptools dependency-links
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9'.split()

requirements = cfg.get('requirements','').split()
ray_requirements = cfg.get('ray_requirements', '').split()
lic = licenses[cfg['license']]
min_python = cfg['min_python']

Expand All @@ -41,6 +42,7 @@
packages = setuptools.find_packages(),
include_package_data = True,
install_requires = requirements,
extras_require = {'ray': ray_requirements,},
dependency_links = cfg.get('dep_links','').split(),
python_requires = '>=' + cfg['min_python'],
long_description = open('README.md', encoding='utf8').read(),
Expand Down
21 changes: 18 additions & 3 deletions statsforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# Cell
import inspect
import logging
from multiprocessing import Pool
from functools import partial

import numpy as np
Expand Down Expand Up @@ -110,11 +109,12 @@ def _as_tuple(x):
# Cell
class StatsForecast:

def __init__(self, df, models, freq, n_jobs=1):
def __init__(self, df, models, freq, n_jobs=1, ray_address=None):
self.ga, self.uids, self.last_dates = _grouped_array_from_df(df)
self.models = models
self.freq = pd.tseries.frequencies.to_offset(freq)
self.n_jobs = n_jobs
self.ray_address = ray_address

def forecast(self, h, xreg=None, level=None):
if xreg is not None:
Expand Down Expand Up @@ -164,7 +164,22 @@ def _data_parallel_forecast(self, h, xreg, level):
from itertools import repeat

xregs = repeat(None)
with Pool(self.n_jobs) as executor:

if self.ray_address is not None:
try:
from ray.util.multiprocessing import Pool
except ModuleNotFoundError as e:
msg = (
'{e}. To use a ray cluster you have to install '
'ray. Please run `pip install ray`. '
)
raise ModuleNotFoundError(msg) from e
kwargs = dict(ray_address=self.ray_address)
else:
from multiprocessing import Pool
kwargs = dict()

with Pool(self.n_jobs, **kwargs) as executor:
for model_args in self.models:
model, *args = _as_tuple(model_args)
model_name = _build_forecast_name(model, *args)
Expand Down

0 comments on commit 85253e5

Please sign in to comment.