diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py index a267fec..1eb7212 100644 --- a/inseq/commands/attribute_context/attribute_context.py +++ b/inseq/commands/attribute_context/attribute_context.py @@ -157,13 +157,16 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM output_full_tokens[: output_current_text_offset + cti_idx + 1], skip_special_tokens=False ).lstrip(" ") if not contextual_output: - contextual_output = output_full_tokens[output_current_text_offset + cti_idx] - + output_ctx_tokens = [output_full_tokens[output_current_text_offset + cti_idx]] + if model.is_encoder_decoder: + output_ctx_tokens.append(model.pad_token) + contextual_output = model.convert_tokens_to_string(output_ctx_tokens, skip_special_tokens=True) + else: + output_ctx_tokens = model.convert_string_to_tokens( + contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder + ) cci_kwargs = {} contextless_output = None - output_ctx_tokens = model.convert_string_to_tokens( - contextual_output, skip_special_tokens=False, as_targets=model.is_encoder_decoder - ) if args.attributed_fn is not None and is_contrastive_step_function(args.attributed_fn): if not model.is_encoder_decoder: formatted_input_current_text = concat_with_sep(