Skip to content

Commit

Permalink
remove debug messages (pytorch#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent ad164b0 commit 32e19ea
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,22 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):


def prefill(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
*,
sequential_prefill = True,
**sampling_kwargs
) -> torch.Tensor:
print(f"x: {x}, input_pos: {input_pos}")
# print(f"x: {x}, input_pos: {input_pos}")
width = x.size(1)
assert input_pos.size(0) == width
sequential_prefill = True

if sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
print(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
#print(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
else:
# input_pos: [B, S]
Expand Down Expand Up @@ -157,13 +162,6 @@ def decode_n_tokens(
return new_tokens, new_probs


# try:
# from .thin_wrapper import model_forward
#
# except:
# print("compiled model load not successful, running eager model")


def model_forward(model, x, input_pos):
return model(x, input_pos)

Expand Down Expand Up @@ -374,7 +372,7 @@ def _main(
encoded = encode_tokens(
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
)
print(encoded)
# print(encoded)
prompt_length = encoded.size(0)

model_size = sum(
Expand Down

0 comments on commit 32e19ea

Please sign in to comment.