Skip to content

Commit

Permalink
Fix tests (#196)
Browse files Browse the repository at this point in the history
* cattrs 24.1.0

* msgpack 1.0.8

* multidict 6.0.5

* narwhal 1.6.3

* protobuf 5.28.0

* pyasn1 0.6.0

* pytz 2024.1

* print files

* add more print

* make sure unzipping is in same directory

* use shutil.copyfile

* fix

* tidy up

* format
  • Loading branch information
peterdudfield authored Sep 13, 2024
1 parent b8c469b commit ddb5126
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions quartz_solar_forecast/forecasts/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir)

target_path = os.path.join(download_dir, filename)
shutil.copy2(downloaded_file, target_path)

# copy file from downloaded_file to target_path
shutil.copyfile(downloaded_file, target_path)

return target_path

Expand All @@ -76,8 +78,11 @@ def _decompress_zipfile(self, filename: str) -> None:
filename : str
The name of the .zip file to be decompressed
"""
# get the directory of the file
directory = os.path.dirname(filename)

with zipfile.ZipFile(filename, "r") as zip_file:
zip_file.extractall()
zip_file.extractall(path=directory)

def load_model(
self,
Expand Down Expand Up @@ -117,7 +122,7 @@ def load_model(
if not os.path.isfile(model_path):
logger.info("Preparing model...")
self._decompress_zipfile(zipfile_model)

logger.info("Loading model...")
loaded_model = XGBRegressor()
loaded_model.load_model(model_path)
Expand Down

0 comments on commit ddb5126

Please sign in to comment.