Skip to content

Commit

Permalink
cleanup type error: [list-item]
Browse files Browse the repository at this point in the history
  • Loading branch information
m-vdb committed Aug 31, 2020
1 parent 1699906 commit 0da7e77
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 25 deletions.
11 changes: 7 additions & 4 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,17 +677,20 @@ def has_user_affirmed(tracker: "DialogueStateTracker") -> bool:
def _revert_affirmation_events(tracker: "DialogueStateTracker") -> List[Event]:
revert_events = _revert_single_affirmation_events()

last_user_event = tracker.get_last_event_for(UserUttered)
last_user_event = copy.deepcopy(last_user_event)
last_user_event.parse_data["intent"]["confidence"] = 1.0

# User affirms the rephrased intent
rephrased_intent = tracker.last_executed_action_has(
name=ACTION_DEFAULT_ASK_REPHRASE_NAME, skip=1
)
if rephrased_intent:
revert_events += _revert_rephrasing_events()

last_user_event = tracker.get_last_event_for(UserUttered)
if not last_user_event:
return revert_events

last_user_event = copy.deepcopy(last_user_event)
last_user_event.parse_data["intent"]["confidence"] = 1.0

return revert_events + [last_user_event]


Expand Down
4 changes: 2 additions & 2 deletions rasa/core/actions/two_stage_fallback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import time
from typing import List, Text, Optional
from typing import List, Text, Optional, cast

from rasa.constants import DEFAULT_NLU_FALLBACK_INTENT_NAME
from rasa.core.actions import action
Expand Down Expand Up @@ -195,7 +195,7 @@ def _second_affirmation_failed(tracker: DialogueStateTracker) -> bool:


def _message_clarification(tracker: DialogueStateTracker) -> List[Event]:
clarification = copy.deepcopy(tracker.latest_message)
clarification = copy.deepcopy(cast(Event, tracker.latest_message))
clarification.parse_data["intent"]["confidence"] = 1.0
clarification.timestamp = time.time()
return [ActionExecuted(ACTION_LISTEN_NAME), clarification]
2 changes: 1 addition & 1 deletion rasa/core/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def __init__(
) -> None:

super().__init__(state_featurizer, use_intent_probabilities)
self.max_history = max_history or self.MAX_HISTORY_DEFAULT
self.max_history: Optional[int] = max_history or self.MAX_HISTORY_DEFAULT
self.remove_duplicates = remove_duplicates

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def load(cls, path: Text) -> "Policy":

raise NotImplementedError("Policy must have the capacity to load itself.")

@staticmethod
def _default_predictions(domain: Domain) -> List[float]:
def _default_predictions(self, domain: Domain) -> List[float]:
"""Creates a list of zeros.
Args:
Expand Down
14 changes: 7 additions & 7 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Dict, Text, Optional, Any, Set, TYPE_CHECKING
from typing import List, Dict, Text, Optional, Any, Set, TYPE_CHECKING, Union

import re
from collections import defaultdict
Expand Down Expand Up @@ -152,8 +152,8 @@ def _prev_action_listen_in_state(state: Dict[Text, float]) -> bool:

@staticmethod
def _modified_states(
states: List[Dict[Text, float]]
) -> List[Optional[Dict[Text, float]]]:
states: List[Dict[Text, Union[int, float]]]
) -> List[Optional[Dict[Text, Union[int, float]]]]:
"""Modifies the states to create feature keys for form unhappy path conditions.
Args:
Expand All @@ -165,7 +165,7 @@ def _modified_states(
"""

indicator = PREV_PREFIX + RULE_SNIPPET_ACTION_NAME
state_only_with_action = {indicator: 1}
state_only_with_action: Dict[Text, Union[int, float]] = {indicator: 1}
# leave only last 2 dialogue turns to
# - capture previous meaningful action before action_listen
# - ignore previous intent
Expand Down Expand Up @@ -266,9 +266,7 @@ def train(

# only consider original trackers (no augmented ones)
training_trackers = [
t
for t in training_trackers
if not hasattr(t, "is_augmented") or not t.is_augmented
t for t in training_trackers if not getattr(t, "is_augmented", False)
]
# only use trackers from rule-based training data
rule_trackers = [t for t in training_trackers if t.is_rule_tracker]
Expand Down Expand Up @@ -414,6 +412,8 @@ def _find_action_from_form_happy_path(
)
return ACTION_LISTEN_NAME

return None

def _find_action_from_rules(
self, tracker: DialogueStateTracker, domain: Domain
) -> Optional[Text]:
Expand Down
7 changes: 4 additions & 3 deletions rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,10 @@ def _collect_user_uttered_predictions(
intent_gold = event.intent.get("name")
predicted_intent = predicted.get(INTENT, {}).get("name")

user_uttered_eval_store.add_to_store(
intent_predictions=[predicted_intent], intent_targets=[intent_gold]
)
if intent_gold:
user_uttered_eval_store.add_to_store(intent_targets=[intent_gold])
if predicted_intent:
user_uttered_eval_store.add_to_store(intent_predictions=[predicted_intent])

entity_gold = event.entities
predicted_entities = predicted.get(ENTITIES)
Expand Down
12 changes: 6 additions & 6 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from functools import partial
from multiprocessing import Process
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union, Set
from typing import Any, Callable, Deque, Dict, List, Optional, Text, Tuple, Union, Set

import numpy as np
from aiohttp import ClientError
Expand Down Expand Up @@ -1474,7 +1474,7 @@ async def record_messages(

async def _get_tracker_events_to_plot(
domain: Dict[Text, Any], file_importer: TrainingDataImporter, conversation_id: Text
) -> List[Union[Text, List[Event]]]:
) -> List[Union[Text, Deque[Event]]]:
training_trackers = await _get_training_trackers(file_importer, domain)
number_of_trackers = len(training_trackers)
if number_of_trackers > MAX_NUMBER_OF_TRAINING_STORIES_FOR_VISUALIZATION:
Expand All @@ -1487,10 +1487,10 @@ async def _get_tracker_events_to_plot(
)
training_trackers = []

training_data_events = [t.events for t in training_trackers]
events_including_current_user_id = training_data_events + [conversation_id]

return events_including_current_user_id
training_data_events: List[Union[Text, Deque[Event]]] = [
t.events for t in training_trackers
]
return training_data_events + [conversation_id]


async def _get_training_trackers(
Expand Down

0 comments on commit 0da7e77

Please sign in to comment.