Skip to content

Commit

Permalink
convert.py: Outfile default name change and additional metadata support
Browse files Browse the repository at this point in the history
  • Loading branch information
mofosyne committed Apr 5, 2024
1 parent a307375 commit da064a8
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 24 deletions.
158 changes: 134 additions & 24 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,39 @@ def load(model_plus: ModelPlus) -> Params:

return params

@dataclass
class Metadata:
name: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
url: Optional[str] = None
description: Optional[str] = None
licence: Optional[str] = None
source_url: Optional[str] = None
source_hf_repo: Optional[str] = None

@staticmethod
def load(metadata_path: Path) -> "Metadata":
if metadata_path is None or not metadata_path.exists():
return Metadata()

with open(metadata_path, 'r') as file:
data = json.load(file)

# Create a new Metadata instance
metadata = Metadata()

# Assigning values to Metadata attributes if they exist in the JSON file
metadata.name = data.get("general.name")
metadata.author = data.get("general.author")
metadata.version = data.get("general.version")
metadata.url = data.get("general.url")
metadata.description = data.get("general.description")
metadata.license = data.get("general.license")
metadata.source_url = data.get("general.source_url")
metadata.source_hf_repo = data.get("general.source_hf_repo")

return metadata

#
# vocab
Expand Down Expand Up @@ -1053,21 +1086,41 @@ class OutputFile:
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)

def add_meta_arch(self, params: Params) -> None:
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
# Metadata About The Model And It's Provenence
name = "LLaMA"

# TODO: better logic to determine model name
if params.n_ctx == 4096:
name = "LLaMA v2"
if metadata is not None and metadata.name is not None:
name = metadata.name
elif params.path_model is not None:
name = str(params.path_model.parent).split('/')[-1]

self.gguf.add_name (name)
self.gguf.add_vocab_size (params.n_vocab)
self.gguf.add_context_length (params.n_ctx)
self.gguf.add_embedding_length (params.n_embd)
self.gguf.add_block_count (params.n_layer)
self.gguf.add_feed_forward_length (params.n_ff)
name = str(params.path_model.parent).split("/")[-1]
elif params.n_ctx == 4096:
# Heuristic detection of LLaMA v2 model
name = "LLaMA v2"

self.gguf.add_name(name)

if metadata is not None:
if metadata.author is not None:
self.gguf.add_author(metadata.author)
if metadata.version is not None:
self.gguf.add_version(metadata.version)
if metadata.url is not None:
self.gguf.add_url(metadata.url)
if metadata.description is not None:
self.gguf.add_description(metadata.description)
if metadata.licence is not None:
self.gguf.add_licence(metadata.licence)
if metadata.source_url is not None:
self.gguf.add_source_url(metadata.source_url)
if metadata.source_hf_repo is not None:
self.gguf.add_source_hf_repo(metadata.source_hf_repo)

def add_meta_arch(self, params: Params) -> None:
# Metadata About The Neural Architecture Itself
self.gguf.add_context_length(params.n_ctx)
self.gguf.add_embedding_length(params.n_embd)
self.gguf.add_block_count(params.n_layer)
self.gguf.add_feed_forward_length(params.n_ff)
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
self.gguf.add_head_count (params.n_head)
self.gguf.add_head_count_kv (params.n_head_kv)
Expand Down Expand Up @@ -1170,13 +1223,14 @@ def close(self) -> None:
@staticmethod
def write_vocab_only(
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)

of = OutputFile(fname_out, endianess=endianess)

# meta data
of.add_meta_model(params, metadata)
of.add_meta_arch(params)
of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
Expand All @@ -1203,12 +1257,14 @@ def write_all(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
metadata: Metadata = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)

of = OutputFile(fname_out, endianess=endianess)

# meta data
of.add_meta_model(params, metadata)
of.add_meta_arch(params)
if isinstance(vocab, Vocab):
of.add_meta_vocab(vocab)
Expand Down Expand Up @@ -1244,6 +1300,37 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
raise ValueError(f"Unexpected combination of types: {name_to_type}")


def model_parameter_count(model: LazyModel) -> int:
total_model_parameters = 0
for i, (name, lazy_tensor) in enumerate(model.items()):
sum_weights_in_tensor = 1
for dim in lazy_tensor.shape:
sum_weights_in_tensor *= dim
total_model_parameters += sum_weights_in_tensor
return total_model_parameters


def model_parameter_count_rounded_notation(model_params_count: int) -> str:
if model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"

return f"{round(scaled_model_params)}{scale_suffix}"


def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
for (name, tensor) in model.items()}
Expand Down Expand Up @@ -1423,13 +1510,26 @@ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) ->
return vocab, special_vocab


def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
namestr = {
GGMLFileType.AllF32: "f32",
GGMLFileType.MostlyF16: "f16",
GGMLFileType.MostlyQ8_0:"q8_0",
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
quantization = {
GGMLFileType.AllF32: "F32",
GGMLFileType.MostlyF16: "F16",
GGMLFileType.MostlyQ8_0: "Q8_0",
}[file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"

parameters = model_parameter_count_rounded_notation(model_params_count)

version = ""
if metadata is not None and metadata.version is not None:
version = f"-{metadata.version}"

name = "ggml-model"
if metadata is not None and metadata.name is not None:
name = metadata.name
elif params.path_model is not None:
name = params.path_model.name

ret = model_paths[0].parent / f"{name}{version}-{parameters}-{quantization}.gguf"
if ret in model_paths:
sys.stderr.write(
f"Error: Default output path ({ret}) would overwrite the input. "
Expand Down Expand Up @@ -1466,8 +1566,12 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")

args = parser.parse_args(args_in)

metadata = Metadata.load(args.metadata)

if args.no_vocab and args.vocab_only:
raise ValueError("--vocab-only does not make sense with --no-vocab")

Expand All @@ -1481,6 +1585,9 @@ def main(args_in: list[str] | None = None) -> None:
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)

model_params_count = model_parameter_count(model_plus.model)
print(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")

if args.dump:
do_dump_model(model_plus)
return
Expand Down Expand Up @@ -1520,27 +1627,30 @@ def main(args_in: list[str] | None = None) -> None:
raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
endianess=endianess, pad_vocab=args.pad_vocab)
endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
print(f"Wrote {outfile}")
return

if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
vocab = model_plus.vocab

print(f"Vocab info: {vocab}")
print(f"Special vocab info: {special_vocab}")
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = True,
n_vocab = vocab.vocab_size)

print(f"Special vocab info: {special_vocab}")
model = model_plus.model
model = convert_model_names(model, params, args.skip_unknown)
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)

params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")

OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
print(f"Wrote {outfile}")


Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class General:
ALIGNMENT = "general.alignment"
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
URL = "general.url"
DESCRIPTION = "general.description"
LICENSE = "general.license"
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def add_architecture(self) -> None:
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)

def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)

def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)

Expand All @@ -305,6 +308,9 @@ def add_url(self, url: str) -> None:
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)

def add_licence(self, licence: str) -> None:
self.add_string(Keys.General.LICENSE, licence)

def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)

Expand Down
2 changes: 2 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ enum llm_kv {
LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION,
LLM_KV_GENERAL_URL,
LLM_KV_GENERAL_DESCRIPTION,
LLM_KV_GENERAL_LICENSE,
Expand Down Expand Up @@ -330,6 +331,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },
{ LLM_KV_GENERAL_URL, "general.url" },
{ LLM_KV_GENERAL_DESCRIPTION, "general.description" },
{ LLM_KV_GENERAL_LICENSE, "general.license" },
Expand Down

0 comments on commit da064a8

Please sign in to comment.