Skip to content

Commit

Permalink
Merge pull request #96 from Nixtla/refactor/pool
Browse files Browse the repository at this point in the history
refactor: use Pool instead of ProcessPoolExecutor
  • Loading branch information
AzulGarza authored Apr 27, 2022
2 parents e82b5ec + 990e4b7 commit 39b2a4d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"#export\n",
"import inspect\n",
"import logging\n",
"from concurrent.futures import ProcessPoolExecutor\n",
"from multiprocessing import Pool\n",
"from functools import partial\n",
"\n",
"import numpy as np\n",
Expand Down Expand Up @@ -363,15 +363,15 @@
" from itertools import repeat\n",
" \n",
" xregs = repeat(None)\n",
" with ProcessPoolExecutor(self.n_jobs) as executor:\n",
" with Pool(self.n_jobs) as executor:\n",
" for model_args in self.models:\n",
" model, *args = _as_tuple(model_args)\n",
" model_name = _build_forecast_name(model, *args)\n",
" futures = []\n",
" for ga, xr in zip(gas, xregs):\n",
" future = executor.submit(ga.compute_forecasts, h, model, xr, level, *args)\n",
" future = executor.apply_async(ga.compute_forecasts, (h, model, xr, level, *args,))\n",
" futures.append(future)\n",
" values, keys = list(zip(*[f.result() for f in futures]))\n",
" values, keys = list(zip(*[f.get() for f in futures]))\n",
" keys = keys[0]\n",
" if keys is not None:\n",
" values = np.vstack(values)\n",
Expand Down Expand Up @@ -812,7 +812,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions statsforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Cell
import inspect
import logging
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Pool
from functools import partial

import numpy as np
Expand Down Expand Up @@ -164,15 +164,15 @@ def _data_parallel_forecast(self, h, xreg, level):
from itertools import repeat

xregs = repeat(None)
with ProcessPoolExecutor(self.n_jobs) as executor:
with Pool(self.n_jobs) as executor:
for model_args in self.models:
model, *args = _as_tuple(model_args)
model_name = _build_forecast_name(model, *args)
futures = []
for ga, xr in zip(gas, xregs):
future = executor.submit(ga.compute_forecasts, h, model, xr, level, *args)
future = executor.apply_async(ga.compute_forecasts, (h, model, xr, level, *args,))
futures.append(future)
values, keys = list(zip(*[f.result() for f in futures]))
values, keys = list(zip(*[f.get() for f in futures]))
keys = keys[0]
if keys is not None:
values = np.vstack(values)
Expand Down

0 comments on commit 39b2a4d

Please sign in to comment.