diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index 250cd700..f0e61319 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -590,11 +590,11 @@ def filtered_attribute_step( batch=batch, ) step_fn_extra_args = get_step_scores_args([score], step_scores_args) - step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu") + step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args) # Reinsert finished sentences if target_attention_mask is not None and is_filtered: step_output.remap_from_filtered(target_attention_mask, orig_batch) - step_output = step_output.detach().to("cpu") + step_output = step_output.detach() return step_output def get_attribution_args(self, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]: diff --git a/inseq/attr/feat/gradient_attribution.py b/inseq/attr/feat/gradient_attribution.py index 3eefe63e..6f3fff95 100644 --- a/inseq/attr/feat/gradient_attribution.py +++ b/inseq/attr/feat/gradient_attribution.py @@ -95,9 +95,9 @@ def attribute_step( attr, self.attribution_model.is_encoder_decoder ) return GranularFeatureAttributionStepOutput( - source_attributions=source_attributions.to("cpu") if source_attributions is not None else None, - target_attributions=target_attributions.to("cpu") if target_attributions is not None else None, - step_scores={"deltas": deltas.to("cpu")} if deltas is not None else None, + source_attributions=source_attributions if source_attributions is not None else None, + target_attributions=target_attributions if target_attributions is not None else None, + step_scores={"deltas": deltas} if deltas is not None else None, ) diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index c77c2e09..bb475707 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -475,7 +475,7 @@ def _filter_scores( if len(set(indices)) != len(indices): raise IndexError("Duplicate indices are not allowed.") if isinstance(indices, tuple): - scores = scores.index_select(dim, torch.arange(indices[0], indices[1])) + scores = scores.index_select(dim, torch.arange(indices[0], indices[1], device=scores.device)) else: scores = scores.index_select(dim, torch.tensor(indices, device=scores.device)) return scores diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index f68ffb9a..f3671244 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -249,7 +249,9 @@ def from_step_attributions( start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203 ] if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]): - empty_final_row = torch.ones(1, *target_attributions[seq_id].shape[1:]) * float("nan") + empty_final_row = torch.ones( + 1, *target_attributions[seq_id].shape[1:], device=target_attributions[seq_id].device + ) * float("nan") target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0) seq_attributions[seq_id].target_attributions = target_attributions[seq_id] if attr.step_scores is not None: