49
49
INTENT_NAME_KEY ,
50
50
ENTITY_ATTRIBUTE_ROLE ,
51
51
ENTITY_ATTRIBUTE_GROUP ,
52
+ PREDICTED_CONFIDENCE_KEY ,
53
+ INTENT_RANKING_KEY ,
54
+ ENTITY_ATTRIBUTE_TEXT ,
55
+ ENTITY_ATTRIBUTE_START ,
56
+ ENTITY_ATTRIBUTE_CONFIDENCE ,
57
+ ENTITY_ATTRIBUTE_END ,
52
58
)
53
59
54
60
if TYPE_CHECKING :
61
+ from typing_extensions import TypedDict
62
+
55
63
from rasa .shared .core .trackers import DialogueStateTracker
56
64
65
+ EntityPrediction = TypedDict (
66
+ "EntityPrediction" ,
67
+ {
68
+ ENTITY_ATTRIBUTE_TEXT : Text ,
69
+ ENTITY_ATTRIBUTE_START : Optional [float ],
70
+ ENTITY_ATTRIBUTE_END : Optional [float ],
71
+ ENTITY_ATTRIBUTE_VALUE : Text ,
72
+ ENTITY_ATTRIBUTE_CONFIDENCE : float ,
73
+ ENTITY_ATTRIBUTE_TYPE : Text ,
74
+ ENTITY_ATTRIBUTE_GROUP : Optional [Text ],
75
+ ENTITY_ATTRIBUTE_ROLE : Optional [Text ],
76
+ "additional_info" : Any ,
77
+ },
78
+ total = False ,
79
+ )
80
+
81
+ IntentPrediction = TypedDict (
82
+ "IntentPrediction" , {INTENT_NAME_KEY : Text , PREDICTED_CONFIDENCE_KEY : float ,},
83
+ )
84
+ NLUPredictionData = TypedDict (
85
+ "NLUPredictionData" ,
86
+ {
87
+ INTENT : IntentPrediction ,
88
+ INTENT_RANKING_KEY : List [IntentPrediction ],
89
+ ENTITIES : List [EntityPrediction ],
90
+ "message_id" : Optional [Text ],
91
+ "metadata" : Dict ,
92
+ },
93
+ total = False ,
94
+ )
57
95
logger = logging .getLogger (__name__ )
58
96
59
97
@@ -369,7 +407,7 @@ def __init__(
369
407
text : Optional [Text ] = None ,
370
408
intent : Optional [Dict ] = None ,
371
409
entities : Optional [List [Dict ]] = None ,
372
- parse_data : Optional [Dict [ Text , Any ] ] = None ,
410
+ parse_data : Optional ["NLUPredictionData" ] = None ,
373
411
timestamp : Optional [float ] = None ,
374
412
input_channel : Optional [Text ] = None ,
375
413
message_id : Optional [Text ] = None ,
@@ -410,7 +448,7 @@ def __init__(
410
448
# happens during training
411
449
self .use_text_for_featurization = False
412
450
413
- self .parse_data = {
451
+ self .parse_data : "NLUPredictionData" = {
414
452
INTENT : self .intent ,
415
453
# Copy entities so that changes to `self.entities` don't affect
416
454
# `self.parse_data` and hence don't get persisted
@@ -426,7 +464,7 @@ def __init__(
426
464
@staticmethod
427
465
def _from_parse_data (
428
466
text : Text ,
429
- parse_data : Dict [ Text , Any ] ,
467
+ parse_data : "NLUPredictionData" ,
430
468
timestamp : Optional [float ] = None ,
431
469
input_channel : Optional [Text ] = None ,
432
470
message_id : Optional [Text ] = None ,
0 commit comments