Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Response Selector in Cross Validation #4976

Merged
merged 18 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/4976.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
``rasa test nlu --cross-validation`` now also includes an evaluation of the response selector.
As a result, the train and test F1-score, accuracy and precision is logged for the response selector.
A report is also generated in the ``results`` folder by the name ``response_selection_report.json``
4 changes: 2 additions & 2 deletions rasa/cli/arguments/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_test_nlu_argument_group(
"-f",
"--folds",
required=False,
default=10,
default=5,
help="Number of cross validation folds (cross validation only).",
)
comparison_arguments = parser.add_argument_group("Comparison Mode")
Expand Down Expand Up @@ -182,7 +182,7 @@ def add_test_core_model_param(parser: argparse.ArgumentParser):


def add_no_plot_param(
parser: argparse.ArgumentParser, default: bool = False, required: bool = False,
parser: argparse.ArgumentParser, default: bool = False, required: bool = False
) -> None:
parser.add_argument(
"--no-plot",
Expand Down
136 changes: 106 additions & 30 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

IntentMetrics = Dict[Text, List[float]]
EntityMetrics = Dict[Text, Dict[Text, List[float]]]
ResponseSelectionMetrics = Dict[Text, List[float]]


def plot_confusion_matrix(
Expand Down Expand Up @@ -868,14 +869,12 @@ def align_entity_predictions(
"""

true_token_labels = []
entities_by_extractors = {
entities_by_extractors: Dict[Text, List] = {
extractor: [] for extractor in extractors
} # type: Dict[Text, List]
}
for p in result.entity_predictions:
entities_by_extractors[p["extractor"]].append(p)
extractor_labels = {
extractor: [] for extractor in extractors
} # type: Dict[Text, List]
extractor_labels: Dict[Text, List] = {extractor: [] for extractor in extractors}
for t in result.tokens:
true_token_labels.append(determine_token_labels(t, result.entity_targets, None))
for extractor, entities in entities_by_extractors.items():
Expand Down Expand Up @@ -1078,11 +1077,11 @@ def run_evaluation(
interpreter.pipeline = remove_pretrained_extractors(interpreter.pipeline)
test_data = training_data.load_data(data_path, interpreter.model_metadata.language)

result = {
result: Dict[Text, Optional[Dict]] = {
"intent_evaluation": None,
"entity_evaluation": None,
"response_selection_evaluation": None,
} # type: Dict[Text, Optional[Dict]]
}

if output_directory:
io_utils.create_directory(output_directory)
Expand Down Expand Up @@ -1128,7 +1127,10 @@ def generate_folds(

skf = StratifiedKFold(n_splits=n, shuffle=True)
x = td.intent_examples
y = [example.get("intent") for example in x]

# Get labels with response key appended to intent name because we want a stratified split on all
# intents(including retrieval intents if they exist)
y = [example.get_combined_intent_response_key() for example in x]
dakshvar22 marked this conversation as resolved.
Show resolved Hide resolved
for i_fold, (train_index, test_index) in enumerate(skf.split(x, y)):
logger.debug(f"Fold: {i_fold}")
train = [x[i] for i in train_index]
Expand All @@ -1150,11 +1152,15 @@ def generate_folds(
def combine_result(
intent_metrics: IntentMetrics,
entity_metrics: EntityMetrics,
response_selection_metrics: ResponseSelectionMetrics,
interpreter: Interpreter,
data: TrainingData,
intent_results: Optional[List[IntentEvaluationResult]] = None,
entity_results: Optional[List[EntityEvaluationResult]] = None,
) -> Tuple[IntentMetrics, EntityMetrics]:
response_selection_results: Optional[
List[ResponseSelectionEvaluationResult]
] = None,
) -> Tuple[IntentMetrics, EntityMetrics, ResponseSelectionMetrics]:
"""Collects intent and entity metrics for crossvalidation folds.
If `intent_results` or `entity_results` is provided as a list, prediction results
are also collected.
Expand All @@ -1163,8 +1169,10 @@ def combine_result(
(
intent_current_metrics,
entity_current_metrics,
response_selection_current_metrics,
current_intent_results,
current_entity_results,
current_response_selection_results,
) = compute_metrics(interpreter, data)

if intent_results is not None:
Expand All @@ -1173,15 +1181,28 @@ def combine_result(
if entity_results is not None:
entity_results += current_entity_results

if response_selection_results is not None:
response_selection_results += current_response_selection_results

for k, v in intent_current_metrics.items():
intent_metrics[k] = v + intent_metrics[k]

for k, v in response_selection_current_metrics.items():
response_selection_metrics[k] = v + response_selection_metrics[k]

for extractor, extractor_metric in entity_current_metrics.items():
entity_metrics[extractor] = {
k: v + entity_metrics[extractor][k] for k, v in extractor_metric.items()
}

return intent_metrics, entity_metrics
return intent_metrics, entity_metrics, response_selection_metrics


def _contains_entity_labels(entity_results: List[EntityEvaluationResult]) -> bool:

for result in entity_results:
if result.entity_targets or result.entity_predictions:
return True


def cross_validate(
Expand All @@ -1194,14 +1215,15 @@ def cross_validate(
confmat: Optional[Text] = None,
histogram: Optional[Text] = None,
disable_plotting: bool = False,
) -> Tuple[CVEvaluationResult, CVEvaluationResult]:
) -> Tuple[CVEvaluationResult, CVEvaluationResult, CVEvaluationResult]:

"""Stratified cross validation on data.

Args:
data: Training Data
n_folds: integer, number of cv folds
nlu_config: nlu config file
report: path to folder where reports are stored
output: path to folder where reports are stored
successes: if true successful predictions are written to a file
errors: if true incorrect predictions are written to a file
confmat: path to file that will show the confusion matrix
Expand All @@ -1222,38 +1244,58 @@ def cross_validate(
trainer = Trainer(nlu_config)
trainer.pipeline = remove_pretrained_extractors(trainer.pipeline)

intent_train_metrics = defaultdict(list) # type: IntentMetrics
intent_test_metrics = defaultdict(list) # type: IntentMetrics
entity_train_metrics = defaultdict(lambda: defaultdict(list)) # type: EntityMetrics
entity_test_metrics = defaultdict(lambda: defaultdict(list)) # type: EntityMetrics
intent_train_metrics: IntentMetrics = defaultdict(list)
intent_test_metrics: IntentMetrics = defaultdict(list)
entity_train_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
entity_test_metrics: EntityMetrics = defaultdict(lambda: defaultdict(list))
response_selection_train_metrics: ResponseSelectionMetrics = defaultdict(list)
response_selection_test_metrics: ResponseSelectionMetrics = defaultdict(list)

intent_test_results = [] # type: List[IntentEvaluationResult]
entity_test_results = [] # type: List[EntityEvaluationResult]
intent_test_results: List[IntentEvaluationResult] = []
entity_test_results: List[EntityEvaluationResult] = []
response_selection_test_results: List[ResponseSelectionEvaluationResult] = ([])
intent_classifier_present = False
extractors = set() # type: Set[Text]
response_selector_present = False
entity_evaluation_possible = False
extractors: Set[Text] = set()

for train, test in generate_folds(n_folds, data):
interpreter = trainer.train(train)

# calculate train accuracy
combine_result(intent_train_metrics, entity_train_metrics, interpreter, train)
combine_result(
intent_train_metrics,
entity_train_metrics,
response_selection_train_metrics,
interpreter,
train,
)
# calculate test accuracy
combine_result(
intent_test_metrics,
entity_test_metrics,
response_selection_test_metrics,
interpreter,
test,
intent_test_results,
entity_test_results,
response_selection_test_results,
)

if not extractors:
extractors = get_entity_extractors(interpreter)
entity_evaluation_possible = (
entity_evaluation_possible
or _contains_entity_labels(entity_test_results)
)

if is_intent_classifier_present(interpreter):
intent_classifier_present = True

if intent_classifier_present:
if is_response_selector_present(interpreter):
response_selector_present = True

if intent_classifier_present and intent_test_results:
logger.info("Accumulated test folds intent evaluation results:")
evaluate_intents(
intent_test_results,
Expand All @@ -1265,13 +1307,25 @@ def cross_validate(
disable_plotting,
)

if extractors:
if extractors and entity_evaluation_possible:
logger.info("Accumulated test folds entity evaluation results:")
evaluate_entities(entity_test_results, extractors, output, successes, errors)

if response_selector_present and response_selection_test_results:
logger.info("Accumulated test folds response selection evaluation results:")
evaluate_response_selections(response_selection_test_results, output)

if not entity_evaluation_possible:
entity_test_metrics = defaultdict(lambda: defaultdict(list))
entity_train_metrics = defaultdict(lambda: defaultdict(list))

return (
CVEvaluationResult(dict(intent_train_metrics), dict(intent_test_metrics)),
CVEvaluationResult(dict(entity_train_metrics), dict(entity_test_metrics)),
CVEvaluationResult(
dict(response_selection_train_metrics),
dict(response_selection_test_metrics),
),
)


Expand All @@ -1290,8 +1344,10 @@ def compute_metrics(
) -> Tuple[
IntentMetrics,
EntityMetrics,
ResponseSelectionMetrics,
List[IntentEvaluationResult],
List[EntityEvaluationResult],
List[ResponseSelectionEvaluationResult],
]:
"""Computes metrics for intent classification and entity extraction.
Returns intent and entity metrics, and prediction results.
Expand All @@ -1303,12 +1359,34 @@ def compute_metrics(

intent_results = remove_empty_intent_examples(intent_results)

intent_metrics = _compute_metrics(
intent_results, "intent_target", "intent_prediction"
response_selection_results = remove_empty_response_examples(
response_selection_results
)
entity_metrics = _compute_entity_metrics(entity_results, interpreter)

return (intent_metrics, entity_metrics, intent_results, entity_results)
intent_metrics = {}
if intent_results:
intent_metrics = _compute_metrics(
intent_results, "intent_target", "intent_prediction"
)

entity_metrics = {}
if entity_results:
entity_metrics = _compute_entity_metrics(entity_results, interpreter)

response_selection_metrics = {}
if response_selection_results:
response_selection_metrics = _compute_metrics(
response_selection_results, "response_target", "response_prediction"
)

return (
intent_metrics,
entity_metrics,
response_selection_metrics,
intent_results,
entity_results,
response_selection_results,
)


def compare_nlu(
Expand Down Expand Up @@ -1407,7 +1485,7 @@ def _compute_metrics(
],
target_key: Text,
target_prediction: Text,
) -> IntentMetrics:
) -> Union[IntentMetrics, ResponseSelectionMetrics]:
"""Computes evaluation metrics for a given corpus and
returns the results
"""
Expand All @@ -1425,9 +1503,7 @@ def _compute_entity_metrics(
) -> EntityMetrics:
"""Computes entity evaluation metrics and returns the results"""

entity_metric_results = defaultdict(
lambda: defaultdict(list)
) # type: EntityMetrics
entity_metric_results: EntityMetrics = defaultdict(lambda: defaultdict(list))
extractors = get_entity_extractors(interpreter)

if not extractors:
Expand Down
8 changes: 7 additions & 1 deletion rasa/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def perform_nlu_cross_validation(
data = rasa.nlu.training_data.load_data(nlu)
data = drop_intents_below_freq(data, cutoff=folds)
kwargs = utils.minimal_kwargs(kwargs, cross_validate)
results, entity_results = cross_validate(data, folds, nlu_config, output, **kwargs)
results, entity_results, response_selection_results = cross_validate(
data, folds, nlu_config, output, **kwargs
)
logger.info(f"CV evaluation (n={folds})")

if any(results):
Expand All @@ -213,3 +215,7 @@ def perform_nlu_cross_validation(
logger.info("Entity evaluation results")
return_entity_results(entity_results.train, "train")
return_entity_results(entity_results.test, "test")
if any(response_selection_results):
logger.info("Response Selection evaluation results")
return_results(response_selection_results.train, "train")
return_results(response_selection_results.test, "test")
Loading