diff --git a/quartz_solar_forecast/forecasts/v2.py b/quartz_solar_forecast/forecasts/v2.py index 590630f7..2d5c3660 100644 --- a/quartz_solar_forecast/forecasts/v2.py +++ b/quartz_solar_forecast/forecasts/v2.py @@ -62,7 +62,9 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str: downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir) target_path = os.path.join(download_dir, filename) - shutil.copy2(downloaded_file, target_path) + + # copy file from downloaded_file to target_path + shutil.copyfile(downloaded_file, target_path) return target_path @@ -76,8 +78,11 @@ def _decompress_zipfile(self, filename: str) -> None: filename : str The name of the .zip file to be decompressed """ + # get the directory of the file + directory = os.path.dirname(filename) + with zipfile.ZipFile(filename, "r") as zip_file: - zip_file.extractall() + zip_file.extractall(path=directory) def load_model( self, @@ -117,7 +122,7 @@ def load_model( if not os.path.isfile(model_path): logger.info("Preparing model...") self._decompress_zipfile(zipfile_model) - + logger.info("Loading model...") loaded_model = XGBRegressor() loaded_model.load_model(model_path)