diff --git a/art/experimental/attacks/evasion/fast_gradient.py b/art/experimental/attacks/evasion/fast_gradient.py index 48a6b2d9ac..5278cadf5a 100644 --- a/art/experimental/attacks/evasion/fast_gradient.py +++ b/art/experimental/attacks/evasion/fast_gradient.py @@ -358,12 +358,4 @@ def _compute( x_adv_result.append(x_adv_batch['pixel_values']) x_adv_result = torch.concatenate(x_adv_result) - sentinel = object() - - def myfunc(adv, x_sample=sentinel): - if x_sample is sentinel: - x_sample = x_adv - x_sample['pixel_values'] = adv.type(original_type) - return x_sample - - return myfunc(x_adv_result, x_sample=x_adv) + return x_adv.update_pixels(x_adv_result) diff --git a/art/experimental/estimators/hugging_face_multimodal/hugging_face_mm_inputs.py b/art/experimental/estimators/hugging_face_multimodal/hugging_face_mm_inputs.py index d10ab8510d..e3f9c763b4 100644 --- a/art/experimental/estimators/hugging_face_multimodal/hugging_face_mm_inputs.py +++ b/art/experimental/estimators/hugging_face_multimodal/hugging_face_mm_inputs.py @@ -184,6 +184,9 @@ def __len__(self) -> int: pixel_values = UserDict.__getitem__(self, "pixel_values") return len(pixel_values) + def update_pixels(self, pixel_values: torch.Tensor) -> None: + super().__setitem__("pixel_values", pixel_values) + def reshape(self, new_shape: Tuple) -> HuggingFaceMultiModalInput: """ Defines reshaping on the HuggingFaceMultiModalInput input.