Skip to content

Commit ac764ec

Browse files
committed
added tests and changelog
1 parent 6fd3fff commit ac764ec

File tree

3 files changed

+128
-19
lines changed

3 files changed

+128
-19
lines changed

changelog/8560.improvement.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement a new interface `run_inference` inside `RasaModel` which performs batch inferencing through tensorflow models.

rasa/utils/tensorflow/models.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@
4848
logger = logging.getLogger(__name__)
4949

5050

51+
def _merge_batch_outputs(
52+
all_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
53+
batch_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
54+
) -> Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]]:
55+
if not all_outputs:
56+
return batch_output
57+
for key, val in batch_output.items():
58+
if isinstance(val, np.ndarray):
59+
all_outputs[key] = np.concatenate(
60+
[all_outputs[key], batch_output[key]], axis=0
61+
)
62+
63+
elif isinstance(val, dict):
64+
# recurse and merge the inner dict first
65+
all_outputs[key] = _merge_batch_outputs(all_outputs[key], val)
66+
67+
return all_outputs
68+
69+
5170
# noinspection PyMethodOverriding
5271
class RasaModel(TmpKerasModel):
5372
"""Abstract custom Keras model.
@@ -289,30 +308,12 @@ def run_inference(
289308
# Only want x, since y is always None out of our data generators
290309
batch_in = next(data_iterator)[0]
291310
batch_out = self.rasa_predict(batch_in)
292-
outputs = self._merge_batch_outputs(outputs, batch_out)
311+
outputs = _merge_batch_outputs(outputs, batch_out)
293312
except StopIteration:
294313
# Generator ran out of batches, time to finish inferencing
295314
break
296315
return outputs
297316

298-
def _merge_batch_outputs(
299-
self,
300-
all_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
301-
batch_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
302-
) -> Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]]:
303-
if not all_outputs:
304-
return batch_output
305-
for key, val in batch_output.items():
306-
if isinstance(val, np.ndarray):
307-
all_outputs[key] = np.concatenate(
308-
[all_outputs[key], batch_output[key]], axis=0
309-
)
310-
elif isinstance(val, dict):
311-
# recurse and merge the inner dict first
312-
all_outputs[key] = self._merge_batch_outputs(all_outputs[key], val)
313-
314-
return batch_output
315-
316317
@staticmethod
317318
def _empty_lists_to_none_in_dict(input_dict: Dict[Text, Any]) -> Dict[Text, Any]:
318319
"""Recursively replaces empty list or np array with None in a dictionary."""

tests/utils/tensorflow/test_models.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
from typing import Dict, Text, Union
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
from rasa.utils.tensorflow.models import _merge_batch_outputs, RasaModel
7+
from rasa.utils.tensorflow.model_data import RasaModelData
8+
from rasa.shared.constants import DIAGNOSTIC_DATA
9+
from rasa.utils.tensorflow.model_data import FeatureArray
10+
11+
12+
@pytest.mark.parametrize(
13+
"existing_outputs, new_batch_outputs, expected_output",
14+
[
15+
(
16+
{"a": np.array([1, 2]), "b": np.array([3, 1])},
17+
{"a": np.array([5, 6]), "b": np.array([2, 4])},
18+
{"a": np.array([1, 2, 5, 6]), "b": np.array([3, 1, 2, 4])},
19+
),
20+
(
21+
{},
22+
{"a": np.array([5, 6]), "b": np.array([2, 4])},
23+
{"a": np.array([5, 6]), "b": np.array([2, 4])},
24+
),
25+
(
26+
{"a": np.array([1, 2]), "b": {"c": np.array([3, 1])}},
27+
{"a": np.array([5, 6]), "b": {"c": np.array([2, 4])}},
28+
{"a": np.array([1, 2, 5, 6]), "b": {"c": np.array([3, 1, 2, 4])}},
29+
),
30+
],
31+
)
32+
def test_merging_batch_outputs(
33+
existing_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
34+
new_batch_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
35+
expected_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
36+
):
37+
38+
predicted_output = _merge_batch_outputs(existing_outputs, new_batch_outputs)
39+
40+
def test_equal_dicts(
41+
dict1: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
42+
dict2: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
43+
):
44+
assert dict2.keys() == dict1.keys()
45+
for key in dict1:
46+
val_1 = dict1[key]
47+
val_2 = dict2[key]
48+
assert type(val_1) == type(val_2)
49+
50+
if isinstance(val_2, np.ndarray):
51+
assert np.array_equal(val_1, val_2)
52+
53+
elif isinstance(val_2, dict):
54+
test_equal_dicts(val_1, val_2)
55+
56+
test_equal_dicts(predicted_output, expected_output)
57+
58+
59+
@pytest.mark.parametrize(
60+
"batch_size, number_of_data_points, expected_number_of_batch_iterations",
61+
[(2, 3, 2), (1, 3, 3), (5, 3, 1),],
62+
)
63+
def test_batch_inference(
64+
batch_size: int,
65+
number_of_data_points: int,
66+
expected_number_of_batch_iterations: int,
67+
):
68+
model = RasaModel()
69+
70+
def batch_predict(batch_in: np.ndarray):
71+
72+
dummy_output = batch_in[0]
73+
output = {
74+
"dummy_output": dummy_output,
75+
DIAGNOSTIC_DATA: tf.constant(np.array([[1, 2]]), dtype=tf.int32),
76+
}
77+
return output
78+
79+
# Monkeypatch batch predict so that run_inference interface can be tested
80+
model.batch_predict = batch_predict
81+
82+
# Create dummy model data to pass to model
83+
model_data = RasaModelData(
84+
label_key="label",
85+
label_sub_key="ids",
86+
data={
87+
"text": {
88+
"sentence": [
89+
FeatureArray(
90+
np.random.rand(number_of_data_points, 2),
91+
number_of_dimensions=2,
92+
),
93+
]
94+
}
95+
},
96+
)
97+
output = model.run_inference(model_data, batch_size=batch_size)
98+
99+
# Firstly, the number of data points in dummy_output should be equal
100+
# to the number of data points sent as input.
101+
assert output["dummy_output"].shape[0] == number_of_data_points
102+
103+
# Secondly, the number of data points inside diagnostic_data should be
104+
# equal to the number of batches passed to the model because for every
105+
# batch passed as input, it would have created a
106+
# corresponding diagnostic data entry.
107+
assert output[DIAGNOSTIC_DATA].shape == (expected_number_of_batch_iterations, 2)

0 commit comments

Comments
 (0)