Skip to content

Commit 761cfc3

Browse files
authored
Merge pull request #9043 from RasaHQ/misc/mypy-index
fix `mypy` index issues
2 parents 7616dcc + c421602 commit 761cfc3

22 files changed

+111
-58
lines changed

rasa/cli/x.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
import importlib.util
44
import logging
5-
from multiprocessing import get_context
5+
from multiprocessing import get_context, Process
66
import os
77
import signal
88
import sys
@@ -196,9 +196,10 @@ def _is_correct_event_broker(event_broker: EndpointConfig) -> bool:
196196
)
197197

198198

199-
def start_rasa_for_local_rasa_x(args: argparse.Namespace, rasa_x_token: Text) -> None:
199+
def start_rasa_for_local_rasa_x(
200+
args: argparse.Namespace, rasa_x_token: Text
201+
) -> Process:
200202
"""Starts the Rasa X API with Rasa as a background process."""
201-
202203
credentials_path, endpoints_path = _get_credentials_and_endpoints_paths(args)
203204
endpoints = AvailableEndpoints.read_endpoints(endpoints_path)
204205

rasa/core/actions/action.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from rasa.shared.core.trackers import DialogueStateTracker
5959
from rasa.core.nlg import NaturalLanguageGenerator
6060
from rasa.core.channels.channel import OutputChannel
61+
from rasa.shared.core.events import IntentPrediction
6162

6263
logger = logging.getLogger(__name__)
6364

@@ -851,9 +852,17 @@ async def run(
851852
domain: "Domain",
852853
) -> List[Event]:
853854
"""Runs action. Please see parent class for the full docstring."""
854-
intent_to_affirm = tracker.latest_message.intent.get(INTENT_NAME_KEY)
855+
latest_message = tracker.latest_message
856+
if latest_message is None:
857+
raise TypeError(
858+
"Cannot find last user message for detecting fallback affirmation."
859+
)
860+
861+
intent_to_affirm = latest_message.intent.get(INTENT_NAME_KEY)
855862

856-
intent_ranking = tracker.latest_message.parse_data.get(INTENT_RANKING_KEY, [])
863+
intent_ranking: List["IntentPrediction"] = latest_message.parse_data.get(
864+
INTENT_RANKING_KEY
865+
) or []
857866
if (
858867
intent_to_affirm == DEFAULT_NLU_FALLBACK_INTENT_NAME
859868
and len(intent_ranking) > 1

rasa/core/actions/two_stage_fallback.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ACTION_DEFAULT_ASK_REPHRASE_NAME,
2525
ACTION_TWO_STAGE_FALLBACK_NAME,
2626
)
27+
from rasa.shared.nlu.constants import INTENT, PREDICTED_CONFIDENCE_KEY
2728
from rasa.utils.endpoints import EndpointConfig
2829

2930

@@ -124,7 +125,7 @@ def _last_intent_name(tracker: DialogueStateTracker) -> Optional[Text]:
124125
if not last_message:
125126
return None
126127

127-
return last_message.intent.get("name")
128+
return last_message.intent_name
128129

129130

130131
def _two_fallbacks_in_a_row(tracker: DialogueStateTracker) -> bool:
@@ -179,6 +180,6 @@ def _message_clarification(tracker: DialogueStateTracker) -> List[Event]:
179180
)
180181

181182
clarification = copy.deepcopy(latest_message)
182-
clarification.parse_data["intent"]["confidence"] = 1.0
183+
clarification.parse_data[INTENT][PREDICTED_CONFIDENCE_KEY] = 1.0
183184
clarification.timestamp = time.time()
184185
return [ActionExecuted(ACTION_LISTEN_NAME), clarification]

rasa/core/channels/console.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import questionary
99
from aiohttp import ClientTimeout
1010
from prompt_toolkit.styles import Style
11-
from typing import Any
11+
from typing import Any, Generator
1212
from typing import Text, Optional, Dict, List
1313

1414
import rasa.shared.utils.cli
@@ -121,9 +121,9 @@ async def send_message_receive_block(
121121
return await resp.json()
122122

123123

124-
async def send_message_receive_stream(
124+
async def _send_message_receive_stream(
125125
server_url: Text, auth_token: Text, sender_id: Text, message: Text
126-
) -> None:
126+
) -> Generator[Dict[Text, Any], None, None]:
127127
payload = {"sender": sender_id, "message": message}
128128

129129
url = f"{server_url}/webhooks/rest/webhook?stream=true&token={auth_token}"
@@ -175,7 +175,7 @@ async def record_messages(
175175
break
176176

177177
if use_response_stream:
178-
bot_responses = send_message_receive_stream(
178+
bot_responses = _send_message_receive_stream(
179179
server_url, auth_token, sender_id, text
180180
)
181181
previous_response = None

rasa/core/policies/ensemble.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import rasa.core
1313
import rasa.core.training.training
1414
from rasa.core.constants import FALLBACK_POLICY_PRIORITY
15-
from rasa.shared.exceptions import RasaException
15+
from rasa.shared.exceptions import RasaException, InvalidConfigException
1616
import rasa.shared.utils.common
1717
import rasa.shared.utils.io
1818
import rasa.utils.io
@@ -636,7 +636,13 @@ def _pick_best_policy(
636636
if form_confidence > best_confidence:
637637
best_policy_name = form_policy_name
638638

639-
best_prediction = predictions[best_policy_name]
639+
best_prediction = predictions.get(best_policy_name)
640+
641+
if not best_prediction:
642+
raise InvalidConfigException(
643+
f"No prediction for policy '{best_policy_name}' found. Please check "
644+
f"your model configuration."
645+
)
640646

641647
policy_events += best_prediction.optional_events
642648

rasa/core/policies/rule_policy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
self._restrict_rules = restrict_rules
146146
self._check_for_contradictions = check_for_contradictions
147147

148-
self._rules_sources = None
148+
self._rules_sources = defaultdict(list)
149149

150150
# max history is set to `None` in order to capture any lengths of rule stories
151151
super().__init__(

rasa/core/policies/ted_policy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def predict_action_probabilities(
684684
tracker, domain, interpreter
685685
)
686686
model_data = self._create_model_data(tracker_state_features)
687-
outputs = self.model.run_inference(model_data)
687+
outputs: Dict[Text, np.ndarray] = self.model.run_inference(model_data)
688688

689689
# take the last prediction in the sequence
690690
similarities = outputs["similarities"][:, -1, :]

rasa/core/policies/two_stage_fallback.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ACTION_DEFAULT_ASK_REPHRASE_NAME,
2323
)
2424
from rasa.shared.core.domain import InvalidDomain, Domain
25-
from rasa.shared.nlu.constants import ACTION_NAME, INTENT_NAME_KEY
25+
from rasa.shared.nlu.constants import ACTION_NAME, INTENT_NAME_KEY, INTENT
2626

2727
if TYPE_CHECKING:
2828
from rasa.core.policies.ensemble import PolicyEnsemble
@@ -124,7 +124,7 @@ def predict_action_probabilities(
124124
) -> PolicyPrediction:
125125
"""Predicts the next action if NLU confidence is low."""
126126
nlu_data = tracker.latest_message.parse_data
127-
last_intent_name = nlu_data["intent"].get(INTENT_NAME_KEY, None)
127+
last_intent_name = nlu_data[INTENT].get(INTENT_NAME_KEY, None)
128128
should_nlu_fallback = self.should_nlu_fallback(
129129
nlu_data, tracker.latest_action.get(ACTION_NAME)
130130
)

rasa/core/test.py

+3-23
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
)
1919
from rasa.shared.core.domain import Domain
2020
from rasa.nlu.constants import (
21-
ENTITY_ATTRIBUTE_TEXT,
2221
RESPONSE_SELECTOR_DEFAULT_INTENT,
2322
RESPONSE_SELECTOR_RETRIEVAL_INTENTS,
2423
TOKENS_NAMES,
25-
ENTITY_ATTRIBUTE_CONFIDENCE,
2624
)
2725
from rasa.shared.nlu.constants import (
2826
INTENT,
@@ -38,8 +36,7 @@
3836
RESPONSE_SELECTOR,
3937
FULL_RETRIEVAL_INTENT_NAME_KEY,
4038
TEXT,
41-
ENTITY_ATTRIBUTE_GROUP,
42-
ENTITY_ATTRIBUTE_ROLE,
39+
ENTITY_ATTRIBUTE_TEXT,
4340
)
4441
from rasa.constants import RESULTS_FILE, PERCENTAGE_KEY
4542
from rasa.shared.core.events import (
@@ -56,24 +53,7 @@
5653
from rasa.core.agent import Agent
5754
from rasa.core.processor import MessageProcessor
5855
from rasa.shared.core.generator import TrainingDataGenerator
59-
60-
from typing_extensions import TypedDict
61-
62-
EntityPrediction = TypedDict(
63-
"EntityPrediction",
64-
{
65-
ENTITY_ATTRIBUTE_TEXT: Text,
66-
ENTITY_ATTRIBUTE_START: Optional[float],
67-
ENTITY_ATTRIBUTE_END: Optional[float],
68-
ENTITY_ATTRIBUTE_VALUE: Text,
69-
ENTITY_ATTRIBUTE_CONFIDENCE: float,
70-
ENTITY_ATTRIBUTE_TYPE: Text,
71-
ENTITY_ATTRIBUTE_GROUP: Optional[Text],
72-
ENTITY_ATTRIBUTE_ROLE: Optional[Text],
73-
"additional_info": Any,
74-
},
75-
total=False,
76-
)
56+
from rasa.shared.core.events import EntityPrediction
7757

7858
CONFUSION_MATRIX_STORIES_FILE = "story_confusion_matrix.png"
7959
REPORT_STORIES_FILE = "story_report.json"
@@ -911,7 +891,7 @@ async def test(
911891
num_failed = len(story_evaluation.failed_stories)
912892
num_correct = len(story_evaluation.successful_stories)
913893
num_convs = num_failed + num_correct
914-
if num_convs:
894+
if num_convs and isinstance(report, Dict):
915895
conv_accuracy = num_correct / num_convs
916896
report["conversation_accuracy"] = {
917897
"accuracy": conv_accuracy,

rasa/core/training/interactive.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ async def _predict_till_next_listen(
961961
predictions = result.get("scores")
962962
probabilities = [prediction["score"] for prediction in predictions]
963963
pred_out = int(np.argmax(probabilities))
964-
action_name = predictions[pred_out].get("action")
964+
action_name = predictions.get(pred_out, {}).get("action")
965965
policy = result.get("policy")
966966
confidence = result.get("confidence")
967967

rasa/nlu/classifiers/diet_classifier.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def __init__(
341341
self._check_config_parameters()
342342

343343
# transform numbers to labels
344-
self.index_label_id_mapping = index_label_id_mapping
344+
self.index_label_id_mapping = index_label_id_mapping or {}
345345

346346
self._entity_tag_specs = entity_tag_specs
347347

@@ -647,7 +647,8 @@ def _create_label_data(
647647
return label_data
648648

649649
def _use_default_label_features(self, label_ids: np.ndarray) -> List[FeatureArray]:
650-
all_label_features = self._label_data.get(LABEL, SENTENCE)[0]
650+
feature_arrays: List[FeatureArray] = self._label_data.get(LABEL, SENTENCE)
651+
all_label_features = feature_arrays[0]
651652
return [
652653
FeatureArray(
653654
np.array([all_label_features[label_id] for label_id in label_ids]),

rasa/nlu/constants.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import rasa.shared.nlu.constants
2-
2+
from rasa.shared.nlu.constants import ENTITY_ATTRIBUTE_CONFIDENCE
33

44
BILOU_ENTITIES = "bilou_entities"
55
BILOU_ENTITIES_ROLE = "bilou_entities_role"
66
BILOU_ENTITIES_GROUP = "bilou_entities_group"
77

8-
ENTITY_ATTRIBUTE_TEXT = "text"
9-
ENTITY_ATTRIBUTE_CONFIDENCE = "confidence"
108
ENTITY_ATTRIBUTE_CONFIDENCE_TYPE = (
119
f"{ENTITY_ATTRIBUTE_CONFIDENCE}_{rasa.shared.nlu.constants.ENTITY_ATTRIBUTE_TYPE}"
1210
)

rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _check_attribute_vocabulary(self, attribute: Text) -> bool:
169169
"""Checks if trained vocabulary exists in attribute's count vectorizer."""
170170
try:
171171
return hasattr(self.vectorizers[attribute], "vocabulary_")
172-
except (AttributeError, TypeError):
172+
except (AttributeError, KeyError):
173173
return False
174174

175175
def _get_attribute_vocabulary(self, attribute: Text) -> Optional[Dict[Text, int]]:
@@ -240,7 +240,7 @@ def __init__(
240240
self._attributes = self._attributes_for(self.analyzer)
241241

242242
# declare class instance for CountVectorizer
243-
self.vectorizers = vectorizers
243+
self.vectorizers = vectorizers or {}
244244

245245
self.finetune_mode = finetune_mode
246246

rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def required_components(cls) -> List[Type[Component]]:
6363
"pos": lambda token: token.data.get(POS_TAG_KEY)
6464
if POS_TAG_KEY in token.data
6565
else None,
66-
"pos2": lambda token: token.data.get(POS_TAG_KEY)[:2]
66+
"pos2": lambda token: token.data.get(POS_TAG_KEY, [])[:2]
6767
if "pos" in token.data
6868
else None,
6969
"upper": lambda token: token.text.isupper(),

rasa/nlu/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def _update_metadata_epochs(
358358
new_config: Optional[Dict] = None,
359359
finetuning_epoch_fraction: float = 1.0,
360360
) -> Metadata:
361+
new_config = new_config or {}
361362
for old_component_config, new_component_config in zip(
362363
model_metadata.metadata["pipeline"], new_config["pipeline"]
363364
):

rasa/nlu/test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
ENTITY_ATTRIBUTE_CONFIDENCE_TYPE,
3636
ENTITY_ATTRIBUTE_CONFIDENCE_ROLE,
3737
ENTITY_ATTRIBUTE_CONFIDENCE_GROUP,
38-
ENTITY_ATTRIBUTE_TEXT,
3938
)
4039
from rasa.shared.nlu.constants import (
4140
TEXT,
@@ -50,6 +49,7 @@
5049
NO_ENTITY_TAG,
5150
INTENT_NAME_KEY,
5251
PREDICTED_CONFIDENCE_KEY,
52+
ENTITY_ATTRIBUTE_TEXT,
5353
)
5454
from rasa.model import get_model
5555
from rasa.nlu.components import ComponentBuilder

rasa/shared/core/events.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,49 @@
4949
INTENT_NAME_KEY,
5050
ENTITY_ATTRIBUTE_ROLE,
5151
ENTITY_ATTRIBUTE_GROUP,
52+
PREDICTED_CONFIDENCE_KEY,
53+
INTENT_RANKING_KEY,
54+
ENTITY_ATTRIBUTE_TEXT,
55+
ENTITY_ATTRIBUTE_START,
56+
ENTITY_ATTRIBUTE_CONFIDENCE,
57+
ENTITY_ATTRIBUTE_END,
5258
)
5359

5460
if TYPE_CHECKING:
61+
from typing_extensions import TypedDict
62+
5563
from rasa.shared.core.trackers import DialogueStateTracker
5664

65+
EntityPrediction = TypedDict(
66+
"EntityPrediction",
67+
{
68+
ENTITY_ATTRIBUTE_TEXT: Text,
69+
ENTITY_ATTRIBUTE_START: Optional[float],
70+
ENTITY_ATTRIBUTE_END: Optional[float],
71+
ENTITY_ATTRIBUTE_VALUE: Text,
72+
ENTITY_ATTRIBUTE_CONFIDENCE: float,
73+
ENTITY_ATTRIBUTE_TYPE: Text,
74+
ENTITY_ATTRIBUTE_GROUP: Optional[Text],
75+
ENTITY_ATTRIBUTE_ROLE: Optional[Text],
76+
"additional_info": Any,
77+
},
78+
total=False,
79+
)
80+
81+
IntentPrediction = TypedDict(
82+
"IntentPrediction", {INTENT_NAME_KEY: Text, PREDICTED_CONFIDENCE_KEY: float,},
83+
)
84+
NLUPredictionData = TypedDict(
85+
"NLUPredictionData",
86+
{
87+
INTENT: IntentPrediction,
88+
INTENT_RANKING_KEY: List[IntentPrediction],
89+
ENTITIES: List[EntityPrediction],
90+
"message_id": Optional[Text],
91+
"metadata": Dict,
92+
},
93+
total=False,
94+
)
5795
logger = logging.getLogger(__name__)
5896

5997

@@ -369,7 +407,7 @@ def __init__(
369407
text: Optional[Text] = None,
370408
intent: Optional[Dict] = None,
371409
entities: Optional[List[Dict]] = None,
372-
parse_data: Optional[Dict[Text, Any]] = None,
410+
parse_data: Optional["NLUPredictionData"] = None,
373411
timestamp: Optional[float] = None,
374412
input_channel: Optional[Text] = None,
375413
message_id: Optional[Text] = None,
@@ -410,7 +448,7 @@ def __init__(
410448
# happens during training
411449
self.use_text_for_featurization = False
412450

413-
self.parse_data = {
451+
self.parse_data: "NLUPredictionData" = {
414452
INTENT: self.intent,
415453
# Copy entities so that changes to `self.entities` don't affect
416454
# `self.parse_data` and hence don't get persisted
@@ -426,7 +464,7 @@ def __init__(
426464
@staticmethod
427465
def _from_parse_data(
428466
text: Text,
429-
parse_data: Dict[Text, Any],
467+
parse_data: "NLUPredictionData",
430468
timestamp: Optional[float] = None,
431469
input_channel: Optional[Text] = None,
432470
message_id: Optional[Text] = None,

0 commit comments

Comments
 (0)