Skip to content

Commit

Permalink
Allow stacking dataloaders for training
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 8, 2024
1 parent 7e6bf5f commit d457c41
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 26 deletions.
31 changes: 31 additions & 0 deletions supirfactor_dynamical/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,34 @@ def __iter__(self):
from .stratified_file_dataset import (
StratifiedFilesDataset
)


def stack_dataloaders(loaders):

if loaders is None:
return None
elif (
isinstance(loaders, DataLoader) or
torch.is_tensor(loaders)
):
yield from loaders
else:
for loader in loaders:
yield from loader


def _shuffle_time_data(dl):
if dl is None:
return None

try:
dl.dataset.shuffle()
except AttributeError:
pass

if (
not isinstance(dl, DataLoader) and
not torch.is_tensor(dl)
):
for d in dl:
_shuffle_time_data(d)
4 changes: 2 additions & 2 deletions supirfactor_dynamical/models/_base_recurrent_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def r2_over_time(
self.eval()

self.training_r2_over_time = [
r2_score([x], self, multioutput=multioutput)
r2_score([[x]], self, multioutput=multioutput)
for x in training_dataloader.dataset.get_times_in_order()
]

if validation_dataloader is not None:

self.validation_r2_over_time = [
r2_score([x], self, multioutput=multioutput)
r2_score([[x]], self, multioutput=multioutput)
for x in validation_dataloader.dataset.get_times_in_order()
]

Expand Down
10 changes: 2 additions & 8 deletions supirfactor_dynamical/models/_model_mixins/training_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from supirfactor_dynamical._io._writer import write
from supirfactor_dynamical._utils import _check_data_offsets
from supirfactor_dynamical.postprocessing.eval import r2_score
from supirfactor_dynamical.datasets import stack_dataloaders

from torch.utils.data import DataLoader

Expand Down Expand Up @@ -696,7 +697,7 @@ def score(
_count = []

with torch.no_grad():
for data in dataloader:
for data in stack_dataloaders(dataloader):
_score.append(
loss_function(
self._slice_data_and_forward(data),
Expand Down Expand Up @@ -725,10 +726,3 @@ def score(
)

return _score


def _shuffle_time_data(dl):
try:
dl.dataset.shuffle()
except AttributeError:
pass
3 changes: 2 additions & 1 deletion supirfactor_dynamical/postprocessing/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
argmax_last_dim,
to
)
from supirfactor_dynamical.datasets import stack_dataloaders


def r2_score(
Expand All @@ -34,7 +35,7 @@ def r2_score(
_n = 0

with torch.no_grad():
for data in dataloader:
for data in stack_dataloaders(dataloader):

input_data, target_data = _extract_data(
data,
Expand Down
2 changes: 1 addition & 1 deletion supirfactor_dynamical/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_r2(self):
)

r2s = r2_score(
[[target, predicts]],
[[[target, predicts]]],
_ModelStub(),
target_data_idx=0,
input_data_idx=1,
Expand Down
85 changes: 84 additions & 1 deletion supirfactor_dynamical/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
pretrain_and_tune_dynamic_model,
process_results_to_dataframes,
process_combined_results,
train_simple_model
train_simple_model,
train_model
)

from supirfactor_dynamical.models import _CLASS_DICT
from supirfactor_dynamical.datasets import stack_dataloaders

from ._stubs import (
X,
Expand Down Expand Up @@ -97,6 +99,87 @@ def test_training(self):
)


class TestStandardTraining(_SetupMixin, unittest.TestCase):

def test_training(self):

model = get_model('static')(
self.prior
)
pre_weights = torch.clone(model.encoder[0].weight.detach())

train_model(
model,
self.static_dataloader,
10
)

model(X_tensor)

post_weights = torch.clone(model.encoder[0].weight.detach())

with self.assertRaises(AssertionError):
torch.testing.assert_close(
pre_weights,
post_weights
)

torch.testing.assert_close(
pre_weights != 0,
post_weights != 0
)

def test_stacked_dl_training(self):

model = get_model('static')(
self.prior
)
pre_weights = torch.clone(model.encoder[0].weight.detach())

_stack_data = torch.cat([
x
for x in stack_dataloaders(
(self.static_dataloader, self.static_dataloader)
)
])

_stack_data_2 = torch.cat([
x for x in self.static_dataloader
])

_stack_data_2 = torch.cat([
_stack_data_2,
_stack_data_2
])

torch.testing.assert_close(_stack_data, _stack_data_2)

train_model(
model,
(self.static_dataloader, self.static_dataloader),
10,
validation_dataloader=(
self.static_dataloader,
self.static_dataloader
)
)

model(X_tensor)

post_weights = torch.clone(model.encoder[0].weight.detach())

with self.assertRaises(AssertionError):
torch.testing.assert_close(
pre_weights,
post_weights
)

torch.testing.assert_close(
pre_weights != 0,
post_weights != 0
)


class TestCoupledTraining(_SetupMixin, unittest.TestCase):

def test_training(self):
Expand Down
2 changes: 1 addition & 1 deletion supirfactor_dynamical/training/train_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import warnings

from supirfactor_dynamical.models._model_mixins.training_mixin import (
from supirfactor_dynamical.datasets import (
_shuffle_time_data
)
from supirfactor_dynamical._utils import to
Expand Down
2 changes: 1 addition & 1 deletion supirfactor_dynamical/training/train_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
import numpy as np

from supirfactor_dynamical.models._model_mixins.training_mixin import (
from supirfactor_dynamical.datasets import (
_shuffle_time_data
)
from supirfactor_dynamical.training._utils import _set_submodels
Expand Down
5 changes: 3 additions & 2 deletions supirfactor_dynamical/training/train_simple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
to,
_nobs
)
from supirfactor_dynamical.datasets import stack_dataloaders


def train_simple_model(
Expand Down Expand Up @@ -69,7 +70,7 @@ def train_simple_model(

_batch_losses = []
_batch_n = []
for train_x in training_dataloader:
for train_x in stack_dataloaders(training_dataloader):

train_x = to(train_x, device)

Expand All @@ -93,7 +94,7 @@ def train_simple_model(

_validation_loss = []
_validation_n = []
for val_x in training_dataloader:
for val_x in stack_dataloaders(validation_dataloader):

val_x = to(val_x, device)

Expand Down
16 changes: 7 additions & 9 deletions supirfactor_dynamical/training/train_standard_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
to,
_nobs
)
from supirfactor_dynamical.datasets import (
stack_dataloaders,
_shuffle_time_data
)


def train_model(
Expand Down Expand Up @@ -65,7 +69,8 @@ def train_model(

_batch_losses = []
_batch_n = []
for train_x in training_dataloader:

for train_x in stack_dataloaders(training_dataloader):

if output_data_index is not None:
target_x = to(
Expand Down Expand Up @@ -106,7 +111,7 @@ def train_model(
if validation_dataloader is not None:

_vloss, _vn = model_ref._calculate_validation_loss(
validation_dataloader,
stack_dataloaders(validation_dataloader),
loss_function,
input_data_index=input_data_index,
output_data_index=output_data_index
Expand Down Expand Up @@ -142,10 +147,3 @@ def train_model(
to(model_ref, final_device)

return model


def _shuffle_time_data(dl):
try:
dl.dataset.shuffle()
except AttributeError:
pass

0 comments on commit d457c41

Please sign in to comment.