Skip to content

Commit

Permalink
Resolve Issue princeton-nlp#126
Browse files Browse the repository at this point in the history
  • Loading branch information
john-b-yang committed Jun 18, 2024
1 parent 9de97b8 commit dee44dc
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions inference/run_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def load_data(
return dataset


def generate(model, dataset, tokenizer, temperature, top_p, fileobj, peft_path):
def generate(model, dataset, tokenizer, temperature, top_p, fileobj, model_name_or_path, peft_path):
class RepeatingTokensCriteria(StoppingCriteria):
"""
Stopping criteria based on repeating tokens in the generated sequence.
Expand Down Expand Up @@ -291,16 +291,18 @@ def __call__(self, input_ids, scores, **kwargs):
output = output[0].cpu()[input_ids.shape[-1] :]
new_len = len(output)
logger.info(
f"Generated {new_len} tokens ({total_len} total) in {(datetime.now() - start).total_seconds()} seconds (speed: {new_len / (datetime.now() - start).total_seconds()} tps)"
f"Generated {new_len} tokens ({total_len} total) in {(datetime.now() - start).total_seconds()} " + \
f"seconds (speed: {new_len / (datetime.now() - start).total_seconds()} tps)"
)
output = tokenizer.decode(output, skip_special_tokens=False)
logger.info(output[:200])
diff = extract_diff(output)
model_name_or_path += f"__{peft_path}" if peft_path is not None else ""
res = {
"instance_id": instance["instance_id"],
"full_output": output,
"model_patch": diff,
"model_name_or_path": peft_path,
"model_name_or_path": model_name_or_path,
}
print(json.dumps(res), file=fileobj, flush=True)
except Exception as e:
Expand Down Expand Up @@ -392,6 +394,7 @@ def main(
temperature=temperature,
top_p=top_p,
fileobj=f,
model_name_or_path=model_name_or_path,
peft_path=peft_path,
)
logger.info(f"Done")
Expand Down

0 comments on commit dee44dc

Please sign in to comment.