Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Start to support generic intervention output, and adaptor-like tuning #177

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .models.interventions import NoiseIntervention
from .models.interventions import SigmoidMaskIntervention
from .models.interventions import AutoencoderIntervention
from .models.interventions import InterventionOutput


# Utils
Expand Down
17 changes: 14 additions & 3 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
TrainableIntervention,
SkipIntervention,
CollectIntervention,
BoundlessRotatedSpaceIntervention
BoundlessRotatedSpaceIntervention,
InterventionOutput
)

from torch import optim
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, config, model, backend, **kwargs):
self.is_model_stateless = is_stateless(model)
self.config.model_type = str(type(model)) # backfill
self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False
self.as_adaptor = kwargs["as_adaptor"] if "as_adaptor" in kwargs else False

self.model_has_grad = False
if self.use_fast:
Expand Down Expand Up @@ -224,6 +226,8 @@ def __init__(self, config, model, backend, **kwargs):
# cached swapped activations (hot)
self.hot_activations = {}

self.aux_loss = []

# temp fields should not be accessed outside
self._batched_setter_activation_select = {}
"""
Expand Down Expand Up @@ -1509,7 +1513,8 @@ def _intervention_setter(
] # batch_size

def hook_callback(model, args, kwargs, output=None):
if self._is_generation:
# if it is None, we use it as adaptor.
if unit_locations_base[key_i] is not None and self._is_generation:
is_prompt = self._key_setter_call_counter[key] == 0
if not self._intervene_on_prompt or is_prompt:
self._key_setter_call_counter[key] += 1
Expand Down Expand Up @@ -1555,6 +1560,10 @@ def hook_callback(model, args, kwargs, output=None):
intervention,
subspaces[key_i] if subspaces is not None else None,
)
if isinstance(intervened_representation, InterventionOutput):
if intervened_representation.loss is not None:
self.aux_loss.append(intervened_representation.loss)
intervened_representation = intervened_representation.output
else:
intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -1852,7 +1861,9 @@ def forward(
activations_sources = source_representations
if sources is not None and not isinstance(sources, list):
sources = [sources]


self.aux_loss.clear()

self._cleanup_states()

# if no source input or intervention, we return base
Expand Down
14 changes: 14 additions & 0 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Optional, Sequence, Union, List, Any

from .layers import RotateLayer, LowRankRotateLayer, SubspaceLowRankRotateLayer, AutoencoderLayer
from .basic_utils import sigmoid_boundary
from .intervention_utils import _can_use_fast, _do_intervention_by_swap

from dataclasses import dataclass
from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


@dataclass
class InterventionOutput(ModelOutput):
"""
Output of the IntervenableModel, including original outputs, intervened outputs, and collected activations.
"""
output: Optional[Any] = None
loss: Optional[Any] = None


class Intervention(torch.nn.Module):

Expand Down
14 changes: 11 additions & 3 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,14 @@ def do_intervention(
source_representation_f = bhsd_to_bs_hd(source_representation)
else:
assert False # what's going on?
intervened_representation = intervention(

intervention_output = intervention(
base_representation_f, source_representation_f, subspaces
)
if isinstance(intervention_output, InterventionOutput):
intervened_representation = intervention_output.output
else:
intervened_representation = intervention_output

post_d = intervened_representation.shape[-1]

Expand All @@ -481,7 +485,11 @@ def do_intervention(
else:
assert False # what's going on?

return intervened_representation
if not isinstance(intervention_output, InterventionOutput):
return intervened_representation

intervention_output.output = intervened_representation
return intervention_output


def simple_output_to_subcomponent(output, representation_type, model_config):
Expand Down
Loading