Version 0.2.0
See [NEW!] in README.md for new key features
[0.2.0] - 2018-08-21
Added
- Added the ability to pass custom arguments to the tqdm callback
- Added an ignore_index flag to the categorical accuracy metric, similar to nn.CrossEntropyLoss. Usage:
metrics=[CategoricalAccuracyFactory(ignore_index=0)]
- Added TopKCategoricalAccuracy metric (default for key: top_5_acc)
- Added BinaryAccuracy metric (default for key: binary_acc)
- Added MeanSquaredError metric (default for key: mse)
- Added DefaultAccuracy metric (use with 'acc' or 'accuracy') - infers accuracy from the criterion
- New Trial api
torchbearer.Trial
to replace the Model api. Trial api is more atomic and uses the fluent pattern to allow chaining of methods. torchbearer.Trial
has with_x_generator and with_x_data methods to add training/validation/testing generators to the trial. There is a with_generators method to allow passing of all generators in one call.torchbearer.Trial
has for_x_steps and for_steps to allow running of trails without explicit generators or data tensorstorchbearer.Trial
keeps a history of run calls which tracks number of epochs ran and the final metrics at each epoch. This allows seamless resuming of trial running.torchbearer.Trial.state_dict
now returns the trial history and callback list state allowing for full resuming of trialstorchbearer.Trial
has a replay method that can replay training (with callbacks and display) from the history. This is useful when loading trials from state.- The backward call can now be passed args by setting
state[torchbearer.BACKWARD_ARGS]
torchbearer.Trial
implements the forward pass, loss calculation and backward call as a optimizer closure- Metrics are now explicitly calculated with no gradient
Changed
- Callback decorators can now be chained to allow construction with multiple methods filled
- Callbacks can now implement
state_dict
and ``load_state_dict` to allow callbacks to resume with state - State dictionary is now accepts StateKey objects which are unique and generated through
torchbearer.state.get_state
- State dictionary now warns when accessed with strings as this allows for collisions
- Checkpointer callbacks will now resume from a state dict when resume=True in Trial
Deprecated
torchbearer.Model
has been deprecated in favour of the newtorchbearer.Trial
api
Removed
- Removed the MetricFactory class. Decorators still work in the same way but the Factory is no longer needed.