Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recent versions of black and lightning #260

Merged
merged 2 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_shared_cli_options(mocker, minimum_valid_train, minimum_valid_predict):
mocker.patch("zamba.cli.ModelManager.predict", pred_mock)

for command in [minimum_valid_train, minimum_valid_predict]:

# check default model is time distributed one
result = runner.invoke(app, command)
assert result.exit_code == 0
Expand Down
1 change: 0 additions & 1 deletion tests/test_load_video_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def assert_megadetector_total_or_none(original_video_metadata, video_shape, **kw


def assert_no_frames_or_correct_shape(original_video_metadata, video_shape, **kwargs):

return (video_shape["frames"] == 0) or (
(video_shape["height"] == kwargs["frame_selection_height"])
and (video_shape["width"] == kwargs["frame_selection_width"])
Expand Down
1 change: 0 additions & 1 deletion zamba/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,5 +600,4 @@ def depth(


if __name__ == "__main__":

app()
4 changes: 0 additions & 4 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def check_files_exist_and_load(

bad_load = []
if not skip_load_validation:

logger.info(
"Checking that all videos can be loaded. If you're very confident all your videos can be loaded, you can skip this with `skip_load_validation`, but it's not recommended."
)
Expand Down Expand Up @@ -503,7 +502,6 @@ def validate_filepaths_and_labels(cls, values):

# validate split column has no partial nulls or invalid values
if "split" in labels.columns:

# if split is entirely null, warn, drop column, and generate splits automatically
if labels.split.isnull().all():
logger.warning(
Expand Down Expand Up @@ -559,7 +557,6 @@ def validate_provided_species_and_use_default_model_labels(cls, values):
)

if not provided_species.issubset(model_species):

# if labels are not a subset, user cannot set use_default_model_labels to True
if values["use_default_model_labels"]:
raise ValueError(
Expand Down Expand Up @@ -677,7 +674,6 @@ def make_split(labels, values):
species_df = labels[labels[c] > 0]

if len(species_df):

# within each species, seed splits by putting one video in each set and then allocate videos based on split proportions
labels.loc[species_df.index, "split"] = expected_splits + random.choices(
list(values["split_proportions"].keys()),
Expand Down
8 changes: 4 additions & 4 deletions zamba/models/depth_estimation/depth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def depth_transforms(size):

class DepthDataset(torch.utils.data.Dataset):
def __init__(self, filepaths):

# these are hardcoded because they depend on the trained model weights used for inference
self.height = 270
self.width = 480
Expand All @@ -55,7 +54,6 @@ def __init__(self, filepaths):

logger.info(f"Running object detection on {len(filepaths)} videos.")
for video_filepath in tqdm(filepaths):

# get video array at 1 fps, use full size for detecting objects
logger.debug(f"Loading video: {video_filepath}")
try:
Expand All @@ -73,7 +71,6 @@ def __init__(self, filepaths):

# iterate over frames
for frame_idx, (detections, scores) in enumerate(detections_per_frame):

# if anything is detected in the frame, save out relevant frames
if len(detections) > 0:
logger.debug(f"{len(detections)} detection(s) found at second {frame_idx}.")
Expand Down Expand Up @@ -234,7 +231,10 @@ def predict(self, filepaths):
for d, vid, t in zip(distance.cpu().numpy(), filepath, time):
predictions.append((vid, t, d))

predictions = pd.DataFrame(predictions, columns=["filepath", "time", "distance"],).round(
predictions = pd.DataFrame(
predictions,
columns=["filepath", "time", "distance"],
).round(
{"distance": 1}
) # round to useful number of decimal places

Expand Down
1 change: 0 additions & 1 deletion zamba/models/efficientnet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule):
def __init__(
self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs
):

super().__init__(**kwargs)

if finetune_from is None:
Expand Down
2 changes: 0 additions & 2 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def predict_model(
}

if predict_config.save is not False:

config_path = predict_config.save_dir / "predict_configuration.yaml"
logger.info(f"Writing out full configuration to {config_path}.")
with config_path.open("w") as fp:
Expand All @@ -415,7 +414,6 @@ def predict_model(
df = df.round(5)

if predict_config.save is not False:

preds_path = predict_config.save_dir / "zamba_predictions.csv"
logger.info(f"Saving out predictions to {preds_path}.")
with preds_path.open("w") as fp:
Expand Down
9 changes: 9 additions & 0 deletions zamba/pytorch_lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule
from sklearn.metrics import f1_score, top_k_accuracy_score, accuracy_score
import torch
Expand Down Expand Up @@ -273,9 +274,17 @@ def configure_optimizers(self):
}

def to_disk(self, path: os.PathLike):
"""Save out model weights to a checkpoint file on disk.

Note: this does not include callbacks, optimizer_states, or lr_schedulers.
To include those, use `Trainer.save_checkpoint()` instead.
"""

checkpoint = {
"state_dict": self.state_dict(),
"hyper_parameters": self.hparams,
"global_step": self.global_step,
"pytorch-lightning_version": pl.__version__,
}
torch.save(checkpoint, path)

Expand Down