Skip to content

Commit

Permalink
Add support for different output activations
Browse files Browse the repository at this point in the history
  • Loading branch information
Rubinjo committed Feb 13, 2023
1 parent 7397864 commit 48e1cd5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/innvestigate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from innvestigate import analyzer # noqa
from innvestigate.analyzer import create_analyzer # noqa
from innvestigate.analyzer.base import NotAnalyzeableModelException # noqa
from innvestigate.backend.graph import model_wo_softmax # noqa
from innvestigate.backend.graph import model_wo_output_activation # noqa

__version__ = "2.0.1"
26 changes: 14 additions & 12 deletions src/innvestigate/backend/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"get_layer_neuronwise_io",
"copy_layer_wo_activation",
"copy_layer",
"pre_softmax_tensors",
"model_wo_softmax",
"pre_output_tensors",
"model_wo_output_activation",
"get_model_layers",
"model_contains",
"trace_model_execution",
Expand Down Expand Up @@ -338,32 +338,34 @@ def copy_layer(
return get_layer_from_config(layer, config, weights=weights, **kwargs)


def pre_softmax_tensors(Xs: Tensor, should_find_softmax: bool = True) -> list[Tensor]:
"""Finds the tensors that were preceeding a potential softmax."""
softmax_found = False
def pre_output_tensors(Xs: Tensor, activation: str) -> list[Tensor]:
"""Finds the tensors that were preceeding a potential activation."""
activation_found = False

Xs = ibackend.to_list(Xs)
ret = []
for x in Xs:
layer, node_index, _tensor_index = x._keras_history
if ichecks.contains_activation(layer, activation="softmax"):
softmax_found = True
if ichecks.contains_activation(layer, activation=activation):
activation_found = True
if isinstance(layer, klayers.Activation):
ret.append(layer.get_input_at(node_index))
else:
layer_wo_act = copy_layer_wo_activation(layer)
ret.append(layer_wo_act(layer.get_input_at(node_index)))

if should_find_softmax and not softmax_found:
raise Exception("No softmax found.")
if not activation_found:
raise Exception(f"No {activation} found.")

return ret


def model_wo_softmax(model: Model) -> Model:
"""Creates a new model w/o the final softmax activation."""
def model_wo_output_activation(model: Model, activation: str) -> Model:
"""Creates a new model w/o the final activation."""
return kmodels.Model(
inputs=model.inputs, outputs=pre_softmax_tensors(model.outputs), name=model.name
inputs=model.inputs,
outputs=pre_output_tensors(model.outputs, activation),
name=model.name,
)


Expand Down

0 comments on commit 48e1cd5

Please sign in to comment.