Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gguf : track writer state, free unneeded tensors, cleanup #3871

Merged
merged 7 commits into from
Nov 7, 2023
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions gguf-py/gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import struct
import sys
import tempfile
from enum import IntEnum, auto
from enum import Enum, IntEnum, auto
from io import BufferedWriter
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Sequence
Expand Down Expand Up @@ -636,18 +636,17 @@ def get_type(val):
sys.exit()


class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
KV_DATA = auto()
TI_DATA = auto()


class GGUFWriter:
fout: BufferedWriter
arch: str
offset_tensor = 0
data_alignment = GGUF_DEFAULT_ALIGNMENT
kv_data = b""
kv_data_count = 0
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
tensors: list[tuple[np.ndarray[Any, Any], int]]
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any]]

@property
def pack_prefix(self):
Expand All @@ -673,27 +672,48 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True
GGUFValueType.FLOAT64: f"{self.pack_prefix}d",
GGUFValueType.BOOL: "?" ,
}
self.add_architecture()
self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b""
self.kv_data_count = 0
self.ti_data = b""
self.ti_data_count = 0
self.use_temp_file = use_temp_file
self.temp_file = None
self.tensors = []
endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian"
print(f"This gguf file is for {endianess_str} only")
self.state = WriterState.EMPTY

self.add_architecture()

def write_header_to_file(self):
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')

self.fout.write(struct.pack("<I", GGUF_MAGIC))
self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION))
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count))
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count))
self.flush()
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
#print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
self.state = WriterState.HEADER

def write_kv_data_to_file(self):
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}')

self.fout.write(self.kv_data)
self.flush()
self.state = WriterState.KV_DATA

def write_ti_data_to_file(self):
if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected output file to contain KV data, got {self.state}')

self.fout.write(self.ti_data)
self.flush()
self.state = WriterState.TI_DATA

def add_key(self, key: str):
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
Expand Down Expand Up @@ -786,6 +806,9 @@ def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n

def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None):
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')

assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"

encoded_name = name.encode("utf8")
Expand Down Expand Up @@ -815,23 +838,22 @@ def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequenc
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)

pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes

if self.temp_file is None:
self.tensors.append((tensor, pad))
if self.temp_file is None:
self.tensors.append(tensor)
return

tensor.tofile(self.temp_file)
self.write_padding(self.temp_file, tensor.nbytes)

if pad != 0:
self.temp_file.write(bytes([0] * pad))

def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))

def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')

if self.endianess==GGUFEndian.BIG:
tensor.byteswap(inplace=True)
self.write_padding(self.fout, self.fout.tell())
Expand All @@ -844,10 +866,13 @@ def write_tensors_to_file(self):
self.write_padding(self.fout, self.fout.tell())

if self.temp_file is None:
for (currtensor, currpad) in self.tensors:
currtensor.tofile(self.fout)
if currpad != 0:
self.fout.write(bytes([0] * currpad))
while True:
try:
tensor = self.tensors.pop(0)
except IndexError:
break
tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
return

self.temp_file.seek(0)
Expand Down Expand Up @@ -983,11 +1008,8 @@ def add_pad_token_id(self, id: int):


class SpecialVocab:
load_merges: bool = False
merges: list[str] = []
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
special_token_ids: dict[str, int] = {}
n_vocab: int | None = None
merges: list[str]
special_token_ids: dict[str, int]

def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False,
Expand All @@ -997,8 +1019,11 @@ def __init__(
self.special_token_ids = {}
self.n_vocab = n_vocab
self.load_merges = load_merges
self.merges = []
if special_token_types is not None:
self.special_token_types = special_token_types
else:
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
self._load(Path(path))

def _load(self, path: Path) -> None:
Expand Down