Skip to content

Version 0.2.0

Compare
Choose a tag to compare
@ethanwharris ethanwharris released this 21 Aug 10:34
021373f

See [NEW!] in README.md for new key features

[0.2.0] - 2018-08-21

Added

  • Added the ability to pass custom arguments to the tqdm callback
  • Added an ignore_index flag to the categorical accuracy metric, similar to nn.CrossEntropyLoss. Usage: metrics=[CategoricalAccuracyFactory(ignore_index=0)]
  • Added TopKCategoricalAccuracy metric (default for key: top_5_acc)
  • Added BinaryAccuracy metric (default for key: binary_acc)
  • Added MeanSquaredError metric (default for key: mse)
  • Added DefaultAccuracy metric (use with 'acc' or 'accuracy') - infers accuracy from the criterion
  • New Trial api torchbearer.Trial to replace the Model api. Trial api is more atomic and uses the fluent pattern to allow chaining of methods.
  • torchbearer.Trial has with_x_generator and with_x_data methods to add training/validation/testing generators to the trial. There is a with_generators method to allow passing of all generators in one call.
  • torchbearer.Trial has for_x_steps and for_steps to allow running of trails without explicit generators or data tensors
  • torchbearer.Trial keeps a history of run calls which tracks number of epochs ran and the final metrics at each epoch. This allows seamless resuming of trial running.
  • torchbearer.Trial.state_dict now returns the trial history and callback list state allowing for full resuming of trials
  • torchbearer.Trial has a replay method that can replay training (with callbacks and display) from the history. This is useful when loading trials from state.
  • The backward call can now be passed args by setting state[torchbearer.BACKWARD_ARGS]
  • torchbearer.Trial implements the forward pass, loss calculation and backward call as a optimizer closure
  • Metrics are now explicitly calculated with no gradient

Changed

  • Callback decorators can now be chained to allow construction with multiple methods filled
  • Callbacks can now implement state_dict and ``load_state_dict` to allow callbacks to resume with state
  • State dictionary is now accepts StateKey objects which are unique and generated through torchbearer.state.get_state
  • State dictionary now warns when accessed with strings as this allows for collisions
  • Checkpointer callbacks will now resume from a state dict when resume=True in Trial

Deprecated

  • torchbearer.Model has been deprecated in favour of the new torchbearer.Trial api

Removed

  • Removed the MetricFactory class. Decorators still work in the same way but the Factory is no longer needed.

Fixed