Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Mar 31, 2020
1 parent 8d8a70e commit 987df24
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
14 changes: 12 additions & 2 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,15 @@ def log_row(action, mean, total):
output_string += os.linesep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.flush()
self.output_file.close()


Expand Down Expand Up @@ -214,8 +219,13 @@ def summary(self) -> str:

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.flush()
self.output_file.close()
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,10 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
<class 'NoneType'>),
None),
...
...
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
Expand Down
13 changes: 5 additions & 8 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tempfile
import os
import time
from pathlib import Path

Expand Down Expand Up @@ -30,8 +30,8 @@ def simple_profiler():


@pytest.fixture
def advanced_profiler():
profiler = AdvancedProfiler()
def advanced_profiler(tmpdir):
profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
return profiler


Expand Down Expand Up @@ -168,12 +168,9 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler):
# record at least one event
with advanced_profiler.profile("test"):
pass
# log to stdout
# log to stdout and print to file
advanced_profiler.describe()
# print to file
advanced_profiler.output_filename = Path(tmpdir, "profiler.txt")
advanced_profiler.describe()
data = Path(advanced_profiler.output_filename).read_text()
data = Path(advanced_profiler.output_fname).read_text()
assert len(data) > 0


Expand Down

0 comments on commit 987df24

Please sign in to comment.