Skip to content

Commit

Permalink
Better error messages on outdated context manager.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 20, 2023
1 parent a404265 commit fced0a0
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 40 deletions.
18 changes: 1 addition & 17 deletions bindings/python/py_src/safetensors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,4 @@
__version__ = "0.2.9"

# Re-export this
from ._safetensors_rust import safe_open as rust_open, serialize, serialize_file, deserialize, SafetensorError


class safe_open:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

def __getattr__(self, __name: str):
return getattr(self.f, __name)

def __enter__(self):
self.f = rust_open(*self.args, **self.kwargs)
return self

def __exit__(self, type, value, traceback):
del self.f
from ._safetensors_rust import safe_open, serialize, serialize_file, deserialize, SafetensorError # noqa: F401
131 changes: 108 additions & 23 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,32 +466,15 @@ impl Version {
}
}

/// Opens a safetensors lazily and returns tensors as asked
///
/// Args:
/// filename (`str`, or `os.PathLike`):
/// The filename to open
///
/// framework (`str`):
/// The framework you want you tensors in. Supported values:
/// `pt`, `tf`, `flax`, `numpy`.
///
/// device (`str`, defaults to `"cpu"`):
/// The device on which you want the tensors.
#[pyclass]
#[allow(non_camel_case_types)]
#[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")]
struct safe_open {
struct Open {
metadata: Metadata,
offset: usize,
framework: Framework,
device: Device,
storage: Arc<Storage>,
}

#[pymethods]
impl safe_open {
#[new]
impl Open {
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
let file = File::open(&filename)?;
let device = device.unwrap_or(Device::Cpu);
Expand Down Expand Up @@ -714,6 +697,105 @@ impl safe_open {
)))
}
}
}

/// Opens a safetensors lazily and returns tensors as asked
///
/// Args:
/// filename (`str`, or `os.PathLike`):
/// The filename to open
///
/// framework (`str`):
/// The framework you want you tensors in. Supported values:
/// `pt`, `tf`, `flax`, `numpy`.
///
/// device (`str`, defaults to `"cpu"`):
/// The device on which you want the tensors.
#[pyclass]
#[allow(non_camel_case_types)]
#[pyo3(text_signature = "(self, filename, framework, device=\"cpu\")")]
struct safe_open {
inner: Option<Open>,
}

impl safe_open {
fn inner(&self) -> PyResult<&Open> {
let inner = self
.inner
.as_ref()
.ok_or_else(|| SafetensorError::new_err(format!("File is closed",)))?;
Ok(inner)
}
}

#[pymethods]
impl safe_open {
#[new]
fn new(filename: PathBuf, framework: Framework, device: Option<Device>) -> PyResult<Self> {
let inner = Some(Open::new(filename, framework, device)?);
Ok(Self { inner })
}

/// Return the special non tensor information in the header
///
/// Returns:
/// (`Dict[str, str]`):
/// The freeform metadata.
pub fn metadata(&self) -> PyResult<Option<BTreeMap<String, String>>> {
Ok(self.inner()?.metadata())
}

/// Returns the names of the tensors in the file.
///
/// Returns:
/// (`List[str]`):
/// The name of the tensors contained in that file
pub fn keys(&self) -> PyResult<Vec<String>> {
self.inner()?.keys()
}

/// Returns a full tensor
///
/// Args:
/// name (`str`):
/// The name of the tensor you want
///
/// Returns:
/// (`Tensor`):
/// The tensor in the framework you opened the file for.
///
/// Example:
/// ```python
/// from safetensors import safe_open
///
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
/// tensor = f.get_tensor("embedding")
///
/// ```
pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
self.inner()?.get_tensor(name)
}

/// Returns a full slice view object
///
/// Args:
/// name (`str`):
/// The name of the tensor you want
///
/// Returns:
/// (`PySafeSlice`):
/// A dummy object you can slice into to get a real tensor
/// Example:
/// ```python
/// from safetensors import safe_open
///
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
/// tensor_part = f.get_slice("embedding")[:, ::8]
///
/// ```
pub fn get_slice(&self, name: &str) -> PyResult<PySafeSlice> {
self.inner()?.get_slice(name)
}

pub fn __enter__(slf: Py<Self>) -> Py<Self> {
// SAFETY: This code is extremely important to the GPU fast load.
Expand All @@ -726,9 +808,10 @@ impl safe_open {
// of the context manager lifecycle.
Python::with_gil(|py| -> PyResult<()> {
let _self: &safe_open = &slf.borrow(py);
if let (Device::Cuda(_), Framework::Pytorch) = (&_self.device, &_self.framework) {
let inner = _self.inner()?;
if let (Device::Cuda(_), Framework::Pytorch) = (&inner.device, &inner.framework) {
let module = get_module(py, &TORCH_MODULE)?;
let device: PyObject = _self.device.clone().into_py(py);
let device: PyObject = inner.device.clone().into_py(py);
let torch_device = module
.getattr(intern!(py, "cuda"))?
.getattr(intern!(py, "device"))?;
Expand All @@ -742,10 +825,11 @@ impl safe_open {
}

pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
if let (Device::Cuda(_), Framework::Pytorch) = (&self.device, &self.framework) {
let inner = self.inner().unwrap();
if let (Device::Cuda(_), Framework::Pytorch) = (&inner.device, &inner.framework) {
Python::with_gil(|py| -> PyResult<()> {
let module = get_module(py, &TORCH_MODULE)?;
let device: PyObject = self.device.clone().into_py(py);
let device: PyObject = inner.device.clone().into_py(py);
let torch_device = module
.getattr(intern!(py, "cuda"))?
.getattr(intern!(py, "device"))?;
Expand All @@ -756,6 +840,7 @@ impl safe_open {
})
.ok();
}
self.inner = None;
}
}

Expand Down
3 changes: 3 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def test_get_correctly_dropped(self):
with safe_open("./out.safetensors", framework="pt") as f:
pass

with self.assertRaises(SafetensorError):
print(f.keys())

with open("./out.safetensors", "w") as g:
g.write("something")

Expand Down

0 comments on commit fced0a0

Please sign in to comment.