Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate Conversation ID #7238

Merged
merged 24 commits into from
Feb 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f00cf8b
Create decorator to validate if conversation_id exists
RomuloSouza Nov 5, 2020
90066ee
Add decorator to required functions
aleronupe Nov 10, 2020
03a7567
Merge branch 'master' of https://github.com/FGA-GCES/rasa into valida…
aleronupe Nov 10, 2020
9955405
Add changelog file
RomuloSouza Nov 10, 2020
faf75c2
Merge branch 'master' of https://github.com/FGA-GCES/rasa into valida…
aleronupe Nov 26, 2020
9572e3a
Make clean code changes
aleronupe Nov 27, 2020
91f5980
Merge branch 'validate_conversation_id' of https://github.com/FGA-GCE…
aleronupe Dec 3, 2020
dff15ae
Update changelog file and implement function type
aleronupe Dec 3, 2020
00cec5c
Add method exists in TrackerStore interface
RomuloSouza Dec 3, 2020
9dab7b4
Merge branch 'master' of https://github.com/RasaHQ/rasa into validate…
RomuloSouza Dec 3, 2020
95920dc
Create tests for endpoints that use decorator
aleronupe Dec 5, 2020
9f1560a
Merge branch 'master' of https://github.com/RasaHQ/rasa into validate…
aleronupe Dec 5, 2020
a45e2d5
Merge branch 'master' of https://github.com/RasaHQ/rasa into validate…
RomuloSouza Jan 11, 2021
ddfd8bd
Rewrite comments and change values to constants
RomuloSouza Jan 12, 2021
4618354
Extract agent from request and remove app param
RomuloSouza Jan 12, 2021
cd27837
Merge branch 'master' of https://github.com/RasaHQ/rasa into validate…
RomuloSouza Jan 12, 2021
9586cad
Merge branch 'main' of https://github.com/RasaHQ/rasa into validate_c…
RomuloSouza Feb 3, 2021
f43873d
Fix comments and remove unused code in tests
RomuloSouza Feb 3, 2021
9cd843b
Fix URL tests
aleronupe Feb 5, 2021
cbb6e7e
Merge branch 'main' of https://github.com/RasaHQ/rasa into validate_c…
RomuloSouza Feb 10, 2021
74a50df
Merge branch 'main' of https://github.com/RasaHQ/rasa into validate_c…
RomuloSouza Feb 11, 2021
724cc79
Merge branch 'main' of https://github.com/RasaHQ/rasa into validate_c…
RomuloSouza Feb 11, 2021
da43126
Merge branch 'main' of https://github.com/RasaHQ/rasa into validate_c…
RomuloSouza Feb 11, 2021
f5249af
Merge branch 'main' of https://github.com/RasaHQ/rasa into validate_c…
RomuloSouza Feb 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changelog/7022.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
The following endpoints now require the existence of the conversation for the specified conversation ID, raising an exception and returning a 404 status code.

* `GET /conversations/<conversation_id:path>/story`

* `POST /conversations/<conversation_id:path>/execute`

* `POST /conversations/<conversation_id:path>/predict`
16 changes: 15 additions & 1 deletion rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,23 @@ def create_tracker(
return tracker

def save(self, tracker):
"""Save method that will be overridden by specific tracker"""
"""Save method that will be overridden by specific tracker."""
raise NotImplementedError()

def exists(self, conversation_id: Text) -> bool:
"""Checks if tracker exists for the specified ID.

This method may be overridden by the specific tracker store for
faster implementations.

Args:
conversation_id: Conversation ID to check if the tracker exists.

Returns:
`True` if the tracker exists, `False` otherwise.
"""
return self.retrieve(conversation_id) is not None

def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
"""Retrieves tracker for the latest conversation session.

Expand Down
23 changes: 23 additions & 0 deletions rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from http import HTTPStatus
from inspect import isawaitable
from pathlib import Path
from http import HTTPStatus
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -154,6 +155,25 @@ def decorated(*args, **kwargs):
return decorator


def ensure_conversation_exists() -> Callable[..., HTTPResponse]:
"""Wraps a request handler ensuring the conversation exists."""

def decorator(f: Callable[..., HTTPResponse]) -> HTTPResponse:
@wraps(f)
def decorated(request: Request, *args: Any, **kwargs: Any) -> HTTPResponse:
conversation_id = kwargs["conversation_id"]
if request.app.agent.tracker_store.exists(conversation_id):
return f(request, *args, **kwargs)
else:
raise ErrorResponse(
HTTPStatus.NOT_FOUND, "Not found", "Conversation ID not found."
)

return decorated

return decorator


def requires_auth(app: Sanic, token: Optional[Text] = None) -> Callable[[Any], Any]:
"""Wraps a request handler with token authentication."""

Expand Down Expand Up @@ -800,6 +820,7 @@ async def replace_events(request: Request, conversation_id: Text):
@app.get("/conversations/<conversation_id:path>/story")
@requires_auth(app, auth_token)
@ensure_loaded_agent(app)
@ensure_conversation_exists()
async def retrieve_story(request: Request, conversation_id: Text):
"""Get an end-to-end story corresponding to this conversation."""
until_time = rasa.utils.endpoints.float_arg(request, "until")
Expand All @@ -826,6 +847,7 @@ async def retrieve_story(request: Request, conversation_id: Text):
@app.post("/conversations/<conversation_id:path>/execute")
@requires_auth(app, auth_token)
@ensure_loaded_agent(app)
@ensure_conversation_exists()
async def execute_action(request: Request, conversation_id: Text):
request_params = request.json

Expand Down Expand Up @@ -934,6 +956,7 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons
@app.post("/conversations/<conversation_id:path>/predict")
@requires_auth(app, auth_token)
@ensure_loaded_agent(app)
@ensure_conversation_exists()
async def predict(request: Request, conversation_id: Text) -> HTTPResponse:
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of the scope of your PR / optional: We should actually use the lockstore here as processor.predict_next persists the updated tracker.

# Fetches the appropriate bot response in a json format
Expand Down
29 changes: 29 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,13 @@ async def test_put_tracker(rasa_app: SanicASGITestClient):
assert events.deserialise_events(evts) == test_events


async def test_predict_without_conversation_id(rasa_app: SanicASGITestClient):
_, response = await rasa_app.post("/conversations/non_existent_id/predict")

assert response.status == HTTPStatus.NOT_FOUND
assert response.json()["message"] == "Conversation ID not found."


async def test_sorted_predict(rasa_app: SanicASGITestClient):
await _create_tracker_for_sender(rasa_app, "sortedpredict")

Expand Down Expand Up @@ -1468,6 +1475,16 @@ async def test_execute(rasa_app: SanicASGITestClient):
assert parsed_content["messages"]


async def test_execute_without_conversation_id(rasa_app: SanicASGITestClient):
data = {INTENT_NAME_KEY: "utter_greet"}
_, response = await rasa_app.post(
"/conversations/non_existent_id/execute", json=data
)

assert response.status == HTTPStatus.NOT_FOUND
assert response.json()["message"] == "Conversation ID not found."


async def test_execute_with_missing_action_name(rasa_app: SanicASGITestClient):
test_sender = "test_execute_with_missing_action_name"
await _create_tracker_for_sender(rasa_app, test_sender)
Expand Down Expand Up @@ -1775,6 +1792,18 @@ async def test_get_story(
assert response.content.decode().strip() == expected


async def test_get_story_without_conversation_id(
rasa_app: SanicASGITestClient, monkeypatch: MonkeyPatch
):
conversation_id = "some-conversation-ID"
url = f"/conversations/{conversation_id}/story"

_, response = await rasa_app.get(url)

assert response.status == HTTPStatus.NOT_FOUND
assert response.json()["message"] == "Conversation ID not found."


async def test_get_story_does_not_update_conversation_session(
rasa_app: SanicASGITestClient, monkeypatch: MonkeyPatch
):
Expand Down