Skip to content

Commit

Permalink
feat(crypt): Add fast encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
Eta0 committed Nov 30, 2023
1 parent 53f26d4 commit a2cd158
Show file tree
Hide file tree
Showing 13 changed files with 1,705 additions and 493 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ jobs:
- name: Install Redis
run: sudo apt-get install -y redis-server

- name: Install libsodium
run: sudo apt-get install -y libsodium23

- name: Run tests
run: python -m unittest discover tests/ --verbose
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- Tensor encryption
- Encrypts all tensor weights in a file with minimal overhead
- Doesn't encrypt tensor metadata, such as:
- Tensor name
- Tensor `dtype`
- Tensor shape & size
- Requires an up-to-date version of `libsodium`
- Use `apt-get install libsodium23` on Ubuntu or Debian
- On other platforms, follow the
[installation instructions from the libsodium documentation](https://doc.libsodium.org/installation)
- Takes up less than 500 KiB once installed
- Uses a parallelized version of XSalsa20-Poly1305 as its encryption algorithm
- Splits each tensor's weights into ≤ 2 MiB chunks, encrypted separately
- Example usage: see [examples/encryption.py](examples/encryption.py)

## [2.6.0] - 2023-10-30

### Added
Expand Down Expand Up @@ -220,6 +239,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `get_gpu_name`
- `no_init_or_tensor`

[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.6.0...HEAD
[2.6.0]: https://github.com/coreweave/tensorizer/compare/v2.5.1...v2.6.0
[2.5.1]: https://github.com/coreweave/tensorizer/compare/v2.5.0...v2.5.1
[2.5.0]: https://github.com/coreweave/tensorizer/compare/v2.4.0...v2.5.0
Expand Down
92 changes: 92 additions & 0 deletions examples/encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import tempfile
import time

import torch
from transformers import AutoConfig, AutoModelForCausalLM

from tensorizer import (
DecryptionParams,
EncryptionParams,
TensorDeserializer,
TensorSerializer,
)
from tensorizer.utils import no_init_or_tensor

model_ref = "EleutherAI/gpt-neo-2.7B"


def original_model(ref) -> torch.nn.Module:
return AutoModelForCausalLM.from_pretrained(ref)


def empty_model(ref) -> torch.nn.Module:
config = AutoConfig.from_pretrained(ref)
with no_init_or_tensor():
return AutoModelForCausalLM.from_config(config)


# Set a strong string or bytes passphrase here
passphrase: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD", "") or input(
"Passphrase to use for encryption: "
)

fd, path = tempfile.mkstemp(prefix="encrypted-tensors")

try:
# Encrypt a model during serialization
encryption_params = EncryptionParams.from_passphrase_fast(passphrase)

model = original_model(model_ref)
serialization_start = time.monotonic()

serializer = TensorSerializer(path, encryption=encryption_params)
serializer.write_module(model)
serializer.close()

serialization_end = time.monotonic()
del model

# Then decrypt it again during deserialization
decryption_params = DecryptionParams.from_passphrase(passphrase)

model = empty_model(model_ref)
deserialization_start = time.monotonic()

deserializer = TensorDeserializer(
path, encryption=decryption_params, plaid_mode=True
)
deserializer.load_into_module(model)
deserializer.close()

deserialization_end = time.monotonic()
del model
finally:
os.close(fd)
os.unlink(path)


def print_speed(prefix, start, end, size):
mebibyte = 1 << 20
gibibyte = 1 << 30
duration = end - start
rate = size / duration
print(
f"{prefix} {size / gibibyte:.2f} GiB model in {duration:.2f} seconds,"
f" {rate / mebibyte:.2f} MiB/s"
)


print_speed(
"Serialized and encrypted",
serialization_start,
serialization_end,
serializer.total_tensor_bytes,
)

print_speed(
"Deserialized encrypted",
deserialization_start,
deserialization_end,
deserializer.total_tensor_bytes,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"boto3>=1.26.0",
"redis>=5.0.0",
"hiredis>=2.2.0",
"pynacl>=1.5.0",
"libnacl>=2.1.0"
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ numpy>=1.19.5
protobuf>=3.19.5
psutil>=5.9.4
boto3>=1.26.0
redis==5.0.0
hiredis
redis==5.0.0
libnacl>=2.1.0
85 changes: 59 additions & 26 deletions tensorizer/_NumpyTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,74 @@
8: torch.int64,
}

# torch types with no numpy equivalents
# i.e. the only ones that need to be opaque
# Uses a comprehension to filter out any dtypes
# that don't exist in older torch versions
_ASYMMETRIC_TYPES = {
getattr(torch, t)
# Listing of types from a static copy of:
# tuple(
# dict.fromkeys(
# str(t)
# for t in vars(torch).values()
# if isinstance(t, torch.dtype)
# )
# )
_ALL_TYPES = {
f"torch.{t}": v
for t in (
"bfloat16",
"quint8",
"uint8",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
"complex32",
"complex64",
"complex128",
"bool",
"qint8",
"quint8",
"qint32",
"bfloat16",
"quint4x2",
"quint2x4",
"complex32",
)
if hasattr(torch, t)
if isinstance(v := getattr(torch, t, None), torch.dtype)
}

# torch types with no numpy equivalents
# i.e. the only ones that need to be opaque
# Uses a comprehension to filter out any dtypes
# that don't exist in older torch versions
_ASYMMETRIC_TYPES = {
_ALL_TYPES[t]
for t in {
"torch.bfloat16",
"torch.quint8",
"torch.qint8",
"torch.qint32",
"torch.quint4x2",
"torch.quint2x4",
"torch.complex32",
}
& _ALL_TYPES.keys()
}

# These types aren't supported yet because they require supplemental
# quantization parameters to deserialize correctly
_UNSUPPORTED_TYPES = {
getattr(torch, t)
for t in (
"quint8",
"qint8",
"qint32",
"quint4x2",
"quint2x4",
)
if hasattr(torch, t)
_ALL_TYPES[t]
for t in {
"torch.quint8",
"torch.qint8",
"torch.qint32",
"torch.quint4x2",
"torch.quint2x4",
}
& _ALL_TYPES.keys()
}

_DECODE_MAPPING = {str(t): t for t in _ASYMMETRIC_TYPES}
_DECODE_MAPPING = {
k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES
}


class _NumpyTensor(NamedTuple):
Expand Down Expand Up @@ -85,14 +120,12 @@ def from_buffer(
buffer=buffer,
offset=offset,
)
return cls(data=data,
numpy_dtype=numpy_dtype,
torch_dtype=torch_dtype)
return cls(data=data, numpy_dtype=numpy_dtype, torch_dtype=torch_dtype)

@classmethod
def from_tensor(cls,
tensor: Union[torch.Tensor,
torch.nn.Module]) -> "_NumpyTensor":
def from_tensor(
cls, tensor: Union[torch.Tensor, torch.nn.Module]
) -> "_NumpyTensor":
"""
Converts a torch tensor into a `_NumpyTensor`.
May use an opaque dtype for the numpy array stored in
Expand Down
Loading

0 comments on commit a2cd158

Please sign in to comment.