From 482eabb579a293fb6aeccae9292ab016b65b251a Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 15 Oct 2024 09:03:23 +0000 Subject: [PATCH 1/2] Add convenience fn to move eraser to new device --- concept_erasure/leace.py | 8 ++++++++ concept_erasure/oracle.py | 4 ++++ concept_erasure/quadratic.py | 8 ++++++++ 3 files changed, 20 insertions(+) diff --git a/concept_erasure/leace.py b/concept_erasure/leace.py index c599a80..ce58f36 100644 --- a/concept_erasure/leace.py +++ b/concept_erasure/leace.py @@ -46,6 +46,14 @@ def __call__(self, x: Tensor) -> Tensor: x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH return x_.type_as(x) + def to(self, device: torch.device | str) -> "LeaceEraser": + """Move eraser to a new device.""" + return LeaceEraser( + self.proj_left.to(device), + self.proj_right.to(device), + self.bias.to(device) if self.bias is not None else None, + ) + class LeaceFitter: """Fits an affine transform that surgically erases a concept from a representation. diff --git a/concept_erasure/oracle.py b/concept_erasure/oracle.py index ae0a2f2..e1ac16b 100644 --- a/concept_erasure/oracle.py +++ b/concept_erasure/oracle.py @@ -27,6 +27,10 @@ def __call__(self, x: Tensor, z: Tensor) -> Tensor: return x.sub(expected_x).type_as(x) + def to(self, device: torch.device | str) -> "OracleEraser": + """Move eraser to a new device.""" + return OracleEraser(self.coef.to(device), self.mean_z.to(device)) + class OracleFitter: """Compute stats needed for surgically erasing a concept Z from a random vector X. diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index 75ad21d..480049f 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -38,6 +38,14 @@ def __call__(self, x: Tensor, z: Tensor) -> Tensor: # Efficiently group `x` by `z`, optimally transport each group, then coalesce return groupby(x, z).map(self.optimal_transport).coalesce() + def to(self, device: torch.device | str) -> "QuadraticEraser": + """Move eraser to a new device.""" + return QuadraticEraser( + self.class_means.to(device), + self.global_mean.to(device), + self.ot_maps.to(device), + ) + @dataclass(frozen=True) class QuadraticEditor: From ee37d7bb8a606cb03daeeb47bd8f28af152b8e2c Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 15 Oct 2024 09:31:46 +0000 Subject: [PATCH 2/2] fix gpt neo x import --- concept_erasure/scrubbing/neox.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/concept_erasure/scrubbing/neox.py b/concept_erasure/scrubbing/neox.py index b435ee7..0fb44e9 100644 --- a/concept_erasure/scrubbing/neox.py +++ b/concept_erasure/scrubbing/neox.py @@ -6,10 +6,12 @@ from tqdm.auto import tqdm from transformers import ( GPTNeoXForCausalLM, - GPTNeoXLayer, GPTNeoXModel, ) -from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention +from transformers.models.gpt_neox.modeling_gpt_neox import ( + GPTNeoXAttention, + GPTNeoXLayer, +) from concept_erasure import ConceptScrubber, ErasureMethod, LeaceFitter from concept_erasure.utils import assert_type