Skip to content

Commit

Permalink
Clean up CLI output (pytorch#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer authored and malfet committed Jul 17, 2024
1 parent 1a8ff8f commit 770f70c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _initialize_model(
quantize,
tokenizer=None,
):
print("Loading model ...")
print("Loading model...")

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
Expand Down
2 changes: 0 additions & 2 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def from_params(cls, params_path):

@classmethod
def from_table(cls, name: str):
print(f"name {name}")
json_path = config_path / f"{name}.json"
if json_path.is_file():
return ModelArgs.from_params(json_path)
Expand All @@ -82,7 +81,6 @@ def from_table(cls, name: str):

@classmethod
def from_name(cls, name: str):
print(f"name {name}")
json_path = config_path / f"{name}.json"
if Path(json_path).is_file():
return ModelArgs.from_params(json_path)
Expand Down
16 changes: 8 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,7 @@ def _main(
if generator_args.chat_mode:
max_seq_length = 2048
print(f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye")
get_system_prompt = input("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n")
if (get_system_prompt == "y" or get_system_prompt == "Y"):
system_prompt = input("What is your system prompt? \n")
system_prompt = input("System Prompt [Optional]: ")
if is_llama3_model:
chat_formatter = ChatFormat(tokenizer)
else:
Expand All @@ -567,12 +565,12 @@ def _main(
i += 1
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? \n")
prompt = input("User: ")
if (prompt == "/bye"):
print("Exiting Chat.\n")
break
if not is_llama3_model:
if system_prompt is not None:
if system_prompt:
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
system_prompt = None # can only provide system prompt on first interaction
else:
Expand All @@ -581,7 +579,7 @@ def _main(
tokenizer, prompt, bos=True, device=builder_args.device
)
else:
if system_prompt is not None:
if system_prompt:
encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}])
system_prompt = None
elif(i == 0):
Expand All @@ -595,6 +593,8 @@ def _main(
break

if generator_args.chat_mode and i >= 0:
print("Model: ", end="")

buffer = []
period_id = tokenizer.encode(".")[0]
done_generating = False
Expand Down Expand Up @@ -667,10 +667,10 @@ def callback(x):
tokens_generated = y.size(0) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
logging.info(
logging.debug(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
)
logging.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
logging.debug(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")

if (start_pos >= max_seq_length):
print("Max Sequence Length Reached. Ending Conversation.")
Expand Down
4 changes: 2 additions & 2 deletions tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, model_path: str):
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
logger.info(f"Reloaded Tiktoken model from {model_path}")
logger.debug(f"Reloaded Tiktoken model from {model_path}")

# BOS / EOS token IDs
self.n_words: int = self.model.n_vocab
Expand All @@ -96,7 +96,7 @@ def __init__(self, model_path: str):
self.special_tokens["<|end_of_text|>"],
self.special_tokens["<|eot_id|>"],
}
logger.info(
logger.debug(
f"#words: {self.n_words} - BOS ID: {self._bos_id} - EOS ID: {self._eos_id}"
)

Expand Down

0 comments on commit 770f70c

Please sign in to comment.