Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added functionalities to save and load StateDict to and from a safetensors #57

Merged
merged 11 commits into from
Jan 23, 2024
12 changes: 6 additions & 6 deletions .devcontainer/cpu/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"context": "../..",
"dockerfile": "../Dockerfile",
"args": {
"CLANG_VERSION": ""
}
"CLANG_VERSION": "",
},
},

// Use 'forwardPorts' to make a list of ports inside the container available locally.
Expand All @@ -26,10 +26,10 @@
"ms-python.python",
"ms-vsliveshare.vsliveshare",
"DavidAnson.vscode-markdownlint",
"GitHub.copilot"
]
}
}
"GitHub.copilot",
],
},
},

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
Expand Down
122 changes: 122 additions & 0 deletions src/caustics/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from pathlib import Path
import json
import struct

DEFAULT_ENCODING = "utf-8"
SAFETENSORS_METADATA = "__metadata__"


def _normalize_path(path: "str | Path") -> Path:
# Convert string path to Path object
if isinstance(path, str):
path = Path(path)
return path


def to_file(
path: "str | Path", data: "str | bytes", encoding: str = DEFAULT_ENCODING
) -> str:
"""
Save data string or bytes to specified file path

Parameters
----------
path : str or Path
The path to save the data to
data : str | bytes
The data string or bytes to save to file
encoding : str, optional
The string encoding to use, by default "utf-8"

Returns
-------
str
The path string where the data is saved
"""
# TODO: Update to allow for remote paths saving

# Convert string data to bytes
if isinstance(data, str):
data = data.encode(encoding)

# Normalize path to pathlib.Path object
path = _normalize_path(path)

path.write_bytes(data)
return str(path.absolute())


def from_file(path: "str | Path") -> bytes:
"""
Load data from specified file path

Parameters
----------
path : str or Path
The path to load the data from

Returns
-------
bytes
The data bytes loaded from the file
"""
# TODO: Update to allow for remote paths loading

# Normalize path to pathlib.Path object
path = _normalize_path(path)

return path.read_bytes()


def _get_safetensors_header(path: "str | Path") -> dict:
"""
Read specified file header to a dictionary

Parameters
----------
path : str or Path
The path to get header from

Returns
-------
dict
The header dictionary
"""
# TODO: Update to allow for remote paths loading of header

# Normalize path to pathlib.Path object
path = _normalize_path(path)

# Doing this avoids reading the whole safetensors
# file in case that it's large
with open(path, "rb") as f:
# Get the size of the header by reading first 8 bytes
(length_of_header,) = struct.unpack("<Q", f.read(8))

# Get the full header
header = json.loads(f.read(length_of_header))

# Only return the metadata
# if it's not even there, just return blank dict
return header


def get_safetensors_metadata(path: "str | Path") -> dict:
"""
Get the metadata from the specified file path

Parameters
----------
path : str or Path
The path to get the metadata from

Returns
-------
dict
The metadata dictionary
"""
header = _get_safetensors_header(path)

# Only return the metadata
# if it's not even there, just return blank dict
return header.get(SAFETENSORS_METADATA, {})
67 changes: 67 additions & 0 deletions src/caustics/sims/simulator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from pathlib import Path
from typing import Dict
from torch import Tensor

from ..parametrized import Parametrized
from .state_dict import StateDict
from ..namespace_dict import NestedNamespaceDict

__all__ = ("Simulator",)

Expand Down Expand Up @@ -27,5 +32,67 @@ def __call__(self, *args, **kwargs):

return self.forward(packed_args, *rest_args, **kwargs)

@staticmethod
def __set_module_params(module: Parametrized, params: Dict[str, Tensor]):
for k, v in params.items():
setattr(module, k, v)

def state_dict(self) -> StateDict:
return StateDict.from_params(self.params)

def load_state_dict(self, file_path: "str | Path") -> "Simulator":
"""
Loads and then sets the state of the simulator from a file

Parameters
----------
file_path : str | Path
The file path to a safetensors file
to load the state from

Returns
-------
Simulator
The simulator with the loaded state
"""
loaded_state_dict = StateDict.load(file_path)
self.set_state_dict(loaded_state_dict)
return self

def set_state_dict(self, state_dict: StateDict) -> "Simulator":
"""
Sets the state of the simulator from a state dict

Parameters
----------
state_dict : StateDict
The state dict to load from

Returns
-------
Simulator
The simulator with the loaded state
"""
# TODO: Do some checks for the state dict metadata

# Convert to nested namespace dict
param_dicts = NestedNamespaceDict(state_dict)

# Grab params for the current module
self_params = param_dicts.pop(self.name)

def _set_params(module):
# Start from root, and move down the DAG
if module.name in param_dicts:
module_params = param_dicts[module.name]
self.__set_module_params(module, module_params)
if module._childs != {}:
for child in module._childs.values():
_set_params(child)

# Set the parameters of the current module
self.__set_module_params(self, self_params)

# Set the parameters of the children modules
_set_params(self)
return self
Loading
Loading