Skip to content

Commit

Permalink
Fix lints (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
mergennachin authored Apr 18, 2024
1 parent aa309a3 commit 3409955
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
10 changes: 6 additions & 4 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False
is_chat_model: bool = False

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
Expand Down Expand Up @@ -77,15 +77,15 @@ def from_args(cls, args): # -> BuilderArgs:
args.checkpoint_dir,
args.dso_path,
args.pte_path,
args.gguf_path
args.gguf_path,
]:
path = str(path)
if path.endswith('/'):
if path.endswith("/"):
path = path[:-1]
path_basename = os.path.basename(path)
if "chat" in path_basename:
is_chat_model = True

return cls(
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
Expand Down Expand Up @@ -189,6 +189,7 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
if is_et:
builder_args.gguf_kwargs["load_as_quantized"] = False


def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None

Expand Down Expand Up @@ -264,6 +265,7 @@ def _load_model(builder_args):

if builder_args.use_tp:
from tp import apply_tp

print("Applying tensor parallel to model ...")
apply_tp(model)

Expand Down
28 changes: 13 additions & 15 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,21 @@
import copy
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from typing import Any

import gguf

import torch
import torch.nn as nn

wd = Path(__file__).parent.resolve()
sys.path.append(str(wd))

from gguf import GGUFValueType, ReaderTensor
from quantize import (
group_dequantize_tensor_from_qparams,
pack_scales_and_zeros,
WeightOnlyInt4Linear,
)

from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float
from gguf import GGUFValueType
from model import ModelArgs, Transformer
from quantize import pack_scales_and_zeros, WeightOnlyInt4Linear

from build.gguf_util import Q4_0, to_float

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,9 +110,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
metadata = _get_metadata(reader)

arch = metadata["general.architecture"]
assert (
arch == "llama"
), "Only LLaMa models are supported by this converter."
assert arch == "llama", "Only LLaMa models are supported by this converter."

model_args = ModelArgs(
dim=metadata[f"{arch}.embedding_length"],
Expand All @@ -139,7 +131,13 @@ def load_model(gguf_file: str) -> torch.nn.Module:
return model


def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
def load_model_and_state_dict(
gguf_file: str,
*,
load_state_dict: bool = True,
load_as_quantized: bool = True,
inner_k_tiles=8,
) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
that can be loaded into it.
Expand Down
1 change: 1 addition & 0 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def from_params(cls, params_path: str):
@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict

model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
Expand Down
15 changes: 7 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import itertools

import logging
import os
import sys
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -109,12 +108,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):


def prefill(
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill = True,
**sampling_kwargs
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill=True,
**sampling_kwargs,
) -> torch.Tensor:
logging.debug(f"x: {x}, input_pos: {input_pos}")
width = x.size(1)
Expand Down Expand Up @@ -348,7 +347,7 @@ def _main(
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
# This is not a log message, it's a dangerous condition message
# This is not a log message, it's a dangerous condition message
# that we must ensure is displayed
print(
"""
Expand Down

0 comments on commit 3409955

Please sign in to comment.