Skip to content

Commit

Permalink
fix and enable mypy return check
Browse files Browse the repository at this point in the history
  • Loading branch information
joejuzl committed Jul 12, 2021
1 parent e5eadce commit 1ded5ef
Show file tree
Hide file tree
Showing 35 changed files with 163 additions and 123 deletions.
5 changes: 4 additions & 1 deletion rasa/cli/x.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import importlib.util
import logging
from multiprocessing import get_context
from multiprocessing.process import BaseProcess
import os
import signal
import sys
Expand Down Expand Up @@ -196,7 +197,9 @@ def _is_correct_event_broker(event_broker: EndpointConfig) -> bool:
)


def start_rasa_for_local_rasa_x(args: argparse.Namespace, rasa_x_token: Text) -> None:
def start_rasa_for_local_rasa_x(
args: argparse.Namespace, rasa_x_token: Text
) -> BaseProcess:
"""Starts the Rasa X API with Rasa as a background process."""

credentials_path, endpoints_path = _get_credentials_and_endpoints_paths(args)
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ async def run(
domain: "Domain",
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
_events = [SessionStarted(metadata=self.metadata)]
_events: List[Event] = [SessionStarted(metadata=self.metadata)]

if domain.session_config.carry_over_slots:
_events.extend(self._slot_set_events_from_tracker(tracker))
Expand Down Expand Up @@ -685,7 +685,7 @@ async def run(

events_json = response.get("events", [])
responses = response.get("responses", [])
bot_messages = await self._utter_responses(
bot_messages: List[Event] = await self._utter_responses(
responses, output_channel, nlg, tracker
)

Expand Down
10 changes: 5 additions & 5 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ async def validate_slots(
domain: Domain,
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
) -> List[Event]:
) -> List[Union[SlotSet, Event]]:
"""Validate the extracted slots.
If a custom action is available for validating the slots, we call it to validate
Expand All @@ -445,7 +445,7 @@ async def validate_slots(
for the validated slots.
"""
logger.debug(f"Validating extracted slots: {slot_candidates}")
events = [
events: List[Union[SlotSet, Event]] = [
SlotSet(slot_name, value) for slot_name, value in slot_candidates.items()
]

Expand Down Expand Up @@ -506,7 +506,7 @@ async def validate(
domain: Domain,
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
) -> List[Event]:
) -> List[Union[SlotSet, Event]]:
"""Extract and validate value of requested slot.
If nothing was extracted reject execution of the form action.
Expand Down Expand Up @@ -560,9 +560,9 @@ async def request_next_slot(
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
events_so_far: List[Event],
) -> List[Event]:
) -> List[Union[SlotSet, Event]]:
"""Request the next slot and response if needed, else return `None`."""
request_slot_events = []
request_slot_events: List[Union[SlotSet, Event]] = []

if await self.is_done(output_channel, nlg, tracker, domain, events_so_far):
# The custom action for slot validation decided to stop the form early
Expand Down
5 changes: 3 additions & 2 deletions rasa/core/actions/two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ async def deactivate(
return await self._give_up(output_channel, nlg, tracker, domain)

# revert fallback events
return [UserUtteranceReverted()] + _message_clarification(tracker)
reverted_event: List[Event] = [UserUtteranceReverted()]
return reverted_event + _message_clarification(tracker)

async def _give_up(
self,
Expand Down Expand Up @@ -136,7 +137,7 @@ def _two_fallbacks_in_a_row(tracker: DialogueStateTracker) -> bool:

def _last_n_intent_names(
tracker: DialogueStateTracker, number_of_last_intent_names: int
) -> List[Text]:
) -> List[Optional[Text]]:
intent_names = []
for i in range(number_of_last_intent_names):
message = tracker.get_last_event_for(
Expand Down
19 changes: 16 additions & 3 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import shutil
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
TYPE_CHECKING,
Text,
Tuple,
Union,
)
import uuid

import aiohttp
Expand Down Expand Up @@ -52,6 +62,9 @@
from rasa.utils.endpoints import EndpointConfig
import rasa.utils.io

if TYPE_CHECKING:
from rasa.shared.core.generator import TrackerWithCachedStates

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -679,7 +692,7 @@ def _are_all_featurizers_using_a_max_history(self) -> bool:
"""Check if all featurizers are MaxHistoryTrackerFeaturizer."""

def has_max_history_featurizer(policy: Policy) -> bool:
return (
return bool(
policy.featurizer
and hasattr(policy.featurizer, "max_history")
and policy.featurizer.max_history is not None
Expand All @@ -700,7 +713,7 @@ async def load_data(
use_story_concatenation: bool = True,
debug_plots: bool = False,
exclusion_percentage: Optional[int] = None,
) -> List[DialogueStateTracker]:
) -> List[TrackerWithCachedStates]:
"""Load training data from a resource."""

max_history = self._max_history()
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/hangouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _text_button_card(text: Text, buttons: List) -> Union[Dict, None]:
logger.error(
"Buttons must be a list of dicts with 'title' and 'payload' as keys"
)
return
return None

hangouts_buttons.append(
{
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_output_channel(self) -> Optional["OutputChannel"]:
"Please use a different channel for external events in these "
"scenarios."
)
return
return None
return SocketIOOutput(self.sio, self.bot_message_evt)

def blueprint(
Expand Down
8 changes: 6 additions & 2 deletions rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def training_states_actions_and_entities(
trackers: List[DialogueStateTracker],
domain: Domain,
omit_unset_slots: bool = False,
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
) -> Tuple[
List[List[State]], List[List[Optional[Text]]], List[List[Dict[Text, Any]]]
]:
"""Transforms list of trackers to lists of states, actions and entity data.
Args:
Expand Down Expand Up @@ -543,7 +545,9 @@ def training_states_actions_and_entities(
trackers: List[DialogueStateTracker],
domain: Domain,
omit_unset_slots: bool = False,
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
) -> Tuple[
List[List[State]], List[List[Optional[Text]]], List[List[Dict[Text, Any]]]
]:
"""Transforms list of trackers to lists of states, actions and entity data.
Args:
Expand Down
8 changes: 3 additions & 5 deletions rasa/core/policies/form_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ACTIVE_LOOP,
LOOP_REJECTED,
)
from rasa.shared.core.domain import State, Domain
from rasa.shared.core.domain import State, Domain, SubStateValue
from rasa.shared.core.events import LoopInterrupted
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter
Expand Down Expand Up @@ -55,9 +55,7 @@ def __init__(
)

@staticmethod
def _get_active_form_name(
state: State,
) -> Optional[Union[Text, Tuple[Union[float, Text]]]]:
def _get_active_form_name(state: State,) -> Optional[SubStateValue]:
return state.get(ACTIVE_LOOP, {}).get(LOOP_NAME)

@staticmethod
Expand Down Expand Up @@ -87,7 +85,7 @@ def _create_lookup_from_states(
self,
trackers_as_states: List[List[State]],
trackers_as_actions: List[List[Text]],
) -> Dict[Text, Text]:
) -> Dict[Text, SubStateValue]:
"""Add states to lookup dict"""
lookup = {}
for states in trackers_as_states:
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/policies/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,12 @@ def _back_to_the_future(
# use first action, if we went first time and second action, if we went again
idx_to_use = idx_of_second_action if again else idx_of_first_action
if idx_to_use is None:
return
return None

# make second ActionExecuted the first one
events = tracker.applied_events()[idx_to_use:]
if not events:
return
return None

mcfly_tracker = tracker.init_copy()
for e in events:
Expand Down
10 changes: 5 additions & 5 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _create_feature_key(self, states: List[State]) -> Optional[Text]:
new_states.insert(0, state)

if not new_states:
return
return None

# we sort keys to make sure that the same states
# represented as dictionaries have the same json strings
Expand Down Expand Up @@ -421,7 +421,7 @@ def _get_slots_loops_from_states(
for states in trackers_as_states:
for state in states:
slots.update(set(state.get(SLOTS, {}).keys()))
active_loop = state.get(ACTIVE_LOOP, {}).get(LOOP_NAME)
active_loop: Optional[Text] = state.get(ACTIVE_LOOP, {}).get(LOOP_NAME)
if active_loop:
loops.add(active_loop)
return slots, loops
Expand Down Expand Up @@ -592,7 +592,7 @@ def _run_prediction_on_trackers(
trackers: List[TrackerWithCachedStates],
domain: Domain,
collect_sources: bool,
) -> Tuple[List[Text], Set[Text]]:
) -> Tuple[List[Text], Set[Optional[Text]]]:
if collect_sources:
self._rules_sources = defaultdict(list)

Expand Down Expand Up @@ -665,7 +665,7 @@ def _collect_rule_sources(

def _find_contradicting_and_used_in_stories_rules(
self, trackers: List[TrackerWithCachedStates], domain: Domain
) -> Tuple[List[Text], Set[Text]]:
) -> Tuple[List[Text], Set[Optional[Text]]]:
return self._run_prediction_on_trackers(trackers, domain, collect_sources=False)

def _analyze_rules(
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def predict_action_probabilities(

def _predict(
self, tracker: DialogueStateTracker, domain: Domain
) -> Tuple[PolicyPrediction, Text]:
) -> Tuple[PolicyPrediction, Optional[Text]]:
(
rules_action_name_from_text,
prediction_source_from_text,
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/policies/sklearn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def persist(self, path: Union[Text, Path]) -> None:
@classmethod
def load(
cls, path: Union[Text, Path], should_finetune: bool = False, **kwargs: Any
) -> Policy:
) -> Optional[Policy]:
"""See the docstring for `Policy.load`."""
filename = Path(path) / "sklearn_model.pkl"
zero_features_filename = Path(path) / "zero_state_features.pkl"
Expand All @@ -347,7 +347,7 @@ def load(
f"Failed to load dialogue model. Path {filename.absolute()} "
f"doesn't exist."
)
return
return None

featurizer = TrackerFeaturizer.load(path)
assert isinstance(featurizer, MaxHistoryTrackerFeaturizer), (
Expand Down
12 changes: 6 additions & 6 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def _create_data_for_entities(
self, entity_tags: Optional[List[List[Dict[Text, List["Features"]]]]]
) -> Optional[Data]:
if not self.config[ENTITY_RECOGNITION]:
return
return None

# check that there are real entity tags
if entity_tags and self._should_extract_entities(entity_tags):
Expand Down Expand Up @@ -723,11 +723,11 @@ def _create_optional_event_for_entities(
# entities belong only to the last user message
# and only if user text was used for prediction,
# a user message always comes after action listen
return
return None

if not self.config[ENTITY_RECOGNITION]:
# entity recognition is not turned on, no entities can be predicted
return
return None

# The batch dimension of entity prediction is not the same as batch size,
# rather it is the number of last (if max history featurizer else all)
Expand All @@ -743,7 +743,7 @@ def _create_optional_event_for_entities(

if ENTITY_ATTRIBUTE_TYPE not in predicted_tags:
# no entities detected
return
return None

# entities belong to the last message of the tracker
# convert the predicted tags to actual entities
Expand Down Expand Up @@ -819,7 +819,7 @@ def load(
should_finetune: bool = False,
epoch_override: int = defaults[EPOCHS],
**kwargs: Any,
) -> "TEDPolicy":
) -> Optional["TEDPolicy"]:
"""Loads a policy from the storage.
**Needs to load its featurizer**
Expand All @@ -831,7 +831,7 @@ def load(
f"Failed to load TED policy model. Path "
f"'{model_path.absolute()}' doesn't exist."
)
return
return None

tf_model_file = model_path / f"{SAVE_MODEL_FILE_NAME}.tf_model"

Expand Down
7 changes: 4 additions & 3 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from rasa.core.agent import Agent
from rasa.core.processor import MessageProcessor
from rasa.shared.core.generator import TrainingDataGenerator
from _typeshed import SupportsLessThan

from typing_extensions import TypedDict

Expand Down Expand Up @@ -215,14 +216,14 @@ def serialise(self) -> Tuple[PredictionList, PredictionList]:
filter(
lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text, self.entity_targets
),
key=lambda x: x.get(ENTITY_ATTRIBUTE_START),
key=lambda x: x[ENTITY_ATTRIBUTE_START],
)
entity_predictions = sorted(
filter(
lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text,
self.entity_predictions,
),
key=lambda x: x.get(ENTITY_ATTRIBUTE_START),
key=lambda x: x[ENTITY_ATTRIBUTE_START],
)

i_pred, i_target = 0, 0
Expand Down Expand Up @@ -420,7 +421,7 @@ def _clean_entity_results(
cleaned_entities = []

for r in tuple(entity_results):
cleaned_entity = {ENTITY_ATTRIBUTE_TEXT: text}
cleaned_entity: EntityPrediction = {ENTITY_ATTRIBUTE_TEXT: text}
for k in (
ENTITY_ATTRIBUTE_START,
ENTITY_ATTRIBUTE_END,
Expand Down
Loading

0 comments on commit 1ded5ef

Please sign in to comment.