Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override output_path if provided in Trainer init and use temp folder for test outputs #12

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/actions/setup-uv/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Setup uv
runs:
using: 'composite'
steps:
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "0.5.1"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
8 changes: 2 additions & 6 deletions .github/workflows/pypi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ jobs:
if [[ "v$version" != "$tag" ]]; then
exit 1
fi
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "0.4.27"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Set up Python
run: uv python install 3.12
- name: Build sdist and wheel
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ jobs:
python-version: [3.9]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "0.4.27"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Lint check
Expand Down
10 changes: 3 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@ jobs:
uv-resolution: ["lowest-direct", "highest"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "0.4.27"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Unit tests
run: uv run --resolution=${{ matrix.uv-resolution }} --all-extras coverage run -m pytest trainer tests
run: uv run --resolution=${{ matrix.uv-resolution }} --all-extras coverage run --parallel
- name: Upload coverage data
uses: actions/upload-artifact@v4
with:
Expand Down
1 change: 0 additions & 1 deletion examples/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def main():
trainer = Trainer(
train_args,
config,
config.output_path,
model=model,
train_samples=model.get_data_loader(config, None, False, None, None, None),
eval_samples=model.get_data_loader(config, None, True, None, None, None),
Expand Down
2 changes: 1 addition & 1 deletion examples/train_simple_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,6 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
config.grad_clip = None

model = GANModel()
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer = Trainer(TrainerArgs(), config, model=model, gpu=0 if is_cuda else None)
trainer.config.epochs = 10
trainer.fit()
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ lint.ignore = [
"F403", # init files may have star imports for now
]

[tool.coverage.report]
show_missing = true
skip_empty = true

[tool.coverage.run]
parallel = true
source = ["trainer"]
source = ["trainer", "tests"]
command_line = "-m pytest"
27 changes: 11 additions & 16 deletions tests/test_continue_train.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
import glob
import os
import shutil

from tests import run_cli


def test_continue_train():
output_path = "output/"

command_train = "python tests/utils/train_mnist.py"
def test_continue_train(tmp_path):
command_train = f"python tests/utils/train_mnist.py --coqpit.output_path {tmp_path}"
run_cli(command_train)

continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth")))
continue_path = max(tmp_path.iterdir(), key=lambda p: p.stat().st_mtime)
number_of_checkpoints = len(list(continue_path.glob("*.pth")))

# Continue training from the best model
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1"
run_cli(command_continue)

assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth")))
assert number_of_checkpoints < len(list(continue_path.glob("*.pth")))

# Continue training from the last checkpoint
for best in glob.glob(os.path.join(continue_path, "best_model*")):
os.remove(best)
for best in continue_path.glob("best_model*"):
best.unlink()
run_cli(command_continue)

# Continue training from a specific checkpoint
restore_path = os.path.join(continue_path, "checkpoint_5.pth")
command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}"
restore_path = continue_path / "checkpoint_5.pth"
command_continue = (
f"python tests/utils/train_mnist.py --restore_path {restore_path} --coqpit.output_path {tmp_path}"
)
run_cli(command_continue)
shutil.rmtree(continue_path)
7 changes: 2 additions & 5 deletions tests/test_generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from pathlib import Path

from trainer.generic_utils import remove_experiment_folder


def test_remove_experiment_folder():
output_dir = Path("output")
run_dir = output_dir / "run"
def test_remove_experiment_folder(tmp_path):
run_dir = tmp_path / "run"
run_dir.mkdir(exist_ok=True, parents=True)

remove_experiment_folder(run_dir)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import time

import torch
Expand All @@ -10,7 +9,7 @@
is_cuda = torch.cuda.is_available()


def test_train_mnist():
def test_train_mnist(tmp_path):
model = MnistModel()
# Test StepwiseGradualLR
config = MnistModelConfig(
Expand All @@ -23,7 +22,7 @@ def test_train_mnist():
},
scheduler_after_epoch=False,
)
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer = Trainer(TrainerArgs(), config, output_path=tmp_path, model=model, gpu=0 if is_cuda else None)
trainer.train_loader = trainer.get_train_dataloader(
trainer.training_assets,
trainer.train_samples,
Expand Down
8 changes: 2 additions & 6 deletions tests/test_train_batch_size_finder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import torch

from tests.utils.mnist import MnistModel, MnistModelConfig
Expand All @@ -8,11 +6,9 @@
is_cuda = torch.cuda.is_available()


def test_train_largest_batch_mnist():
def test_train_largest_batch_mnist(tmp_path):
model = MnistModel()
trainer = Trainer(
TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
)
trainer = Trainer(TrainerArgs(), MnistModelConfig(), output_path=tmp_path, model=model, gpu=0 if is_cuda else None)

trainer.fit_with_largest_batch_size(starting_batch_size=2048)
loss1 = trainer.keep_avg_train["avg_loss"]
Expand Down
20 changes: 10 additions & 10 deletions tests/test_train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, img):
return validity


def test_overfit_mnist_simple_gan():
def test_overfit_mnist_simple_gan(tmp_path):
@dataclass
class GANModelConfig(TrainerConfig):
epochs: int = 1
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
config.grad_clip = None

model = GANModel()
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer = Trainer(TrainerArgs(), config, output_path=tmp_path, model=model, gpu=0 if is_cuda else None)

trainer.config.epochs = 1
trainer.fit()
Expand All @@ -155,7 +155,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"


def test_overfit_accelerate_mnist_simple_gan():
def test_overfit_accelerate_mnist_simple_gan(tmp_path):
@dataclass
class GANModelConfig(TrainerConfig):
epochs: int = 1
Expand Down Expand Up @@ -231,7 +231,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r

model = GANModel()
trainer = Trainer(
TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
TrainerArgs(use_accelerate=True), config, output_path=tmp_path, model=model, gpu=0 if is_cuda else None
)

trainer.eval_epoch()
Expand All @@ -249,7 +249,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"


def test_overfit_manual_optimize_mnist_simple_gan():
def test_overfit_manual_optimize_mnist_simple_gan(tmp_path):
@dataclass
class GANModelConfig(TrainerConfig):
epochs: int = 1
Expand Down Expand Up @@ -342,7 +342,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
config.grad_clip = None

model = GANModel()
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer = Trainer(TrainerArgs(), config, output_path=tmp_path, model=model, gpu=0 if is_cuda else None)

trainer.config.epochs = 1
trainer.fit()
Expand All @@ -360,7 +360,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"


def test_overfit_manual_optimize_grad_accum_mnist_simple_gan():
def test_overfit_manual_optimize_grad_accum_mnist_simple_gan(tmp_path):
@dataclass
class GANModelConfig(TrainerConfig):
epochs: int = 1
Expand Down Expand Up @@ -456,7 +456,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
config.grad_clip = None

model = GANModel()
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer = Trainer(TrainerArgs(), config, output_path=tmp_path, model=model, gpu=0 if is_cuda else None)

trainer.config.epochs = 1
trainer.fit()
Expand All @@ -474,7 +474,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"


def test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan():
def test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan(tmp_path):
@dataclass
class GANModelConfig(TrainerConfig):
epochs: int = 1
Expand Down Expand Up @@ -573,7 +573,7 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r

model = GANModel()
trainer = Trainer(
TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
TrainerArgs(use_accelerate=True), config, output_path=tmp_path, model=model, gpu=0 if is_cuda else None
)

trainer.config.epochs = 1
Expand Down
39 changes: 33 additions & 6 deletions tests/test_train_mnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import torch

from tests.utils.mnist import MnistModel, MnistModelConfig
Expand All @@ -8,11 +6,11 @@
is_cuda = torch.cuda.is_available()


def test_train_mnist():
def test_train_mnist(tmp_path):
model = MnistModel()
trainer = Trainer(
TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
)

# Parsing command line args
trainer = Trainer(TrainerArgs(), MnistModelConfig(), output_path=tmp_path, model=model, gpu=0 if is_cuda else None)

trainer.fit()
loss1 = trainer.keep_avg_train["avg_loss"]
Expand All @@ -21,3 +19,32 @@ def test_train_mnist():
loss2 = trainer.keep_avg_train["avg_loss"]

assert loss1 > loss2

# Without parsing command line args
args = TrainerArgs()

trainer2 = Trainer(
args,
MnistModelConfig(),
output_path=tmp_path,
model=model,
gpu=0 if is_cuda else None,
parse_command_line_args=False,
)
trainer2.fit()
loss3 = trainer2.keep_avg_train["avg_loss"]

args.continue_path = str(max(tmp_path.iterdir(), key=lambda p: p.stat().st_mtime))

trainer3 = Trainer(
args,
MnistModelConfig(),
output_path=tmp_path,
model=model,
gpu=0 if is_cuda else None,
parse_command_line_args=False,
)
trainer3.fit()
loss4 = trainer3.keep_avg_train["avg_loss"]

assert loss3 > loss4
1 change: 0 additions & 1 deletion tests/utils/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def main():
trainer = Trainer(
train_args,
config,
config.output_path,
model=model,
train_samples=model.get_data_loader(config, None, False, None, None, None),
eval_samples=model.get_data_loader(config, None, True, None, None, None),
Expand Down
5 changes: 5 additions & 0 deletions trainer/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from trainer.logger import logger


def is_pytorch_at_least_2_3() -> bool:
"""Check if the installed Pytorch version is 2.3 or higher."""
return Version(torch.__version__) >= Version("2.3")


def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")
Expand Down
Loading