Skip to content

Commit

Permalink
fixes issue pytorch#36
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind authored and malfet committed Jul 17, 2024
1 parent 15709e0 commit 9745fe0
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,8 @@ def main(
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path(
"checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
),
checkpoint_path: Optional[Path] = None,
tokenizer_path: Optional[Path] = None,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
Expand All @@ -339,14 +338,21 @@ def main(
quantize=None,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
assert checkpoint_path.is_file(), checkpoint_path

torch.manual_seed(1234)

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert (
(checkpoint_path and checkpoint_path.is_file()) or
(dso_path and Path(dso_path).is_file()) or
(pte_path and Path(pte_path).is_file())
), "need to specified a valid checkpoint path, DSO path, or PTE path"
assert not (dso_path and pte_path), "specify either DSO path or PTE path, but not both"

if (checkpoint_path and (dso_path or pte_path)):
print("Warning: checkpoint path ignored because an exported DSO or PTE path specified")

if not tokenizer_path:
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path

global print
# global print
# from tp import maybe_init_dist
# rank = maybe_init_dist()
use_tp = False
Expand Down Expand Up @@ -540,10 +546,22 @@ def cli():
parser.add_argument(
"--temperature", type=float, default=0.8, help="Temperature for sampling."
)
parser.add_argument(
"--seed",
type=int,
default=1234, # set None for release
help="Initialize torch seed"
)
parser.add_argument(
"--checkpoint-path",
type=Path,
default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
default=None,
help="Model checkpoint path.",
)
parser.add_argument(
"--tokenizer-path",
type=Path,
default=None,
help="Model checkpoint path.",
)
parser.add_argument(
Expand Down Expand Up @@ -590,6 +608,10 @@ def cli():


args = parser.parse_args()

if args.seed:
torch.manual_seed(args.seed)

main(
args.prompt,
args.interactive,
Expand All @@ -598,6 +620,7 @@ def cli():
args.top_k,
args.temperature,
args.checkpoint_path,
args.tokenizer_path,
args.compile,
args.compile_prefill,
args.profile,
Expand Down

0 comments on commit 9745fe0

Please sign in to comment.