Skip to content

Commit

Permalink
Merge pull request #6002 from RasaHQ/domain-validation-errors
Browse files Browse the repository at this point in the history
Catch domain exception due to invalid YAML
  • Loading branch information
ricwo authored Jun 12, 2020
2 parents d2d8df5 + fa99718 commit 8304e1c
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 44 deletions.
3 changes: 3 additions & 0 deletions changelog/5976.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix server crashes that occurred when Rasa Open Source pulls a model from a
:ref:`model server <server_fetch_from_server>` and an exception was thrown during
model loading (such as a domain with invalid YAML).
73 changes: 54 additions & 19 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,41 +71,76 @@ async def load_from_server(agent: "Agent", model_server: EndpointConfig) -> "Age
return agent


def _load_and_set_updated_model(
agent: "Agent", model_directory: Text, fingerprint: Text
):
"""Load the persisted model into memory and set the model on the agent."""
def _load_interpreter(
agent: "Agent", nlu_path: Optional[Text]
) -> NaturalLanguageInterpreter:
"""Load the NLU interpreter at `nlu_path`.
Args:
agent: Instance of `Agent` to inspect for an interpreter if `nlu_path` is
`None`.
nlu_path: NLU model path.
Returns:
The NLU interpreter.
"""
if nlu_path:
from rasa.core.interpreter import RasaNLUInterpreter

logger.debug(f"Found new model with fingerprint {fingerprint}. Loading...")
return RasaNLUInterpreter(model_directory=nlu_path)

core_path, nlu_path = get_model_subdirectories(model_directory)
return agent.interpreter or RegexInterpreter()

if nlu_path:
from rasa.core.interpreter import RasaNLUInterpreter

interpreter = RasaNLUInterpreter(model_directory=nlu_path)
else:
interpreter = (
agent.interpreter if agent.interpreter is not None else RegexInterpreter()
)
def _load_domain_and_policy_ensemble(
core_path: Optional[Text],
) -> Tuple[Optional[Domain], Optional[PolicyEnsemble]]:
"""Load the domain and policy ensemble from the model at `core_path`.
Args:
core_path: Core model path.
Returns:
An instance of `Domain` and `PolicyEnsemble` if `core_path` is not `None`.
"""
policy_ensemble = None
domain = None

if core_path:
policy_ensemble = PolicyEnsemble.load(core_path)
domain_path = os.path.join(os.path.abspath(core_path), DEFAULT_DOMAIN_PATH)
domain = Domain.load(domain_path)

return domain, policy_ensemble


def _load_and_set_updated_model(
agent: "Agent", model_directory: Text, fingerprint: Text
) -> None:
"""Load the persisted model into memory and set the model on the agent.
Args:
agent: Instance of `Agent` to update with the new model.
model_directory: Rasa model directory.
fingerprint: Fingerprint of the supplied model at `model_directory`.
"""
logger.debug(f"Found new model with fingerprint {fingerprint}. Loading...")

core_path, nlu_path = get_model_subdirectories(model_directory)

try:
policy_ensemble = None
if core_path:
policy_ensemble = PolicyEnsemble.load(core_path)
interpreter = _load_interpreter(agent, nlu_path)
domain, policy_ensemble = _load_domain_and_policy_ensemble(core_path)

agent.update_model(
domain, policy_ensemble, fingerprint, interpreter, model_directory
)

logger.debug("Finished updating agent to new model.")
except Exception:
except Exception as e:
logger.exception(
"Failed to load policy and update agent. "
"The previous model will stay loaded instead."
f"Failed to update model. The previous model will stay loaded instead. "
f"Error: {e}"
)


Expand Down
20 changes: 11 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sanic.request import Request
from sanic.testing import SanicTestClient

from typing import Tuple, Iterator
from typing import Iterator, Callable

import pytest
from _pytest.tmpdir import TempdirFactory
Expand All @@ -25,7 +25,7 @@
from rasa.core.exporter import Exporter
from rasa.core.policies import Policy
from rasa.core.policies.memoization import AugmentedMemoizationPolicy
from rasa.core.run import _create_app_without_api
import rasa.core.run
from rasa.core.tracker_store import InMemoryTrackerStore, TrackerStore
from rasa.model import get_model
from rasa.train import train_async
Expand Down Expand Up @@ -60,7 +60,7 @@ def event_loop(request: Request) -> Iterator[asyncio.AbstractEventLoop]:


@pytest.fixture(scope="session")
async def _trained_default_agent(tmpdir_factory: TempdirFactory) -> Tuple[Agent, str]:
async def _trained_default_agent(tmpdir_factory: TempdirFactory) -> Agent:
model_path = tmpdir_factory.mktemp("model").strpath

agent = Agent(
Expand Down Expand Up @@ -154,8 +154,10 @@ def default_config() -> List[Policy]:


@pytest.fixture(scope="session")
def trained_async(tmpdir_factory):
async def _train(*args, output_path=None, **kwargs):
def trained_async(tmpdir_factory: TempdirFactory) -> Callable:
async def _train(
*args: Any, output_path: Optional[Text] = None, **kwargs: Any
) -> Optional[Text]:
if output_path is None:
output_path = str(tmpdir_factory.mktemp("models"))

Expand All @@ -166,7 +168,7 @@ async def _train(*args, output_path=None, **kwargs):

@pytest.fixture(scope="session")
async def trained_rasa_model(
trained_async,
trained_async: Callable,
default_domain_path: Text,
default_nlu_data: Text,
default_stories_file: Text,
Expand All @@ -182,7 +184,7 @@ async def trained_rasa_model(

@pytest.fixture(scope="session")
async def trained_core_model(
trained_async,
trained_async: Callable,
default_domain_path: Text,
default_nlu_data: Text,
default_stories_file: Text,
Expand All @@ -198,7 +200,7 @@ async def trained_core_model(

@pytest.fixture(scope="session")
async def trained_nlu_model(
trained_async,
trained_async: Callable,
default_domain_path: Text,
default_config: List[Policy],
default_nlu_data: Text,
Expand Down Expand Up @@ -243,7 +245,7 @@ async def rasa_server_secured(default_agent: Agent) -> Sanic:

@pytest.fixture
async def rasa_server_without_api() -> Sanic:
app = _create_app_without_api()
app = rasa.core.run._create_app_without_api()
channel.register([RestInput()], app, "/webhooks/")
return app

Expand Down
71 changes: 55 additions & 16 deletions tests/core/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import asyncio
from typing import Text
from pathlib import Path
from typing import Text, Dict, Any, Optional, Callable
from unittest.mock import Mock

import pytest
from _pytest.logging import LogCaptureFixture
from _pytest.monkeypatch import MonkeyPatch
from pytest_sanic.utils import TestClient
from sanic import Sanic, response
from sanic.request import Request
from sanic.response import StreamingHTTPResponse
from uvloop.loop import Loop

import rasa.core
from rasa.core.policies.ted_policy import TEDPolicy
Expand All @@ -19,12 +27,12 @@
from tests.core.conftest import DEFAULT_DOMAIN_PATH_WITH_SLOTS


def model_server_app(model_path: Text, model_hash: Text = "somehash"):
def model_server_app(model_path: Text, model_hash: Text = "somehash") -> Sanic:
app = Sanic(__name__)
app.number_of_model_requests = 0

@app.route("/model", methods=["GET"])
async def model(request):
async def model(request: Request) -> StreamingHTTPResponse:
"""Simple HTTP model server responding with a trained model."""

if model_hash == request.headers.get("If-None-Match"):
Expand All @@ -42,12 +50,14 @@ async def model(request):


@pytest.fixture()
def model_server(loop, sanic_client, trained_moodbot_path: Text):
def model_server(
loop: Loop, sanic_client: Callable, trained_moodbot_path: Text
) -> TestClient:
app = model_server_app(trained_moodbot_path, model_hash="somehash")
return loop.run_until_complete(sanic_client(app))


async def test_training_data_is_reproducible(tmpdir, default_domain):
async def test_training_data_is_reproducible(tmpdir: Path, default_domain: Domain):
training_data_file = "examples/moodbot/data/stories.md"
agent = Agent(
"examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
Expand Down Expand Up @@ -110,7 +120,7 @@ async def test_agent_train(trained_moodbot_path: Text):
],
)
async def test_agent_parse_message_using_nlu_interpreter(
default_agent, text_message_data, expected
default_agent: Agent, text_message_data: Text, expected: Dict[Text, Any]
):
result = await default_agent.parse_message_using_nlu_interpreter(text_message_data)
assert result == expected
Expand All @@ -133,7 +143,7 @@ async def test_agent_handle_message(default_agent: Agent):
]


def test_agent_wrong_use_of_load(tmpdir, default_domain):
def test_agent_wrong_use_of_load(tmpdir: Path, default_domain):
training_data_file = "examples/moodbot/data/stories.md"
agent = Agent(
"examples/moodbot/domain.yml", policies=[AugmentedMemoizationPolicy()]
Expand All @@ -146,7 +156,7 @@ def test_agent_wrong_use_of_load(tmpdir, default_domain):


async def test_agent_with_model_server_in_thread(
model_server, moodbot_domain, moodbot_metadata
model_server: TestClient, moodbot_domain: Domain, moodbot_metadata: Any
):
model_endpoint_config = EndpointConfig.from_dict(
{"url": model_server.make_url("/model"), "wait_time_between_pulls": 2}
Expand All @@ -171,7 +181,9 @@ async def test_agent_with_model_server_in_thread(
jobs.kill_scheduler()


async def test_wait_time_between_pulls_without_interval(model_server, monkeypatch):
async def test_wait_time_between_pulls_without_interval(
model_server: TestClient, monkeypatch: MonkeyPatch
):
monkeypatch.setattr(
"rasa.core.agent.schedule_model_pulling", lambda *args: 1 / 0
) # will raise an exception
Expand All @@ -181,12 +193,33 @@ async def test_wait_time_between_pulls_without_interval(model_server, monkeypatc
)

agent = Agent()
# schould not call schedule_model_pulling, if it does, this will raise
# should not call schedule_model_pulling, if it does, this will raise
await rasa.core.agent.load_from_server(agent, model_server=model_endpoint_config)
jobs.kill_scheduler()


async def test_load_agent(trained_rasa_model: str):
async def test_pull_model_with_invalid_domain(
model_server: TestClient, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture
):
# mock `Domain.load()` as if the domain contains invalid YAML
error_message = "domain is invalid"
mock_load = Mock(side_effect=InvalidDomain(error_message))

monkeypatch.setattr(Domain, "load", mock_load)
model_endpoint_config = EndpointConfig.from_dict(
{"url": model_server.make_url("/model"), "wait_time_between_pulls": None}
)

agent = Agent()
await rasa.core.agent.load_from_server(agent, model_server=model_endpoint_config)

# `Domain.load()` was called
mock_load.assert_called_once()

# error was logged
assert error_message in caplog.text


async def test_load_agent(trained_rasa_model: Text):
agent = await load_agent(model_path=trained_rasa_model)

assert agent.tracker_store is not None
Expand All @@ -198,7 +231,9 @@ async def test_load_agent(trained_rasa_model: str):
"domain, policy_config",
[({"forms": ["restaurant_form"]}, {"policies": [{"name": "MemoizationPolicy"}]})],
)
def test_form_without_form_policy(domain, policy_config):
def test_form_without_form_policy(
domain: Dict[Text, Any], policy_config: Dict[Text, Any]
):
with pytest.raises(InvalidDomain) as execinfo:
Agent(
domain=Domain.from_dict(domain),
Expand All @@ -219,7 +254,9 @@ def test_form_without_form_policy(domain, policy_config):
)
],
)
def test_trigger_without_mapping_policy(domain, policy_config):
def test_trigger_without_mapping_policy(
domain: Dict[Text, Any], policy_config: Dict[Text, Any]
):
with pytest.raises(InvalidDomain) as execinfo:
Agent(
domain=Domain.from_dict(domain),
Expand All @@ -232,7 +269,9 @@ def test_trigger_without_mapping_policy(domain, policy_config):
"domain, policy_config",
[({"intents": ["affirm"]}, {"policies": [{"name": "TwoStageFallbackPolicy"}]})],
)
def test_two_stage_fallback_without_deny_suggestion(domain, policy_config):
def test_two_stage_fallback_without_deny_suggestion(
domain: Dict[Text, Any], policy_config: Dict[Text, Any]
):
with pytest.raises(InvalidDomain) as execinfo:
Agent(
domain=Domain.from_dict(domain),
Expand Down Expand Up @@ -272,6 +311,6 @@ async def test_load_agent_on_not_existing_path():
None,
],
)
async def test_agent_load_on_invalid_model_path(model_path):
async def test_agent_load_on_invalid_model_path(model_path: Optional[Text]):
with pytest.raises(ValueError):
Agent.load(model_path)

0 comments on commit 8304e1c

Please sign in to comment.