diff --git a/changelog/5117.improvement.rst b/changelog/5117.improvement.rst new file mode 100644 index 000000000000..8b2f824ff0bf --- /dev/null +++ b/changelog/5117.improvement.rst @@ -0,0 +1,3 @@ +New command-line argument --conversation-id will be added and wiil give the ability to +set specific conversation ID for each shell session, if not passed will be random. + diff --git a/rasa/cli/shell.py b/rasa/cli/shell.py index ebbe4ac0dc1e..572ce1424b84 100644 --- a/rasa/cli/shell.py +++ b/rasa/cli/shell.py @@ -1,5 +1,6 @@ import argparse import logging +import uuid from typing import List @@ -7,7 +8,6 @@ from rasa.cli.utils import print_error from rasa.exceptions import ModelNotFound - logger = logging.getLogger(__name__) @@ -26,7 +26,15 @@ def add_subparser( ) shell_parser.set_defaults(func=shell) + shell_parser.add_argument( + "--conversation-id", + default=uuid.uuid4().hex, + required=False, + help="Set the conversation ID.", + ) + run_subparsers = shell_parser.add_subparsers() + shell_nlu_subparser = run_subparsers.add_parser( "nlu", parents=parents, @@ -34,6 +42,7 @@ def add_subparser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, help="Interprets messages on the command line using your NLU model.", ) + shell_nlu_subparser.set_defaults(func=shell_nlu) arguments.set_shell_arguments(shell_parser) diff --git a/rasa/core/channels/console.py b/rasa/core/channels/console.py index 054b44eb0deb..2fd9b4d8bf13 100644 --- a/rasa/core/channels/console.py +++ b/rasa/core/channels/console.py @@ -12,7 +12,6 @@ from rasa.cli import utils as cli_utils from rasa.core import utils from rasa.core.channels.channel import RestInput -from rasa.core.channels.channel import UserMessage from rasa.core.constants import DEFAULT_SERVER_URL from rasa.core.interpreter import INTENT_MESSAGE_PREFIX from rasa.utils.io import DEFAULT_ENCODING @@ -109,9 +108,9 @@ async def send_message_receive_stream( async def record_messages( + sender_id, server_url=DEFAULT_SERVER_URL, auth_token="", - sender_id=UserMessage.DEFAULT_SENDER_ID, max_message_limit=None, use_response_stream=True, ) -> int: diff --git a/rasa/core/run.py b/rasa/core/run.py index 9455332624b8..6ec641838ce2 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid import os import shutil from functools import partial @@ -87,6 +88,7 @@ def configure_app( port: int = constants.DEFAULT_SERVER_PORT, endpoints: Optional[AvailableEndpoints] = None, log_file: Optional[Text] = None, + conversation_id: Optional[Text] = uuid.uuid4().hex, ): """Run the agent.""" from rasa import server @@ -124,8 +126,10 @@ async def configure_async_logging(): async def run_cmdline_io(running_app: Sanic): """Small wrapper to shut down the server once cmd io is done.""" await asyncio.sleep(1) # allow server to start + await console.record_messages( - server_url=constants.DEFAULT_SERVER_FORMAT.format("http", port) + server_url=constants.DEFAULT_SERVER_FORMAT.format("http", port), + sender_id=conversation_id, ) logger.info("Killing Sanic server now.") @@ -153,6 +157,7 @@ def serve_application( ssl_keyfile: Optional[Text] = None, ssl_ca_file: Optional[Text] = None, ssl_password: Optional[Text] = None, + conversation_id: Optional[Text] = uuid.uuid4().hex, ): from rasa import server @@ -171,6 +176,7 @@ def serve_application( port=port, endpoints=endpoints, log_file=log_file, + conversation_id=conversation_id, ) ssl_context = server.create_ssl_context( diff --git a/rasa/core/training/interactive.py b/rasa/core/training/interactive.py index cad7709162f0..943d493aebf0 100644 --- a/rasa/core/training/interactive.py +++ b/rasa/core/training/interactive.py @@ -1624,7 +1624,7 @@ def run_interactive_learning( else: p = None - app = run.configure_app(enable_api=True) + app = run.configure_app(enable_api=True, conversation_id="default") endpoints = AvailableEndpoints.read_endpoints(server_args.get("endpoints")) # before_server_start handlers make sure the agent is loaded before the diff --git a/tests/cli/test_rasa_shell.py b/tests/cli/test_rasa_shell.py index 7301db203ec0..6f156ba1ba3a 100644 --- a/tests/cli/test_rasa_shell.py +++ b/tests/cli/test_rasa_shell.py @@ -5,9 +5,10 @@ def test_shell_help(run: Callable[..., RunResult]): output = run("shell", "--help") - help_text = """usage: rasa shell [-h] [-v] [-vv] [--quiet] [-m MODEL] [--log-file LOG_FILE] - [--endpoints ENDPOINTS] [-p PORT] [-t AUTH_TOKEN] - [--cors [CORS [CORS ...]]] [--enable-api] + help_text = """usage: rasa shell [-h] [-v] [-vv] [--quiet] + [--conversation-id CONVERSATION_ID] [-m MODEL] + [--log-file LOG_FILE] [--endpoints ENDPOINTS] [-p PORT] + [-t AUTH_TOKEN] [--cors [CORS [CORS ...]]] [--enable-api] [--remote-storage REMOTE_STORAGE] [--ssl-certificate SSL_CERTIFICATE] [--ssl-keyfile SSL_KEYFILE] [--ssl-ca-file SSL_CA_FILE] diff --git a/tests/core/test_channels.py b/tests/core/test_channels.py index d6bedb5e5dcc..3473dfdbfb4e 100644 --- a/tests/core/test_channels.py +++ b/tests/core/test_channels.py @@ -114,7 +114,9 @@ async def test_console_input(): ) await console.record_messages( - server_url="https://example.com", max_message_limit=3 + server_url="https://example.com", + max_message_limit=3, + sender_id="default", ) r = latest_request(