From ddb5126cb3cd40567451f5f5107d93469b275edc Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Fri, 13 Sep 2024 08:51:55 +0100 Subject: [PATCH] Fix tests (#196) * cattrs 24.1.0 * msgpack 1.0.8 * multidict 6.0.5 * narwhal 1.6.3 * protobuf 5.28.0 * pyasn1 0.6.0 * pytz 2024.1 * print files * add more print * make sure unzipping is in same directory * use shutil.copyfile * fix * tidy up * format --- quartz_solar_forecast/forecasts/v2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/quartz_solar_forecast/forecasts/v2.py b/quartz_solar_forecast/forecasts/v2.py index 590630f7..2d5c3660 100644 --- a/quartz_solar_forecast/forecasts/v2.py +++ b/quartz_solar_forecast/forecasts/v2.py @@ -62,7 +62,9 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str: downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir) target_path = os.path.join(download_dir, filename) - shutil.copy2(downloaded_file, target_path) + + # copy file from downloaded_file to target_path + shutil.copyfile(downloaded_file, target_path) return target_path @@ -76,8 +78,11 @@ def _decompress_zipfile(self, filename: str) -> None: filename : str The name of the .zip file to be decompressed """ + # get the directory of the file + directory = os.path.dirname(filename) + with zipfile.ZipFile(filename, "r") as zip_file: - zip_file.extractall() + zip_file.extractall(path=directory) def load_model( self, @@ -117,7 +122,7 @@ def load_model( if not os.path.isfile(model_path): logger.info("Preparing model...") self._decompress_zipfile(zipfile_model) - + logger.info("Loading model...") loaded_model = XGBRegressor() loaded_model.load_model(model_path)