forked from Ciela-Institute/caustics
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Added functionalities to save and load StateDict to and from a safetensors #57
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Added new 'io' module that contains 'to_file' method, which saves data string or bytes to a file path
* Added 'save' method for saving state dictionary safetensors bytes to a file. * Added 'ImmutableODict' class for 'StateDict' to inherit from, which cracks down on ensuring that attributes are also immutable after creation. * Modified '_metadata' attribute to be of type 'ImmutableODict' so that users can't also modify the key and values in this dictionary.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## state_dict #57 +/- ##
=============================================
Coverage ? 87.56%
=============================================
Files ? 38
Lines ? 2099
Branches ? 0
=============================================
Hits ? 1838
Misses ? 261
Partials ? 0 ☔ View full report in Codecov by Sentry. |
Fixed type or type error for typehints by making the statement as string.
Added metadata arg to 'StateDict' for allowing loading of existing state dictionary by updating the metadata on creation ONLY.
* Added 'from_file' function to read the full bytes of safetensors file. * Added 'get_safetensors_metadata' to extract only the '__metadata__' from safetensors file * Added '_get_safetensors_header' private function to extract the full header of the safetensors * Added '_normalize_path' private function for normalizing paths to a Path object
lsetiawan
changed the title
feat: Add a way to save
feat: Added functionalities to save and load Jan 20, 2024
StateDict
to a safetensors fileStateDict
to and from a safetensors file
lsetiawan
changed the title
feat: Added functionalities to save and load
feat: Added functionalities to save Jan 22, 2024
StateDict
to and from a safetensors fileStateDict
to a safetensors file
* Added 'set_state_dict' to set parameters for parents and children recursively from a state dictionary object * Added 'load_state_dict' to load and set parameters from a file
lsetiawan
changed the title
feat: Added functionalities to save
feat: Added functionalities to save and load StateDict to and from a safetensors
Jan 22, 2024
StateDict
to a safetensors file
uwcdc
approved these changes
Jan 23, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
lsetiawan
added a commit
that referenced
this pull request
Jan 29, 2024
…nstitute#149) * feat: Add 'state_dict' method to create state dictionary (#51) * feat: Add more public attributes Added 'children' and 'parents' public attributes for easier access to objects that are above and below the tree hiearchy. Additionally, added '_key_maps' attribute to easily access the arguments that contains Parametrized class. * feat: Add 'state_dict' attribute Added 'state_dict' attribute to capture the parameters for Simulation object. The 'state_dict' is a immutable MappingProxyType dictionary that follows similar structure as 'safetensors' file format. * revert: Bring back the order of key/value for _module_key_map * refactor: Separate out state_dict to a class Created a new class called 'StateDict' for state dictionary called from 'Simulator.state_dict()' * revert: Removed '_name' from default attribute * chore(deps): Add 'safetensors' as dependency * revert: Remove 'children' and 'parents' irrelevant * docs: Add doc to 'from_params' * test: Add test for 'state_dict' method * test: Renamed duplicate 'test_simulator' * feat: Allow from_params to take in 'NamespaceDict' Added a catch for NestedNamespaceDict and grab only the 'static' parameters. Additionally, now allows 'from_params' to take in only 'NamespaceDict', so essentially if the parameters is already the static one, it still works. * test: Move 'simple_common_sim' fixture to conftest * test: Added tests for 'StateDict' class * test: Remove directly getting the immutable error object * test: Update to_safetensors test to compare bytes with digest * test: Fix error matching for setitem and delitem * test: Change to not use a fixture for to_safetensors test * test: Remove hashing functions to avoid potential difference among OS * test: Load the tensors bytes back in and compare * chore: Apply review suggestions * fix: Fixed missing parameters in "dynamic" (#56) * fix: Fixed missing parameters in "dynamic" Fixed the missing parameters that are intentionally None in the "dynamic" section. Now StateDict contains dynamic parameters. Also introduced '_sanitize' function to make None to empty tensor of size 0 since safetensors format doesn't except None. * test: Fix state dict comparison with empty torch Fixed the comparison for StateDicts. Since comparing empty tensors doesn't work, I created isEquals internal function to the 'test_from_params' test to compare values one by one. * test: Create utility and update tests * test: Refactored how helpers are setup in tests * refactor: Extract class name to variable * feat: Added functionalities to save and load StateDict to and from a safetensors (#57) * feat(io): Add new 'io' module Added new 'io' module that contains 'to_file' method, which saves data string or bytes to a file path * feat(StateDict): Added 'save' method to save to file * Added 'save' method for saving state dictionary safetensors bytes to a file. * Added 'ImmutableODict' class for 'StateDict' to inherit from, which cracks down on ensuring that attributes are also immutable after creation. * Modified '_metadata' attribute to be of type 'ImmutableODict' so that users can't also modify the key and values in this dictionary. * fix: Fix type or type typehints Fixed type or type error for typehints by making the statement as string. * feat: Added 'metadata' arg to 'StateDict' Added metadata arg to 'StateDict' for allowing loading of existing state dictionary by updating the metadata on creation ONLY. * refactor: Removed metadata on cls creation * feat: Added safetensors loading * Added 'from_file' function to read the full bytes of safetensors file. * Added 'get_safetensors_metadata' to extract only the '__metadata__' from safetensors file * Added '_get_safetensors_header' private function to extract the full header of the safetensors * Added '_normalize_path' private function for normalizing paths to a Path object * feat(StateDict): Added 'load' method for loading state safetensors file * test: Remove usage of '_sanitize' in 'helpers.sims' * style: pre-commit fixes * feat(simulator): Added methods to 'Simulator' for loading state dict * Added 'set_state_dict' to set parameters for parents and children recursively from a state dictionary object * Added 'load_state_dict' to load and set parameters from a file --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * test: Added tests for completeness of StateDict and Simulator (#59) * test(StateDict): Added tests for methods in 'StateDict' Added save and load tests for 'StateDict' to ensure that every functionality within the methods are covered in test. * test: Added test for 'load_state_dict' for 'Simulator' class Added test for the loading functionalities of a 'Simulator' class to ensure that it's working properly * test: Added tests for 'io' module Added test functions for the whole 'io' module * test: Allow for missing ok when unlink, in case it auto deleted * fix: Removed hardcoding of current dir of '.' Removed hardcoding of current directory using '.', instead use 'os.path.curdir' to be more OS agnostic * test: Add platform check before 'unlink', skip windows * fix: Changed normalized path to use absolute * fix: Fix I/O for accounting windows (#61) * refactor: Use Path.cwd instead of os.path.curdir * fix: Use open to write bytes and Path to construct path * test: Update normalize path test * test: Updated code to skip saving on Windows * test: Remove missing_ok * style: Fix mypy errors * refactor: Added option to save from pathlib Path * refactor: Cleaned up code repetition based on reviews From code review, it was pointed out that there were some code repetition that needed to be cleaned up. So this was done and now there's no need for a "helper" module for testing. However, in this commit, the "utils.py" have now become a "utils" directory and installed as a testing module. The "extract_tensors" method has become a private function that is now used in "from_params" and other places for tests. * test: Add default values for EPL and Sersic for testing Added default values for EPL and Sersic objects for performing 'load_state_dict' testing. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
We now have a way to extract a state dictionary from
Simulator
, this PR adds a way to save that state dictionary to a safetensors file. This starts with the creation ofio
module. In addition, this PR also cracks down on ensuring thatStateDict
object can't be modified once created.