-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added predict method to trainer * Added State.PREDICT enum * Added Phase.PREDICT_BEGIN and Phase.PREDICT_END enums * Added CollectOutputs callback * Added predict example * Moving some elements from nn.functional to nn * Added sample count to save/load model
- Loading branch information
Showing
17 changed files
with
309 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch as T | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from lpd.trainer import Trainer | ||
from lpd.callbacks import StatsPrint | ||
from lpd.extensions.custom_schedulers import DoNothingToLR | ||
import lpd.utils.torch_utils as tu | ||
import lpd.utils.general_utils as gu | ||
import examples.utils as eu | ||
|
||
def get_parameters(): | ||
# N is batch size; D_in is input dimension; | ||
# H is hidden dimension; D_out is output dimension. | ||
N, D_in, H, D_out = 8, 1000, 100, 10 | ||
num_epochs = 5 | ||
data_loader = eu.examples_data_generator(N, D_in, D_out) | ||
data_loader_steps = 100 | ||
return N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps | ||
|
||
def get_trainer(N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps): | ||
|
||
device = tu.get_gpu_device_if_available() | ||
|
||
model = eu.get_basic_model(D_in, H, D_out).to(device) | ||
|
||
loss_func = nn.MSELoss(reduction='sum').to(device) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=1e-4) | ||
|
||
scheduler = DoNothingToLR() #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT | ||
|
||
metric_name_to_func = None # THIS EXAMPLE DOES NOT USE METRICS, ONLY LOSS | ||
|
||
callbacks = [ | ||
StatsPrint() | ||
] | ||
|
||
trainer = Trainer(model=model, | ||
device=device, | ||
loss_func=loss_func, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
metric_name_to_func=metric_name_to_func, | ||
train_data_loader=data_loader, | ||
val_data_loader=data_loader, | ||
train_steps=data_loader_steps, | ||
val_steps=data_loader_steps, | ||
num_epochs=num_epochs, | ||
callbacks=callbacks, | ||
name='Train-Evaluate-Predict-Example') | ||
return trainer | ||
|
||
def run(): | ||
gu.seed_all(42) # BECAUSE ITS THE ANSWER TO LIFE AND THE UNIVERSE | ||
|
||
N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps = get_parameters() | ||
|
||
trainer = get_trainer(N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps) | ||
|
||
trainer.summary() | ||
|
||
trainer.train() | ||
|
||
trainer.evaluate(data_loader, data_loader_steps) | ||
|
||
data_generator_for_predictions = eu.examples_prediction_data_generator(data_loader, data_loader_steps) | ||
predictions = trainer.predict(data_generator_for_predictions, data_loader_steps) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
from lpd.callbacks.callback_base import CallbackBase | ||
from lpd.callbacks.callback_monitor import CallbackMonitor, CallbackMonitorResult | ||
from lpd.callbacks.stats_print import StatsPrint | ||
from lpd.callbacks.model_checkpoint import ModelCheckPoint | ||
from lpd.callbacks.tensorboard import Tensorboard | ||
from lpd.callbacks.early_stopping import EarlyStopping | ||
from lpd.callbacks.scheduler_step import SchedulerStep | ||
from lpd.callbacks.callback_context import CallbackContext | ||
from lpd.callbacks.callback_monitor import CallbackMonitor, CallbackMonitorResult | ||
from lpd.callbacks.collect_outputs import CollectOutputs | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from lpd.enums import Phase, State, MonitorType, MonitorMode, StatsType | ||
from lpd.callbacks.callback_base import CallbackBase | ||
from lpd.callbacks.callback_context import CallbackContext | ||
from lpd.callbacks.callback_monitor import CallbackMonitor, CallbackMonitorResult | ||
from typing import Union, List, Optional, Dict | ||
|
||
class CollectOutputs(CallbackBase): | ||
""" | ||
This callback will collect outputs per each state, (it is currently used in trainer.predict() method.) | ||
It will collect the numpy outputs in the defined states to a dictionary. | ||
Methods: | ||
get_outputs_for_state - for a given state, returns the collected outputs | ||
Args: | ||
apply_on_phase - see in CallbackBase | ||
apply_on_states - see in CallbackBase | ||
""" | ||
|
||
def __init__(self, | ||
apply_on_phase: Phase=Phase.BATCH_END, | ||
apply_on_states: Union[State, List[State]]=None): | ||
super(CollectOutputs, self).__init__(apply_on_phase, apply_on_states) | ||
self.state_to_outputs = {} | ||
|
||
def get_outputs_for_state(self, state: State): | ||
return self.state_to_outputs[state] | ||
|
||
def __call__(self, callback_context: CallbackContext): | ||
c = callback_context #READABILITY DOWN THE ROAD | ||
state = c.trainer_state | ||
|
||
if self.should_apply_on_state(c): | ||
|
||
if state not in self.state_to_outputs: | ||
self.state_to_outputs[state] = [] | ||
|
||
last_outputs = c.trainer.get_last_outputs() | ||
self.state_to_outputs[state].append(last_outputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.