diff --git a/nemoguardrails/cli/__init__.py b/nemoguardrails/cli/__init__.py index 671567a23..fdfd91451 100644 --- a/nemoguardrails/cli/__init__.py +++ b/nemoguardrails/cli/__init__.py @@ -138,7 +138,7 @@ def server( if config: # We make sure there is no trailing separator, as that might break things in # single config mode. - api.app.rails_config_path = config[0].rstrip(os.path.sep) + api.app.rails_config_path = os.path.expanduser(config[0].rstrip(os.path.sep)) else: # If we don't have a config, we try to see if there is a local config folder local_path = os.getcwd() @@ -189,6 +189,6 @@ def version_callback(value: bool): def cli( _: Optional[bool] = typer.Option( None, "-v", "--version", callback=version_callback, is_eager=True - ) + ), ): pass diff --git a/nemoguardrails/server/api.py b/nemoguardrails/server/api.py index 31f18f9f7..22a764b3c 100644 --- a/nemoguardrails/server/api.py +++ b/nemoguardrails/server/api.py @@ -18,6 +18,7 @@ import json import logging import os.path +import re import time import warnings from typing import Any, List, Optional @@ -253,7 +254,11 @@ def _get_rails(config_ids: List[str]) -> LLMRails: base_path = os.path.abspath(app.rails_config_path) full_path = os.path.normpath(os.path.join(base_path, config_id)) - if not full_path.startswith(base_path + os.sep): + # @NOTE: (Rdinu) Reject config_ids that contain dangerous characters or sequences + if re.search(r"[\\/]|(\.\.)", config_id): + raise ValueError("Invalid config_id.") + + if os.path.commonprefix([full_path, base_path]) != base_path: raise ValueError("Access to the specified path is not allowed.") rails_config = RailsConfig.from_path(full_path)