Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* l3

* model load

* support params file

* typo

* typo
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 201ca6e commit 071f932
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 43 deletions.
20 changes: 19 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
print("Loading model ...")
t0 = time.time()
model = _load_model(
checkpoint_path, device=device, precision=precision, use_tp=False)
checkpoint_path,
args.checkpoint_dir,
args.params_path,
device=device,
precision=precision,
use_tp=False
)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down Expand Up @@ -152,6 +158,18 @@ def cli():
default="not_specified",
help="Model checkpoint path.",
)
parser.add_argument(
"--checkpoint-dir",
type=Path,
default=None,
help="Model checkpoint directory.",
)
parser.add_argument(
"--params-path",
type=Path,
default=None,
help="Parameter file path.",
)
parser.add_argument(
"--output-pte-path",
type=str,
Expand Down
78 changes: 73 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,50 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
return torch.tensor(tokens, dtype=torch.int, device=device)


def _load_model(checkpoint_path, device, precision, use_tp=False):
def _load_model(
checkpoint_path,
checkpoint_dir,
params_path,
device,
precision,
use_tp=False
):
use_cuda = "cuda" in device
with torch.device("meta"):
model = Transformer.from_name(checkpoint_path.parent.name)
if params_path:
model = Transformer.from_params(params_path)
else:
model = Transformer.from_name(checkpoint_path.parent.name)

# checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
cps = []
if checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
checkpoint_path = None
for i in range(4):
cp_name = f"consolidated.{i}.pth"
print(f"Loading {cp_name}")
cps.append(
torch.load(
os.path.join(checkpoint_dir, cp_name),
map_location=device,
mmap=True,
)
)

checkpoint = {}
for key in cps[0].keys():
if not torch.allclose(cps[0][key], cps[1][key]):
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
if key.endswith("wo.weight") or key.endswith("w2.weight"):
checkpoint[key] = torch.cat(values, dim=1)
else:
checkpoint[key] = torch.cat(values, dim=0)
else:
checkpoint[key] = cps[0][key]
else:
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True, weights_only=True)

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]

Expand All @@ -306,6 +344,8 @@ def main(
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Optional[Path] = None,
checkpoint_dir: Optional[Path] = None,
params_path: Optional[Path] = None,
tokenizer_path: Optional[Path] = None,
compile: bool = True,
compile_prefill: bool = False,
Expand Down Expand Up @@ -351,7 +391,14 @@ def main(

print("Loading model ...")
t0 = time.time()
model_ = _load_model(checkpoint_path, device, precision, use_tp)
model_ = _load_model(
checkpoint_path,
checkpoint_dir,
params_path,
device,
precision,
use_tp
)
if dso_path:
assert not model_dtype, f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
Expand Down Expand Up @@ -390,7 +437,14 @@ def main(
model.to(dtype=name_to_dtype(model_dtype))

if is_speculative:
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
draft_model = _load_model(
draft_checkpoint_path,
None,
None,
device,
precision,
use_tp
)
else:
draft_model = None

Expand Down Expand Up @@ -553,6 +607,18 @@ def cli():
default=None,
help="Model checkpoint path.",
)
parser.add_argument(
"--checkpoint-dir",
type=Path,
default=None,
help="Model checkpoint directory.",
)
parser.add_argument(
"--params-path",
type=Path,
default=None,
help="Parameter file path.",
)
parser.add_argument(
"--tokenizer-path",
type=Path,
Expand Down Expand Up @@ -621,6 +687,8 @@ def cli():
args.top_k,
args.temperature,
args.checkpoint_path,
args.checkpoint_dir,
args.params_path,
args.tokenizer_path,
args.compile,
args.compile_prefill,
Expand Down
94 changes: 57 additions & 37 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,39 @@ class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
# n_head in gpt-fast
n_heads: int = 32
dim: int = 4096
intermediate_size: int = None
# hidden dim is intermediate_size in gpt-fast
hidden_dim: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5

multiple_of = 256
ffn_dim_multiplier = None

def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
self.n_local_heads = self.n_heads
if self.hidden_dim is None:
# If hidden_dim is not explicitly set in the ModelArgs,
# then calculate implicitly based on dim and
# also multiple of `args.multiple_of`
multiple_of = self.multiple_of
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
hidden_dim = int(2 * hidden_dim / 3)
if self.ffn_dim_multiplier is not None:
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.head_dim = self.dim // self.n_heads

@classmethod
def from_params(cls, params_path):
with open(params_path, "r") as f:
params = json.loads(f.read())
return cls(**params)

@classmethod
def from_name(cls, name: str):
print(f"name {name}")
Expand Down Expand Up @@ -70,47 +86,47 @@ def from_name(cls, name: str):
"CodeLlama-7b-Python-hf": dict(
block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
"30B": dict(n_layer=60, n_head=52, dim=6656),
"7B": dict(n_layer=32, n_heads=32, dim=4096),
"13B": dict(n_layer=40, n_heads=40, dim=5120),
"30B": dict(n_layer=60, n_heads=52, dim=6656),
"34B": dict(
n_layer=48,
n_head=64,
n_heads=64,
dim=8192,
vocab_size=32000,
n_local_heads=8,
intermediate_size=22016,
hidden_dim=22016,
rope_base=1000000,
), # CodeLlama-34B-Python-hf
"70B": dict(
n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672
n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672
),
"Mistral-7B": dict(
n_layer=32,
n_head=32,
n_heads=32,
n_local_heads=8,
dim=4096,
intermediate_size=14336,
hidden_dim=14336,
vocab_size=32000,
),
"Mistral-7B-Instruct-v0.1": dict(
n_layer=32,
n_head=32,
n_heads=32,
n_local_heads=8,
dim=4096,
intermediate_size=14336,
hidden_dim=14336,
vocab_size=32000,
),
"Mistral-7B-Instruct-v0.2": dict(
n_layer=32,
n_head=32,
n_heads=32,
n_local_heads=8,
dim=4096,
intermediate_size=14336,
hidden_dim=14336,
vocab_size=32000,
),
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"stories15M": dict(n_layer=6, n_heads=6, dim=288),
"stories110M": dict(n_layer=12, n_heads=12, dim=768),
}


Expand Down Expand Up @@ -160,7 +176,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
and self.max_batch_size >= max_batch_size
):
return
head_dim = self.config.dim // self.config.n_head
head_dim = self.config.dim // self.config.n_heads
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
Expand All @@ -170,8 +186,8 @@ def setup_caches(self, max_batch_size, max_seq_length):
)

freqs_cis = precompute_freqs_cis(
self.config.block_size,
self.config.dim // self.config.n_head,
self.config.dim // self.config.n_heads,
self.config.block_size * 2,
self.config.rope_base,
)
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
Expand Down Expand Up @@ -202,6 +218,10 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))

@classmethod
def from_params(cls, params_path: str):
return cls(ModelArgs.from_params(params_path))


class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
Expand All @@ -222,19 +242,19 @@ def forward(
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
assert config.dim % config.n_heads == 0

# key, query, value projections for all heads, but in a batch
# total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)

self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
Expand Down Expand Up @@ -263,7 +283,7 @@ def forward(
# kv_size = self.n_local_heads * self.head_dim
# q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

Expand All @@ -275,8 +295,8 @@ def forward(
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
Expand All @@ -288,9 +308,9 @@ def forward(
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
Expand All @@ -309,8 +329,8 @@ def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
# transpsoed first two arguments to align with model in ET
def precompute_freqs_cis(n_elem: int, seq_len: int, base: int = 10000) -> Tensor:
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
Expand Down

0 comments on commit 071f932

Please sign in to comment.