Skip to content

Commit cd62d41

Browse files
authored
Merge pull request #8560 from RasaHQ/predict_generator
Implement interface for bulk inferencing in TF models
2 parents 78fce3b + 82f90e5 commit cd62d41

File tree

6 files changed

+205
-19
lines changed

6 files changed

+205
-19
lines changed

changelog/8560.improvement.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Implement a new interface `run_inference` inside `RasaModel` which performs batch inferencing through tensorflow models.
2+
3+
`rasa_predict` inside `RasaModel` has been made a private method now by changing it to `_rasa_predict`.

rasa/core/policies/ted_policy.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -680,11 +680,11 @@ def predict_action_probabilities(
680680
tracker, domain, interpreter
681681
)
682682
model_data = self._create_model_data(tracker_state_features)
683-
output = self.model.rasa_predict(model_data)
683+
outputs = self.model.run_inference(model_data)
684684

685685
# take the last prediction in the sequence
686-
similarities = output["similarities"][:, -1, :]
687-
confidences = output["action_scores"][:, -1, :]
686+
similarities = outputs["similarities"][:, -1, :]
687+
confidences = outputs["action_scores"][:, -1, :]
688688
# take correct prediction from batch
689689
confidence, is_e2e_prediction = self._pick_confidence(
690690
confidences, similarities, domain
@@ -698,14 +698,14 @@ def predict_action_probabilities(
698698
)
699699

700700
optional_events = self._create_optional_event_for_entities(
701-
output, is_e2e_prediction, interpreter, tracker
701+
outputs, is_e2e_prediction, interpreter, tracker
702702
)
703703

704704
return self._prediction(
705705
confidence.tolist(),
706706
is_end_to_end_prediction=is_e2e_prediction,
707707
optional_events=optional_events,
708-
diagnostic_data=output.get(DIAGNOSTIC_DATA),
708+
diagnostic_data=outputs.get(DIAGNOSTIC_DATA),
709709
)
710710

711711
def _create_optional_event_for_entities(

rasa/nlu/classifiers/diet_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def _predict(
875875

876876
# create session data from message and convert it into a batch of 1
877877
model_data = self._create_model_data([message], training=False)
878-
return self.model.rasa_predict(model_data)
878+
return self.model.run_inference(model_data)
879879

880880
def _predict_label(
881881
self, predict_out: Optional[Dict[Text, tf.Tensor]]

rasa/utils/tensorflow/models.py

+75-11
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
CONSTRAIN_SIMILARITIES,
3636
MODEL_CONFIDENCE,
3737
)
38+
import rasa.utils.train_utils
3839
from rasa.utils.tensorflow import layers
3940
from rasa.utils.tensorflow import rasa_layers
4041
from rasa.utils.tensorflow.temp_keras_modules import TmpKerasModel
@@ -230,13 +231,13 @@ def _dynamic_signature(
230231
# the list
231232
return [element_spec]
232233

233-
def rasa_predict(
234-
self, model_data: RasaModelData
234+
def _rasa_predict(
235+
self, batch_in: Tuple[np.ndarray]
235236
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
236237
"""Custom prediction method that builds tf graph on the first call.
237238
238239
Args:
239-
model_data: The model data to use for prediction.
240+
batch_in: Prepared batch ready for input to `predict_step` method of model.
240241
241242
Return:
242243
Prediction output, including diagnostic data.
@@ -248,13 +249,12 @@ def rasa_predict(
248249
self.prepare_for_predict()
249250
self.prepared_for_prediction = True
250251

251-
batch_in = RasaBatchDataGenerator.prepare_batch(model_data.data)
252-
253252
if self._run_eagerly:
254253
outputs = tf_utils.to_numpy_or_python_type(self.predict_step(batch_in))
255-
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
256-
outputs[DIAGNOSTIC_DATA]
257-
)
254+
if DIAGNOSTIC_DATA in outputs:
255+
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
256+
outputs[DIAGNOSTIC_DATA]
257+
)
258258
return outputs
259259

260260
if self._tf_predict_step is None:
@@ -263,11 +263,75 @@ def rasa_predict(
263263
)
264264

265265
outputs = tf_utils.to_numpy_or_python_type(self._tf_predict_step(batch_in))
266-
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
267-
outputs[DIAGNOSTIC_DATA]
266+
if DIAGNOSTIC_DATA in outputs:
267+
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
268+
outputs[DIAGNOSTIC_DATA]
269+
)
270+
return outputs
271+
272+
def run_inference(
273+
self, model_data: RasaModelData, batch_size: Union[int, List[int]] = 1
274+
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
275+
"""Implements bulk inferencing through the model.
276+
277+
Args:
278+
model_data: Input data to be fed to the model.
279+
batch_size: Size of batches that the generator should create.
280+
281+
Returns:
282+
Model outputs corresponding to the inputs fed.
283+
"""
284+
outputs = {}
285+
(data_generator, _,) = rasa.utils.train_utils.create_data_generators(
286+
model_data=model_data, batch_sizes=batch_size, epochs=1, shuffle=False,
268287
)
288+
data_iterator = iter(data_generator)
289+
while True:
290+
try:
291+
# data_generator is a tuple of 2 elements - input and output.
292+
# We only need input, since output is always None and not
293+
# consumed by our TF graphs.
294+
batch_in = next(data_iterator)[0]
295+
batch_out = self._rasa_predict(batch_in)
296+
outputs = self._merge_batch_outputs(outputs, batch_out)
297+
except StopIteration:
298+
# Generator ran out of batches, time to finish inferencing
299+
break
269300
return outputs
270301

302+
@staticmethod
303+
def _merge_batch_outputs(
304+
all_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
305+
batch_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
306+
) -> Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]]:
307+
"""Merges a batch's output into the output for all batches.
308+
309+
Function assumes that the schema of batch output remains the same,
310+
i.e. keys and their value types do not change from one batch's
311+
output to another.
312+
313+
Args:
314+
all_outputs: Existing output for all previous batches.
315+
batch_output: Output for a batch.
316+
317+
Returns:
318+
Merged output with the output for current batch stacked
319+
below the output for all previous batches.
320+
"""
321+
if not all_outputs:
322+
return batch_output
323+
for key, val in batch_output.items():
324+
if isinstance(val, np.ndarray):
325+
all_outputs[key] = np.concatenate(
326+
[all_outputs[key], batch_output[key]], axis=0
327+
)
328+
329+
elif isinstance(val, dict):
330+
# recurse and merge the inner dict first
331+
all_outputs[key] = RasaModel._merge_batch_outputs(all_outputs[key], val)
332+
333+
return all_outputs
334+
271335
@staticmethod
272336
def _empty_lists_to_none_in_dict(input_dict: Dict[Text, Any]) -> Dict[Text, Any]:
273337
"""Recursively replaces empty list or np array with None in a dictionary."""
@@ -339,7 +403,7 @@ def load(
339403
# predict on one data example to speed up prediction during inference
340404
# the first prediction always takes a bit longer to trace tf function
341405
if not finetune_mode and predict_data_example:
342-
model.rasa_predict(predict_data_example)
406+
model.run_inference(predict_data_example)
343407

344408
logger.debug("Finished loading the model.")
345409
return model

rasa/utils/train_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def create_data_generators(
382382
batch_strategy: Text = SEQUENCE,
383383
eval_num_examples: int = 0,
384384
random_seed: Optional[int] = None,
385+
shuffle: bool = True,
385386
) -> Tuple[RasaBatchDataGenerator, Optional[RasaBatchDataGenerator]]:
386387
"""Create data generators for train and optional validation data.
387388
@@ -392,6 +393,7 @@ def create_data_generators(
392393
batch_strategy: The batch strategy to use.
393394
eval_num_examples: Number of examples to use for validation data.
394395
random_seed: The random seed.
396+
shuffle: Whether to shuffle data inside the data generator.
395397
396398
Returns:
397399
The training data generator and optional validation data generator.
@@ -406,15 +408,15 @@ def create_data_generators(
406408
batch_size=batch_sizes,
407409
epochs=epochs,
408410
batch_strategy=batch_strategy,
409-
shuffle=True,
411+
shuffle=shuffle,
410412
)
411413

412414
data_generator = RasaBatchDataGenerator(
413415
model_data,
414416
batch_size=batch_sizes,
415417
epochs=epochs,
416418
batch_strategy=batch_strategy,
417-
shuffle=True,
419+
shuffle=shuffle,
418420
)
419421

420422
return data_generator, validation_data_generator

tests/utils/tensorflow/test_models.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import pytest
2+
from typing import Dict, Text, Union, Tuple
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
from rasa.utils.tensorflow.models import RasaModel
7+
from rasa.utils.tensorflow.model_data import RasaModelData
8+
from rasa.utils.tensorflow.model_data import FeatureArray
9+
from rasa.utils.tensorflow.constants import LABEL, IDS, SENTENCE
10+
from rasa.shared.nlu.constants import TEXT
11+
12+
13+
@pytest.mark.parametrize(
14+
"existing_outputs, new_batch_outputs, expected_output",
15+
[
16+
(
17+
{"a": np.array([1, 2]), "b": np.array([3, 1])},
18+
{"a": np.array([5, 6]), "b": np.array([2, 4])},
19+
{"a": np.array([1, 2, 5, 6]), "b": np.array([3, 1, 2, 4])},
20+
),
21+
(
22+
{},
23+
{"a": np.array([5, 6]), "b": np.array([2, 4])},
24+
{"a": np.array([5, 6]), "b": np.array([2, 4])},
25+
),
26+
(
27+
{"a": np.array([1, 2]), "b": {"c": np.array([3, 1])}},
28+
{"a": np.array([5, 6]), "b": {"c": np.array([2, 4])}},
29+
{"a": np.array([1, 2, 5, 6]), "b": {"c": np.array([3, 1, 2, 4])}},
30+
),
31+
],
32+
)
33+
def test_merging_batch_outputs(
34+
existing_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
35+
new_batch_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
36+
expected_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
37+
):
38+
39+
predicted_output = RasaModel._merge_batch_outputs(
40+
existing_outputs, new_batch_outputs
41+
)
42+
43+
def test_equal_dicts(
44+
dict1: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
45+
dict2: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
46+
) -> None:
47+
assert dict2.keys() == dict1.keys()
48+
for key in dict1:
49+
val_1 = dict1[key]
50+
val_2 = dict2[key]
51+
assert type(val_1) == type(val_2)
52+
53+
if isinstance(val_2, np.ndarray):
54+
assert np.array_equal(val_1, val_2)
55+
56+
elif isinstance(val_2, dict):
57+
test_equal_dicts(val_1, val_2)
58+
59+
test_equal_dicts(predicted_output, expected_output)
60+
61+
62+
@pytest.mark.parametrize(
63+
"batch_size, number_of_data_points, expected_number_of_batch_iterations",
64+
[(2, 3, 2), (1, 3, 3), (5, 3, 1),],
65+
)
66+
def test_batch_inference(
67+
batch_size: int,
68+
number_of_data_points: int,
69+
expected_number_of_batch_iterations: int,
70+
):
71+
model = RasaModel()
72+
73+
def _batch_predict(
74+
batch_in: Tuple[np.ndarray],
75+
) -> Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]]:
76+
77+
dummy_output = batch_in[0]
78+
output = {
79+
"dummy_output": dummy_output,
80+
"non_input_affected_output": tf.constant(
81+
np.array([[1, 2]]), dtype=tf.int32
82+
),
83+
}
84+
return output
85+
86+
# Monkeypatch batch predict so that run_inference interface can be tested
87+
model.batch_predict = _batch_predict
88+
89+
# Create dummy model data to pass to model
90+
model_data = RasaModelData(
91+
label_key=LABEL,
92+
label_sub_key=IDS,
93+
data={
94+
TEXT: {
95+
SENTENCE: [
96+
FeatureArray(
97+
np.random.rand(number_of_data_points, 2),
98+
number_of_dimensions=2,
99+
),
100+
]
101+
}
102+
},
103+
)
104+
output = model.run_inference(model_data, batch_size=batch_size)
105+
106+
# Firstly, the number of data points in dummy_output should be equal
107+
# to the number of data points sent as input.
108+
assert output["dummy_output"].shape[0] == number_of_data_points
109+
110+
# Secondly, the number of data points inside diagnostic_data should be
111+
# equal to the number of batches passed to the model because for every
112+
# batch passed as input, it would have created a
113+
# corresponding diagnostic data entry.
114+
assert output["non_input_affected_output"].shape == (
115+
expected_number_of_batch_iterations,
116+
2,
117+
)

0 commit comments

Comments
 (0)