Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added functionalities to save and load StateDict to and from a safetensors #57

Merged
merged 11 commits into from
Jan 23, 2024

Conversation

lsetiawan
Copy link
Member

Overview

We now have a way to extract a state dictionary from Simulator, this PR adds a way to save that state dictionary to a safetensors file. This starts with the creation of io module. In addition, this PR also cracks down on ensuring that StateDict object can't be modified once created.

Added new 'io' module that contains
'to_file' method, which saves data string or bytes
to a file path
* Added 'save' method for saving state dictionary
safetensors bytes to a file.
* Added 'ImmutableODict' class for 'StateDict' to inherit from,
which cracks down on ensuring that attributes are also immutable
after creation.
* Modified '_metadata' attribute to be of type 'ImmutableODict'
so that users can't also modify the key and values in this dictionary.
Copy link

codecov bot commented Jan 19, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

❗ No coverage uploaded for pull request base (state_dict@cc94a43). Click here to learn what that means.

Additional details and impacted files
@@              Coverage Diff              @@
##             state_dict      #57   +/-   ##
=============================================
  Coverage              ?   87.56%           
=============================================
  Files                 ?       38           
  Lines                 ?     2099           
  Branches              ?        0           
=============================================
  Hits                  ?     1838           
  Misses                ?      261           
  Partials              ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Fixed type or type error for typehints by
making the statement as string.
Added metadata arg to 'StateDict' for allowing
loading of existing state dictionary by
updating the metadata on creation ONLY.
* Added 'from_file' function to read the full
bytes of safetensors file.
* Added 'get_safetensors_metadata' to extract
only the '__metadata__' from safetensors file
* Added '_get_safetensors_header' private function
to extract the full header of the safetensors
* Added '_normalize_path' private function for normalizing
paths to a Path object
@lsetiawan lsetiawan changed the title feat: Add a way to save StateDict to a safetensors file feat: Added functionalities to save and load StateDict to and from a safetensors file Jan 20, 2024
@lsetiawan lsetiawan changed the title feat: Added functionalities to save and load StateDict to and from a safetensors file feat: Added functionalities to save StateDict to a safetensors file Jan 22, 2024
@lsetiawan lsetiawan marked this pull request as ready for review January 22, 2024 21:36
@lsetiawan lsetiawan requested a review from uwcdc January 22, 2024 21:37
* Added 'set_state_dict' to set parameters for parents and children
recursively from a state dictionary object
* Added 'load_state_dict' to load and set parameters from a file
@lsetiawan lsetiawan changed the title feat: Added functionalities to save StateDict to a safetensors file feat: Added functionalities to save and load StateDict to and from a safetensors Jan 22, 2024
Copy link
Member

@uwcdc uwcdc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lsetiawan lsetiawan merged commit 8052970 into uw-ssec:state_dict Jan 23, 2024
13 checks passed
@lsetiawan lsetiawan deleted the add_save branch January 23, 2024 17:54
lsetiawan added a commit that referenced this pull request Jan 29, 2024
…nstitute#149)

* feat: Add 'state_dict' method to create state dictionary (#51)

* feat: Add more public attributes

Added 'children' and 'parents' public attributes
for easier access to objects that are above
and below the tree hiearchy. Additionally,
added '_key_maps' attribute to easily access the
arguments that contains Parametrized class.

* feat: Add 'state_dict' attribute

Added 'state_dict' attribute to capture the parameters
for Simulation object. The 'state_dict' is a immutable
MappingProxyType dictionary that follows similar structure
as 'safetensors' file format.

* revert: Bring back the order of key/value for _module_key_map

* refactor: Separate out state_dict to a class

Created a new class called 'StateDict' for state
dictionary called from 'Simulator.state_dict()'

* revert: Removed '_name' from default attribute

* chore(deps): Add 'safetensors' as dependency

* revert: Remove 'children' and 'parents' irrelevant

* docs: Add doc to 'from_params'

* test: Add test for 'state_dict' method

* test: Renamed duplicate 'test_simulator'

* feat: Allow from_params to take in 'NamespaceDict'

Added a catch for NestedNamespaceDict and grab only the
'static' parameters. Additionally, now allows 'from_params'
to take in only 'NamespaceDict', so essentially if the parameters
is already the static one, it still works.

* test: Move 'simple_common_sim' fixture to conftest

* test: Added tests for 'StateDict' class

* test: Remove directly getting the immutable error object

* test: Update to_safetensors test to compare bytes with digest

* test: Fix error matching for setitem and delitem

* test: Change to not use a fixture for to_safetensors test

* test: Remove hashing functions to avoid potential difference among OS

* test: Load the tensors bytes back in and compare

* chore: Apply review suggestions

* fix: Fixed missing parameters in "dynamic" (#56)

* fix: Fixed missing parameters in "dynamic"

Fixed the missing parameters that are intentionally
None in the "dynamic" section. Now StateDict
contains dynamic parameters. Also introduced
'_sanitize' function to make None to empty
tensor of size 0 since safetensors format
doesn't except None.

* test: Fix state dict comparison with empty torch

Fixed the comparison for StateDicts. Since comparing
empty tensors doesn't work, I created isEquals internal
function to the 'test_from_params' test to compare
values one by one.

* test: Create utility and update tests

* test: Refactored how helpers are setup in tests

* refactor: Extract class name to variable

* feat: Added functionalities to save and load StateDict to and from a safetensors (#57)

* feat(io): Add new 'io' module

Added new 'io' module that contains
'to_file' method, which saves data string or bytes
to a file path

* feat(StateDict): Added 'save' method to save to file

* Added 'save' method for saving state dictionary
safetensors bytes to a file.
* Added 'ImmutableODict' class for 'StateDict' to inherit from,
which cracks down on ensuring that attributes are also immutable
after creation.
* Modified '_metadata' attribute to be of type 'ImmutableODict'
so that users can't also modify the key and values in this dictionary.

* fix: Fix type or type typehints

Fixed type or type error for typehints by
making the statement as string.

* feat: Added 'metadata' arg to 'StateDict'

Added metadata arg to 'StateDict' for allowing
loading of existing state dictionary by
updating the metadata on creation ONLY.

* refactor: Removed metadata on cls creation

* feat: Added safetensors loading

* Added 'from_file' function to read the full
bytes of safetensors file.
* Added 'get_safetensors_metadata' to extract
only the '__metadata__' from safetensors file
* Added '_get_safetensors_header' private function
to extract the full header of the safetensors
* Added '_normalize_path' private function for normalizing
paths to a Path object

* feat(StateDict): Added 'load' method for loading state safetensors file

* test: Remove usage of '_sanitize' in 'helpers.sims'

* style: pre-commit fixes

* feat(simulator): Added methods to 'Simulator' for loading state dict

* Added 'set_state_dict' to set parameters for parents and children
recursively from a state dictionary object
* Added 'load_state_dict' to load and set parameters from a file

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* test: Added tests for completeness of StateDict and Simulator (#59)

* test(StateDict): Added tests for methods in 'StateDict'

Added save and load tests for 'StateDict' to ensure that every
functionality within the methods are covered in test.

* test: Added test for 'load_state_dict' for 'Simulator' class

Added test for the loading functionalities of a 'Simulator'
class to ensure that it's working properly

* test: Added tests for 'io' module

Added test functions for the whole 'io' module

* test: Allow for missing ok when unlink, in case it auto deleted

* fix: Removed hardcoding of current dir of '.'

Removed hardcoding of current directory using
'.', instead use 'os.path.curdir' to be more
OS agnostic

* test: Add platform check before 'unlink', skip windows

* fix: Changed normalized path to use absolute

* fix: Fix I/O for accounting windows (#61)

* refactor: Use Path.cwd instead of os.path.curdir

* fix: Use open to write bytes and Path to construct path

* test: Update normalize path test

* test: Updated code to skip saving on Windows

* test: Remove missing_ok

* style: Fix mypy errors

* refactor: Added option to save from pathlib Path

* refactor: Cleaned up code repetition based on reviews

From code review, it was pointed out that there were some
code repetition that needed to be cleaned up. So this was done
and now there's no need for a "helper" module for testing.
However, in this commit, the "utils.py" have now become a "utils"
directory and installed as a testing module. The "extract_tensors"
method has become a private function that is now used in "from_params"
and other places for tests.

* test: Add default values for EPL and Sersic for testing

Added default values for EPL and Sersic objects for performing
'load_state_dict' testing.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants