Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Feb 4, 2025
1 parent 3a5e033 commit dbfc9ea
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/tabpfn/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def select_features(x: torch.Tensor, sel: torch.Tensor) -> torch.Tensor:
if B == 1:
return x[:, :, sel[0]]

new_x = torch.zeros((sequence_length, B, total_features), device=x.device, dtype=x.dtype)
new_x = torch.zeros(
(sequence_length, B, total_features),
device=x.device,
dtype=x.dtype,
)

# For each batch, compute the number of selected features.
sel_counts = sel.sum(dim=-1) # shape: (B,)
Expand Down

0 comments on commit dbfc9ea

Please sign in to comment.