Skip to content

Commit 923db75

Browse files
authored
Merge pull request #6038 from RasaHQ/diet-entity-confidence
DIETClassifier adds a confidence value to entity predictions
2 parents 4151446 + a9b0c76 commit 923db75

File tree

7 files changed

+266
-13
lines changed

7 files changed

+266
-13
lines changed

changelog/5481.improvement.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``DIETClassifier`` now also assigns a confidence value to entity predictions.

docs/nlu/entity-extraction.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ exactly. Instead it will return the trained synonym.
6060
6161
.. note::
6262

63-
The ``confidence`` will be set by the ``CRFEntityExtractor`` component. The
63+
The ``confidence`` will be set by the ``CRFEntityExtractor`` and the ``DIETClassifier`` component. The
6464
``DucklingHTTPExtractor`` will always return ``1``. The ``SpacyEntityExtractor`` extractor
65-
and ``DIETClassifier`` do not provide this information and return ``null``.
65+
does not provide this information and returns ``null``.
6666

6767

6868
Some extractors, like ``duckling``, may include additional information. For example:

rasa/nlu/classifiers/diet_classifier.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -798,10 +798,13 @@ def _predict_entities(
798798
if predict_out is None:
799799
return []
800800

801-
predicted_tags = self._entity_label_to_tags(predict_out)
801+
predicted_tags, confidence_values = self._entity_label_to_tags(predict_out)
802802

803803
entities = self.convert_predictions_into_entities(
804-
message.text, message.get(TOKENS_NAMES[TEXT], []), predicted_tags
804+
message.text,
805+
message.get(TOKENS_NAMES[TEXT], []),
806+
predicted_tags,
807+
confidence_values,
805808
)
806809

807810
entities = self.add_extractor_name(entities)
@@ -811,20 +814,24 @@ def _predict_entities(
811814

812815
def _entity_label_to_tags(
813816
self, predict_out: Dict[Text, Any]
814-
) -> Dict[Text, List[Text]]:
817+
) -> Tuple[Dict[Text, List[Text]], Dict[Text, List[float]]]:
815818
predicted_tags = {}
819+
confidence_values = {}
816820

817821
for tag_spec in self._entity_tag_specs:
818822
predictions = predict_out[f"e_{tag_spec.tag_name}_ids"].numpy()
823+
confidences = predict_out[f"e_{tag_spec.tag_name}_scores"].numpy()
824+
confidences = [float(c) for c in confidences[0]]
819825
tags = [tag_spec.ids_to_tags[p] for p in predictions[0]]
820826

821827
if self.component_config[BILOU_FLAG]:
822828
tags = bilou_utils.ensure_consistent_bilou_tagging(tags)
823829
tags = bilou_utils.remove_bilou_prefixes(tags)
824830

825831
predicted_tags[tag_spec.tag_name] = tags
832+
confidence_values[tag_spec.tag_name] = confidences
826833

827-
return predicted_tags
834+
return predicted_tags, confidence_values
828835

829836
def process(self, message: Message, **kwargs: Any) -> None:
830837
"""Return the most likely label and its similarity to the input."""
@@ -1479,7 +1486,7 @@ def _calculate_entity_loss(
14791486
logits = self._tf_layers[f"embed.{tag_name}.logits"](inputs)
14801487

14811488
# should call first to build weights
1482-
pred_ids = self._tf_layers[f"crf.{tag_name}"](logits, sequence_lengths)
1489+
pred_ids, _ = self._tf_layers[f"crf.{tag_name}"](logits, sequence_lengths)
14831490
# pytype cannot infer that 'self._tf_layers["crf"]' has the method '.loss'
14841491
# pytype: disable=attribute-error
14851492
loss = self._tf_layers[f"crf.{tag_name}"].loss(
@@ -1671,9 +1678,12 @@ def _batch_predict_entities(
16711678
_input = tf.concat([_input, _tags], axis=-1)
16721679

16731680
_logits = self._tf_layers[f"embed.{name}.logits"](_input)
1674-
pred_ids = self._tf_layers[f"crf.{name}"](_logits, sequence_lengths - 1)
1681+
pred_ids, confidences = self._tf_layers[f"crf.{name}"](
1682+
_logits, sequence_lengths - 1
1683+
)
16751684

16761685
predictions[f"e_{name}_ids"] = pred_ids
1686+
predictions[f"e_{name}_scores"] = confidences
16771687

16781688
if name == ENTITY_ATTRIBUTE_TYPE:
16791689
# use the entity tags as additional input for the role

rasa/nlu/test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# performs entity extraction but those two classifiers don't
5454
ENTITY_PROCESSORS = {"EntitySynonymMapper", "ResponseSelector"}
5555

56-
EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor"}
56+
EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor", "DIETClassifier"}
5757

5858
CVEvaluationResult = namedtuple("Results", "train test")
5959

rasa/utils/tensorflow/crf.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import tensorflow as tf
2+
3+
from tensorflow_addons.utils.types import TensorLike
4+
from typeguard import typechecked
5+
from typing import Tuple
6+
7+
8+
# original code taken from
9+
# https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py
10+
# (modified to our neeeds)
11+
12+
13+
class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
14+
"""Computes the forward decoding in a linear-chain CRF."""
15+
16+
@typechecked
17+
def __init__(self, transition_params: TensorLike, **kwargs) -> None:
18+
"""Initialize the CrfDecodeForwardRnnCell.
19+
20+
Args:
21+
transition_params: A [num_tags, num_tags] matrix of binary
22+
potentials. This matrix is expanded into a
23+
[1, num_tags, num_tags] in preparation for the broadcast
24+
summation occurring within the cell.
25+
"""
26+
super().__init__(**kwargs)
27+
self._transition_params = tf.expand_dims(transition_params, 0)
28+
self._num_tags = transition_params.shape[0]
29+
30+
@property
31+
def state_size(self) -> int:
32+
return self._num_tags
33+
34+
@property
35+
def output_size(self) -> int:
36+
return self._num_tags
37+
38+
def build(self, input_shape):
39+
super().build(input_shape)
40+
41+
def call(
42+
self, inputs: TensorLike, state: TensorLike
43+
) -> Tuple[tf.Tensor, tf.Tensor]:
44+
"""Build the CrfDecodeForwardRnnCell.
45+
46+
Args:
47+
inputs: A [batch_size, num_tags] matrix of unary potentials.
48+
state: A [batch_size, num_tags] matrix containing the previous step's
49+
score values.
50+
51+
Returns:
52+
output: A [batch_size, num_tags * 2] matrix of backpointers and scores.
53+
new_state: A [batch_size, num_tags] matrix of new score values.
54+
"""
55+
state = tf.expand_dims(state[0], 2)
56+
transition_scores = state + self._transition_params
57+
new_state = inputs + tf.reduce_max(transition_scores, [1])
58+
59+
backpointers = tf.argmax(transition_scores, 1)
60+
backpointers = tf.cast(backpointers, tf.float32)
61+
62+
# apply softmax to transition_scores to get scores in range from 0 to 1
63+
scores = tf.reduce_max(tf.nn.softmax(transition_scores, axis=1), [1])
64+
65+
# In the RNN implementation only the first value that is returned from a cell
66+
# is kept throughout the RNN, so that you will have the values from each time
67+
# step in the final output. As we need the backpointers as well as the scores
68+
# for each time step, we concatenate them.
69+
return tf.concat([backpointers, scores], axis=1), new_state
70+
71+
72+
def crf_decode_forward(
73+
inputs: TensorLike,
74+
state: TensorLike,
75+
transition_params: TensorLike,
76+
sequence_lengths: TensorLike,
77+
) -> Tuple[tf.Tensor, tf.Tensor]:
78+
"""Computes forward decoding in a linear-chain CRF.
79+
80+
Args:
81+
inputs: A [batch_size, num_tags] matrix of unary potentials.
82+
state: A [batch_size, num_tags] matrix containing the previous step's
83+
score values.
84+
transition_params: A [num_tags, num_tags] matrix of binary potentials.
85+
sequence_lengths: A [batch_size] vector of true sequence lengths.
86+
87+
Returns:
88+
output: A [batch_size, num_tags * 2] matrix of backpointers and scores.
89+
new_state: A [batch_size, num_tags] matrix of new score values.
90+
"""
91+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
92+
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
93+
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
94+
crf_fwd_layer = tf.keras.layers.RNN(
95+
crf_fwd_cell, return_sequences=True, return_state=True
96+
)
97+
return crf_fwd_layer(inputs, state, mask=mask)
98+
99+
100+
def crf_decode_backward(
101+
backpointers: TensorLike, scores: TensorLike, state: TensorLike
102+
) -> Tuple[tf.Tensor, tf.Tensor]:
103+
"""Computes backward decoding in a linear-chain CRF.
104+
105+
Args:
106+
backpointers: A [batch_size, num_tags] matrix of backpointer of next step
107+
(in time order).
108+
scores: A [batch_size, num_tags] matrix of scores of next step (in time order).
109+
state: A [batch_size, 1] matrix of tag index of next step.
110+
111+
Returns:
112+
new_tags: A [batch_size, num_tags] tensor containing the new tag indices.
113+
new_scores: A [batch_size, num_tags] tensor containing the new score values.
114+
"""
115+
backpointers = tf.transpose(backpointers, [1, 0, 2])
116+
scores = tf.transpose(scores, [1, 0, 2])
117+
118+
def _scan_fn(_state: TensorLike, _inputs: TensorLike) -> tf.Tensor:
119+
_state = tf.cast(tf.squeeze(_state, axis=[1]), dtype=tf.int32)
120+
idxs = tf.stack([tf.range(tf.shape(_inputs)[0]), _state], axis=1)
121+
return tf.expand_dims(tf.gather_nd(_inputs, idxs), axis=-1)
122+
123+
output_tags = tf.scan(_scan_fn, backpointers, state)
124+
# the dtype of the input parameters of tf.scan need to match
125+
# convert state to float32 to match the type of scores
126+
state = tf.cast(state, dtype=tf.float32)
127+
output_scores = tf.scan(_scan_fn, scores, state)
128+
129+
return tf.transpose(output_tags, [1, 0, 2]), tf.transpose(output_scores, [1, 0, 2])
130+
131+
132+
def crf_decode(
133+
potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike
134+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
135+
"""Decode the highest scoring sequence of tags.
136+
137+
Args:
138+
potentials: A [batch_size, max_seq_len, num_tags] tensor of
139+
unary potentials.
140+
transition_params: A [num_tags, num_tags] matrix of
141+
binary potentials.
142+
sequence_length: A [batch_size] vector of true sequence lengths.
143+
144+
Returns:
145+
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
146+
Contains the highest scoring tag indices.
147+
decode_scores: A [batch_size, max_seq_len] matrix, containing the score of
148+
`decode_tags`.
149+
best_score: A [batch_size] vector, containing the best score of `decode_tags`.
150+
"""
151+
sequence_length = tf.cast(sequence_length, dtype=tf.int32)
152+
153+
# If max_seq_len is 1, we skip the algorithm and simply return the
154+
# argmax tag and the max activation.
155+
def _single_seq_fn():
156+
decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32)
157+
decode_scores = tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2)
158+
best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1])
159+
return decode_tags, decode_scores, best_score
160+
161+
def _multi_seq_fn():
162+
# Computes forward decoding. Get last score and backpointers.
163+
initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1])
164+
initial_state = tf.squeeze(initial_state, axis=[1])
165+
inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1])
166+
167+
sequence_length_less_one = tf.maximum(
168+
tf.constant(0, dtype=tf.int32), sequence_length - 1
169+
)
170+
171+
output, last_score = crf_decode_forward(
172+
inputs, initial_state, transition_params, sequence_length_less_one
173+
)
174+
175+
# output is a matrix of size [batch-size, max-seq-length, num-tags * 2]
176+
# split the matrix on axis 2 to get the backpointers and scores, which are
177+
# both of size [batch-size, max-seq-length, num-tags]
178+
backpointers, scores = tf.split(output, 2, axis=2)
179+
180+
backpointers = tf.cast(backpointers, dtype=tf.int32)
181+
backpointers = tf.reverse_sequence(
182+
backpointers, sequence_length_less_one, seq_axis=1
183+
)
184+
185+
scores = tf.reverse_sequence(scores, sequence_length_less_one, seq_axis=1)
186+
187+
initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32)
188+
initial_state = tf.expand_dims(initial_state, axis=-1)
189+
190+
initial_score = tf.reduce_max(tf.nn.softmax(last_score, axis=1), axis=[1])
191+
initial_score = tf.expand_dims(initial_score, axis=-1)
192+
193+
decode_tags, decode_scores = crf_decode_backward(
194+
backpointers, scores, initial_state
195+
)
196+
197+
decode_tags = tf.squeeze(decode_tags, axis=[2])
198+
decode_tags = tf.concat([initial_state, decode_tags], axis=1)
199+
decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1)
200+
201+
decode_scores = tf.squeeze(decode_scores, axis=[2])
202+
decode_scores = tf.concat([initial_score, decode_scores], axis=1)
203+
decode_scores = tf.reverse_sequence(decode_scores, sequence_length, seq_axis=1)
204+
205+
best_score = tf.reduce_max(last_score, axis=1)
206+
207+
return decode_tags, decode_scores, best_score
208+
209+
if potentials.shape[1] is not None:
210+
# shape is statically know, so we just execute
211+
# the appropriate code path
212+
if potentials.shape[1] == 1:
213+
return _single_seq_fn()
214+
215+
return _multi_seq_fn()
216+
217+
return tf.cond(tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn)

rasa/utils/tensorflow/layers.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Optional, Text, Tuple, Callable, Union, Any
33
import tensorflow as tf
44
import tensorflow_addons as tfa
5+
import rasa.utils.tensorflow.crf
56
from tensorflow.python.keras.utils import tf_utils
67
from tensorflow.python.keras import backend as K
78
from rasa.utils.tensorflow.constants import SOFTMAX, MARGIN, COSINE, INNER
@@ -460,7 +461,9 @@ def build(self, input_shape: tf.TensorShape) -> None:
460461
self.built = True
461462

462463
# noinspection PyMethodOverriding
463-
def call(self, logits: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor:
464+
def call(
465+
self, logits: tf.Tensor, sequence_lengths: tf.Tensor
466+
) -> Tuple[tf.Tensor, tf.Tensor]:
464467
"""Decodes the highest scoring sequence of tags.
465468
466469
Arguments:
@@ -471,16 +474,23 @@ def call(self, logits: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor:
471474
Returns:
472475
A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
473476
Contains the highest scoring tag indices.
477+
A [batch_size, max_seq_len] matrix, with dtype `tf.float32`.
478+
Contains the confidence values of the highest scoring tag indices.
474479
"""
475-
pred_ids, _ = tfa.text.crf.crf_decode(
480+
predicted_ids, scores, _ = rasa.utils.tensorflow.crf.crf_decode(
476481
logits, self.transition_params, sequence_lengths
477482
)
478483
# set prediction index for padding to `0`
479484
mask = tf.sequence_mask(
480-
sequence_lengths, maxlen=tf.shape(pred_ids)[1], dtype=pred_ids.dtype
485+
sequence_lengths,
486+
maxlen=tf.shape(predicted_ids)[1],
487+
dtype=predicted_ids.dtype,
481488
)
482489

483-
return pred_ids * mask
490+
confidence_values = scores * tf.cast(mask, tf.float32)
491+
predicted_ids = predicted_ids * mask
492+
493+
return predicted_ids, confidence_values
484494

485495
def loss(
486496
self, logits: tf.Tensor, tag_indices: tf.Tensor, sequence_lengths: tf.Tensor

tests/nlu/test_evaluation.py

+15
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,21 @@ def test_determine_token_labels_with_extractors():
253253
["CRFEntityExtractor"],
254254
0.87,
255255
),
256+
(
257+
Token("pizza", 4),
258+
[
259+
{
260+
"start": 4,
261+
"end": 9,
262+
"value": "pizza",
263+
"entity": "food",
264+
"confidence_entity": 0.87,
265+
"extractor": "DIETClassifier",
266+
}
267+
],
268+
["DIETClassifier"],
269+
0.87,
270+
),
256271
],
257272
)
258273
def test_get_entity_confidences(

0 commit comments

Comments
 (0)