Skip to content

Commit

Permalink
Add end_pos for contrast_targets_alignments
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 26, 2023
1 parent b95fd18 commit 3a8638b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def attribute(
),
special_tokens=self.attribution_model.special_tokens,
start_pos=attr_pos_start,
end_pos=attr_pos_end,
)
attributed_fn_args["contrast_targets_alignments"] = contrast_targets_alignments
if "contrast_targets" in step_scores_args:
Expand Down
2 changes: 2 additions & 0 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def format_contrast_targets_alignments(
contrast_tokens: List[List[str]],
special_tokens: List[str] = [],
start_pos: int = 0,
end_pos: Optional[int] = None,
) -> Tuple[DecoderOnlyBatch, Optional[List[List[Tuple[int, int]]]]]:
# Ensure that the contrast_targets_alignments are in the correct format (list of lists of idxs pairs)
if contrast_targets_alignments:
Expand Down Expand Up @@ -184,6 +185,7 @@ def format_contrast_targets_alignments(
fill_missing=True,
special_tokens=special_tokens,
start_pos=start_pos,
end_pos=end_pos,
)
)
return adjusted_alignments
Expand Down
16 changes: 11 additions & 5 deletions inseq/utils/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def get_adjusted_alignments(
fill_missing: bool = False,
special_tokens: List[str] = [],
start_pos: int = 0,
end_pos: Optional[int] = None,
) -> List[Tuple[int, int]]:
is_auto_aligned = False
if fill_missing and not target_tokens:
Expand Down Expand Up @@ -312,12 +313,17 @@ def get_adjusted_alignments(
# Filling alignments with missing tokens
if fill_missing:
filled_alignments = []
for pair_idx in range(start_pos, len(target_tokens)):
match_pairs = [pair for pair in alignments if pair[0] == pair_idx]
if end_pos is None:
end_pos = len(target_tokens)
for pair_idx in range(start_pos, end_pos):
match_pairs = [pair for pair in alignments if pair[0] == pair_idx and 0 <= pair[1] < len(contrast_tokens)]

if not match_pairs:
# Assuming 1:1 mapping to cover all tokens from the original sequence
filled_alignments.append((pair_idx, pair_idx))
if pair_idx < len(contrast_tokens):
# Assuming 1:1 mapping to cover all tokens from the original sequence
filled_alignments.append((pair_idx, pair_idx))
else:
filled_alignments.append((pair_idx, len(contrast_tokens) - 1))
else:
match_pairs_unaligned = [p for p in match_pairs if p[1] not in [f[1] for f in filled_alignments]]
# If found, use the first match that containing an unaligned target token, first match otherwise
Expand All @@ -329,7 +335,7 @@ def get_adjusted_alignments(
" sequence.\nFilling missing position with 1:1 position alignments."
)
if is_auto_aligned:
filled_alignments = [(a_idx, b_idx) for a_idx, b_idx in filled_alignments if a_idx >= start_pos]
filled_alignments = [(a_idx, b_idx) for a_idx, b_idx in filled_alignments if start_pos <= a_idx < end_pos]
logger.warning(
f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic "
f"sequences, or for languages not covered by the aligner.\nGenerated alignments: {filled_alignments}"
Expand Down

0 comments on commit 3a8638b

Please sign in to comment.