Skip to content

Commit 8d7e71e

Browse files
authored
Merge pull request #5673 from RasaHQ/johannes-73
Attention weight logging
2 parents e311c1f + b20a5e0 commit 8d7e71e

File tree

24 files changed

+806
-458
lines changed

24 files changed

+806
-458
lines changed

changelog/5673.improvement.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Expose diagnostic data for action and NLU predictions.
2+
3+
Add `diagnostic_data` field to the [Message](./reference/rasa/shared/nlu/training_data/message.md#message-objects)
4+
and [Prediction](./reference/rasa/core/policies/policy.md#policyprediction-objects) objects, which contain
5+
information about attention weights and other intermediate results of the inference computation.
6+
This information can be used for debugging and fine-tuning, e.g. with [RasaLit](https://github.com/RasaHQ/rasalit).
7+
8+
For examples of how to access the diagnostic data, see [here](https://gist.github.com/JEM-Mosig/c6e15b81ee70561cb72e361aff310d7e).

docs/docs/tuning-your-model.mdx

+36-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ Here is a summary of the available extractors and what they are best used for:
293293
|`MitieEntityExtractor` |MITIE |structured SVM |good for training custom entities |
294294
|`EntitySynonymMapper` |existing entities |N/A |maps known synonyms |
295295

296-
## Handling Class Imbalance
296+
## Improving Performance
297+
298+
### Handling Class Imbalance
297299

298300
Classification algorithms often do not perform well if there is a large class imbalance,
299301
for example if you have a lot of training data for some intents and very little training data for others.
@@ -312,6 +314,39 @@ pipeline:
312314
batch_strategy: sequence
313315
```
314316

317+
### Accessing Diagnostic Data
318+
319+
To gain a better understanding of what your models do, you can access intermediate results of the prediction process.
320+
To do this, you need to access the `diagnostic_data` field of the [Message](./reference/rasa/shared/nlu/training_data/message.md#message-objects)
321+
and [Prediction](./reference/rasa/core/policies/policy.md#policyprediction-objects) objects, which contain
322+
information about attention weights and other intermediate results of the inference computation.
323+
You can use this information for debugging and fine-tuning, e.g. with [RasaLit](https://github.com/RasaHQ/rasalit).
324+
325+
After you've [trained a model](.//command-line-interface.mdx#rasa-train), you can access diagnostic data for DIET,
326+
given a processed message, like this:
327+
328+
```python
329+
nlu_diagnostic_data = message.as_dict()[DIAGNOSTIC_DATA]
330+
331+
for component_name, diagnostic_data in nlu_diagnostic_data.items():
332+
attention_weights = diagnostic_data["attention_weights"]
333+
print(f"attention_weights for {component_name}:")
334+
print(attention_weights)
335+
336+
text_transformed = diagnostic_data["text_transformed"]
337+
print(f"\ntext_transformed for {component_name}:")
338+
print(text_transformed)
339+
```
340+
341+
And you can access diagnostic data for TED like this:
342+
343+
```python
344+
prediction = policy.predict_action_probabilities(
345+
GREET_RULE, domain, RegexInterpreter()
346+
)
347+
print(f"{prediction.diagnostic_data.get('attention_weights')}")
348+
```
349+
315350

316351
## Configuring Tensorflow
317352

rasa/core/policies/policy.py

+9
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def _prediction(
236236
events: Optional[List[Event]] = None,
237237
optional_events: Optional[List[Event]] = None,
238238
is_end_to_end_prediction: bool = False,
239+
diagnostic_data: Optional[Dict[Text, Any]] = None,
239240
) -> "PolicyPrediction":
240241
return PolicyPrediction(
241242
probabilities,
@@ -244,6 +245,7 @@ def _prediction(
244245
events,
245246
optional_events,
246247
is_end_to_end_prediction,
248+
diagnostic_data,
247249
)
248250

249251
def _metadata(self) -> Optional[Dict[Text, Any]]:
@@ -400,6 +402,7 @@ def __init__(
400402
events: Optional[List[Event]] = None,
401403
optional_events: Optional[List[Event]] = None,
402404
is_end_to_end_prediction: bool = False,
405+
diagnostic_data: Optional[Dict[Text, Any]] = None,
403406
) -> None:
404407
"""Creates a `PolicyPrediction`.
405408
@@ -417,13 +420,17 @@ def __init__(
417420
you return as they can potentially influence the conversation flow.
418421
is_end_to_end_prediction: `True` if the prediction used the text of the
419422
user message instead of the intent.
423+
diagnostic_data: Intermediate results or other information that is not
424+
necessary for Rasa to function, but intended for debugging and
425+
fine-tuning purposes.
420426
"""
421427
self.probabilities = probabilities
422428
self.policy_name = policy_name
423429
self.policy_priority = (policy_priority,)
424430
self.events = events or []
425431
self.optional_events = optional_events or []
426432
self.is_end_to_end_prediction = is_end_to_end_prediction
433+
self.diagnostic_data = diagnostic_data or {}
427434

428435
@staticmethod
429436
def for_action_name(
@@ -466,6 +473,8 @@ def __eq__(self, other: Any) -> bool:
466473
and self.events == other.events
467474
and self.optional_events == other.events
468475
and self.is_end_to_end_prediction == other.is_end_to_end_prediction
476+
# We do not compare `diagnostic_data`, because it has no effect on the
477+
# action prediction.
469478
)
470479

471480
@property

rasa/core/policies/ted_policy.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter
3939
from rasa.core.policies.policy import Policy, PolicyPrediction
4040
from rasa.core.constants import DEFAULT_POLICY_PRIORITY, DIALOGUE
41+
from rasa.shared.constants import DIAGNOSTIC_DATA
4142
from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS, ACTION_LISTEN_NAME
4243
from rasa.shared.core.trackers import DialogueStateTracker
4344
from rasa.shared.core.generator import TrackerWithCachedStates
@@ -50,6 +51,7 @@
5051
Data,
5152
)
5253
from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
54+
import rasa.utils.tensorflow.numpy
5355
from rasa.utils.tensorflow.constants import (
5456
LABEL,
5557
IDS,
@@ -632,6 +634,9 @@ def predict_action_probabilities(
632634
confidence.tolist(),
633635
is_end_to_end_prediction=is_e2e_prediction,
634636
optional_events=optional_events,
637+
diagnostic_data=rasa.utils.tensorflow.numpy.values_to_numpy(
638+
output.get(DIAGNOSTIC_DATA)
639+
),
635640
)
636641

637642
def _create_optional_event_for_entities(
@@ -1050,14 +1055,23 @@ def _embed_dialogue(
10501055
self,
10511056
dialogue_in: tf.Tensor,
10521057
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
1053-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1054-
"""Create dialogue level embedding and mask."""
1058+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
1059+
"""Creates dialogue level embedding and mask.
1060+
1061+
Args:
1062+
dialogue_in: The encoded dialogue.
1063+
tf_batch_data: Batch in model data format.
1064+
1065+
Returns:
1066+
The dialogue embedding, the mask, and (for diagnostic purposes)
1067+
also the attention weights.
1068+
"""
10551069
dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32)
10561070
mask = self._compute_mask(dialogue_lengths)
10571071

1058-
dialogue_transformed = self._tf_layers[f"transformer.{DIALOGUE}"](
1059-
dialogue_in, 1 - mask, self._training
1060-
)
1072+
dialogue_transformed, attention_weights = self._tf_layers[
1073+
f"transformer.{DIALOGUE}"
1074+
](dialogue_in, 1 - mask, self._training)
10611075
dialogue_transformed = tfa.activations.gelu(dialogue_transformed)
10621076

10631077
if self.use_only_last_dialogue_turns:
@@ -1069,7 +1083,7 @@ def _embed_dialogue(
10691083

10701084
dialogue_embed = self._tf_layers[f"embed.{DIALOGUE}"](dialogue_transformed)
10711085

1072-
return dialogue_embed, mask, dialogue_transformed
1086+
return dialogue_embed, mask, dialogue_transformed, attention_weights
10731087

10741088
def _encode_features_per_attribute(
10751089
self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
@@ -1615,6 +1629,7 @@ def batch_loss(
16151629
dialogue_embed,
16161630
dialogue_mask,
16171631
dialogue_transformer_output,
1632+
_,
16181633
) = self._embed_dialogue(dialogue_in, tf_batch_data)
16191634
dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)
16201635

@@ -1686,6 +1701,7 @@ def batch_predict(
16861701
dialogue_embed,
16871702
dialogue_mask,
16881703
dialogue_transformer_output,
1704+
attention_weights,
16891705
) = self._embed_dialogue(dialogue_in, tf_batch_data)
16901706
dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)
16911707

@@ -1698,7 +1714,11 @@ def batch_predict(
16981714
scores = self._tf_layers[f"loss.{LABEL}"].confidence_from_sim(
16991715
sim_all, self.config[SIMILARITY_TYPE]
17001716
)
1701-
predictions = {"action_scores": scores, "similarities": sim_all}
1717+
predictions = {
1718+
"action_scores": scores,
1719+
"similarities": sim_all,
1720+
DIAGNOSTIC_DATA: {"attention_weights": attention_weights},
1721+
}
17021722

17031723
if (
17041724
self.config[ENTITY_RECOGNITION]

rasa/nlu/classifiers/diet_classifier.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import rasa.shared.utils.io
1414
import rasa.utils.io as io_utils
1515
import rasa.nlu.utils.bilou_utils as bilou_utils
16+
import rasa.utils.tensorflow.numpy
17+
from rasa.shared.constants import DIAGNOSTIC_DATA
1618
from rasa.nlu.featurizers.featurizer import Featurizer
1719
from rasa.nlu.components import Component
1820
from rasa.nlu.classifiers.classifier import IntentClassifier
@@ -914,7 +916,7 @@ def _predict_entities(
914916
return entities
915917

916918
def process(self, message: Message, **kwargs: Any) -> None:
917-
"""Return the most likely label and its similarity to the input."""
919+
"""Augments the message with intents, entities, and diagnostic data."""
918920
out = self._predict(message)
919921

920922
if self.component_config[INTENT_CLASSIFICATION]:
@@ -928,12 +930,17 @@ def process(self, message: Message, **kwargs: Any) -> None:
928930

929931
message.set(ENTITIES, entities, add_to_output=True)
930932

933+
if out and DIAGNOSTIC_DATA in out:
934+
message.add_diagnostic_data(
935+
self.unique_name,
936+
rasa.utils.tensorflow.numpy.values_to_numpy(out.get(DIAGNOSTIC_DATA)),
937+
)
938+
931939
def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
932940
"""Persist this model into the passed directory.
933941
934942
Return the metadata necessary to load the model again.
935943
"""
936-
937944
if self.model is None:
938945
return {"file": None}
939946

@@ -1420,6 +1427,7 @@ def batch_loss(
14201427
text_in,
14211428
text_seq_ids,
14221429
lm_mask_bool_text,
1430+
_,
14231431
) = self._create_sequence(
14241432
tf_batch_data[TEXT][SEQUENCE],
14251433
tf_batch_data[TEXT][SENTENCE],
@@ -1569,7 +1577,7 @@ def batch_predict(
15691577

15701578
mask = self._compute_mask(sequence_lengths)
15711579

1572-
text_transformed, _, _, _ = self._create_sequence(
1580+
text_transformed, _, _, _, attention_weights = self._create_sequence(
15731581
tf_batch_data[TEXT][SEQUENCE],
15741582
tf_batch_data[TEXT][SENTENCE],
15751583
mask_sequence_text,
@@ -1579,6 +1587,11 @@ def batch_predict(
15791587

15801588
predictions: Dict[Text, tf.Tensor] = {}
15811589

1590+
predictions[DIAGNOSTIC_DATA] = {
1591+
"attention_weights": attention_weights,
1592+
"text_transformed": text_transformed,
1593+
}
1594+
15821595
if self.config[INTENT_CLASSIFICATION]:
15831596
predictions.update(
15841597
self._batch_predict_intents(sequence_lengths, text_transformed)

0 commit comments

Comments
 (0)