diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f383ceefa2..78117f4c78 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,10 +1,10 @@ * @sfc-gh-turbaszek @sfc-gh-pjob @sfc-gh-jsikorski @sfc-gh-astus @sfc-gh-mraba @sfc-gh-pczajka # Native Apps Owners -src/snowflake/cli/plugins/nativeapp/ @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi -tests/nativeapp/ @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi -tests_integration/test_nativeapp.py @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi +src/snowflake/cli/plugins/nativeapp/ @snowflakedb/nade +tests/nativeapp/ @sfc-gh-bgoel @snowflakedb/nade +tests_integration/test_nativeapp.py @snowflakedb/nade # Project Definition Owners -src/snowflake/cli/api/project/schemas/native_app.py @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi -tests/project/ @sfc-gh-turbaszek @sfc-gh-pjob @sfc-gh-jsikorski @sfc-gh-astus @sfc-gh-mraba @sfc-gh-pczajka @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi +src/snowflake/cli/api/project/schemas/native_app.py @snowflakedb/nade +tests/project/ @sfc-gh-turbaszek @sfc-gh-pjob @sfc-gh-jsikorski @sfc-gh-astus @sfc-gh-mraba @sfc-gh-pczajka @snowflakedb/nade diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6fe4066565..eda7e65d96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,3 +25,32 @@ repos: hooks: - id: mypy additional_dependencies: [types-all] + - repo: local + hooks: + - id: check-print-in-code + language: pygrep + name: "Check for print statements" + entry: "print\\(|echo\\(" + pass_filenames: true + files: ^src/snowflake/.*\.py$ + exclude: > + (?x) + ^src/snowflake/cli/api/console/.*$| + ^src/snowflake/cli/app/printing.py$| + ^src/snowflake/cli/app/dev/.*$| + ^src/snowflake/cli/templates/.*$| + ^src/snowflake/cli/api/utils/rendering.py$| + ^src/snowflake/cli/plugins/spcs/common.py$ + - id: check-app-imports-in-api + language: pygrep + name: "No top level cli.app imports in cli.api" + entry: "^from snowflake\\.cli\\.app" + pass_filenames: true + files: ^src/snowflake/cli/api/.*\.py$ + - id: avoid-snowcli + language: pygrep + name: "Prefer snowflake CLI over snowcli" + entry: "snowcli" + pass_filenames: true + files: ^src/.*\.py$ + exclude: ^src/snowflake/cli/app/telemetry.py$ diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f2119d22fa..33589aa8ee 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -4,6 +4,8 @@ ## New additions * Added support for fully qualified name (`database.schema.name`) in `name` parameter in streamlit project definition +* Added support for fully qualified image repository names in `spcs image-repository` commands. +* Added `--if-not-exists` option to `create` commands for `service`, and `compute-pool`. Added `--replace` and `--if-not-exists` options for `image-repository create`. ## Fixes and improvements * Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands. diff --git a/pyproject.toml b/pyproject.toml index b90d636551..3fbc255223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "setuptools==69.1.1", "snowflake-connector-python[secure-local-storage]==3.7.1", "strictyaml==1.7.3", - "tomlkit==0.12.4", + "tomlkit==0.12.3", "typer==0.9.0", "urllib3>=1.21.1,<2.3", "GitPython==3.1.42", diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index df241f3e69..6aa90f0d3b 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Any, Callable, Optional +from inspect import signature +from typing import Any, Callable, List, Optional, Tuple import click import typer +from click import ClickException from snowflake.cli.api.cli_global_context import cli_context_manager +from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.formats import OutputFormat DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]} @@ -13,6 +16,103 @@ _CLI_BEHAVIOUR = "Global configuration" +class OverrideableOption: + """ + Class that allows you to generate instances of typer.models.OptionInfo with some default properties while allowing + specific values to be overriden. + + Custom parameters: + - mutually_exclusive (Tuple[str]|List[str]): A list of parameter names that this Option is not compatible with. If this Option has + a truthy value and any of the other parameters in the mutually_exclusive list has a truthy value, a + ClickException will be thrown. Note that mutually_exclusive can contain an option's own name but does not require + it. + """ + + def __init__( + self, + default: Any, + *param_decls: str, + mutually_exclusive: Optional[List[str] | Tuple[str]] = None, + **kwargs, + ): + self.default = default + self.param_decls = param_decls + self.mutually_exclusive = mutually_exclusive + self.kwargs = kwargs + + def __call__(self, **kwargs) -> typer.models.OptionInfo: + """ + Returns a typer.models.OptionInfo instance initialized with the specified default values along with any overrides + from kwargs. Note that if you are overriding param_decls, you must pass an iterable of strings, you cannot use + positional arguments like you can with typer.Option. Does not modify the original instance. + """ + default = kwargs.get("default", self.default) + param_decls = kwargs.get("param_decls", self.param_decls) + mutually_exclusive = kwargs.get("mutually_exclusive", self.mutually_exclusive) + if not isinstance(param_decls, list) and not isinstance(param_decls, tuple): + raise TypeError("param_decls must be a list or tuple") + passed_kwargs = self.kwargs.copy() + passed_kwargs.update(kwargs) + if passed_kwargs.get("callback", None) or mutually_exclusive: + passed_kwargs["callback"] = self._callback_factory( + passed_kwargs.get("callback", None), mutually_exclusive + ) + for non_kwarg in ["default", "param_decls", "mutually_exclusive"]: + passed_kwargs.pop(non_kwarg, None) + return typer.Option(default, *param_decls, **passed_kwargs) + + class InvalidCallbackSignature(ClickException): + def __init__(self, callback): + super().__init__( + f"Signature {signature(callback)} is not valid for an OverrideableOption callback function. Must have at most one parameter with each of the following types: (typer.Context, typer.CallbackParam, Any Other Type)" + ) + + def _callback_factory( + self, callback, mutually_exclusive: Optional[List[str] | Tuple[str]] + ): + callback = callback if callback else lambda x: x + + # inspect existing_callback to make sure signature is valid + existing_params = signature(callback).parameters + # at most one parameter with each type in [typer.Context, typer.CallbackParam, any other type] + limits = [ + lambda x: x == typer.Context, + lambda x: x == typer.CallbackParam, + lambda x: x != typer.Context and x != typer.CallbackParam, + ] + for limit in limits: + if len([v for v in existing_params.values() if limit(v.annotation)]) > 1: + raise self.InvalidCallbackSignature(callback) + + def generated_callback(ctx: typer.Context, param: typer.CallbackParam, value): + if mutually_exclusive: + for name in mutually_exclusive: + if value and ctx.params.get( + name, False + ): # if the current parameter is set to True and a previous parameter is also Truthy + curr_opt = param.opts[0] + other_opt = [x for x in ctx.command.params if x.name == name][ + 0 + ].opts[0] + raise click.ClickException( + f"Options '{curr_opt}' and '{other_opt}' are incompatible." + ) + + # pass args to existing callback based on its signature (this is how Typer infers callback args) + passed_params = {} + for existing_param in existing_params: + annotation = existing_params[existing_param].annotation + if annotation == typer.Context: + passed_params[existing_param] = ctx + elif annotation == typer.CallbackParam: + passed_params[existing_param] = param + else: + passed_params[existing_param] = value + return callback(**passed_params) + + return generated_callback + + def _callback(provide_setter: Callable[[], Callable[[Any], Any]]): def callback(value): set_value = provide_setter() @@ -73,7 +173,7 @@ def callback(value): def _password_callback(value: str): if value: - click.echo(PLAIN_PASSWORD_MSG) + cli_console.message(PLAIN_PASSWORD_MSG) return _callback(lambda: cli_context_manager.connection_context.set_password)(value) @@ -181,6 +281,7 @@ def _password_callback(value: str): callback=_callback(lambda: cli_context_manager.set_silent), is_flag=True, rich_help_panel=_CLI_BEHAVIOUR, + is_eager=True, ) VerboseOption = typer.Option( @@ -209,6 +310,32 @@ def _password_callback(value: str): help='Regular expression for filtering objects by name. For example, `list --like "my%"` lists all objects that begin with “my”.', ) +# If IfExistsOption, IfNotExistsOption, or ReplaceOption are used with names other than those in CREATE_MODE_OPTION_NAMES, +# you must also override mutually_exclusive if you want to retain the validation that at most one of these flags is +# passed. +CREATE_MODE_OPTION_NAMES = ["if_exists", "if_not_exists", "replace"] + +IfExistsOption = OverrideableOption( + False, + "--if-exists", + help="Only apply this operation if the specified object exists.", + mutually_exclusive=CREATE_MODE_OPTION_NAMES, +) + +IfNotExistsOption = OverrideableOption( + False, + "--if-not-exists", + help="Only apply this operation if the specified object does not already exist.", + mutually_exclusive=CREATE_MODE_OPTION_NAMES, +) + +ReplaceOption = OverrideableOption( + False, + "--replace", + help="Replace this object if it already exists.", + mutually_exclusive=CREATE_MODE_OPTION_NAMES, +) + def experimental_option( experimental_behaviour_description: Optional[str] = None, @@ -268,31 +395,3 @@ def _callback(project_path: Optional[str]): callback=_callback, show_default=False, ) - - -class OverrideableOption: - """ - Class that allows you to generate instances of typer.models.OptionInfo with some default properties while allowing specific values to be overriden. - """ - - def __init__(self, default: Any, *param_decls: str, **kwargs): - self.default = default - self.param_decls = param_decls - self.kwargs = kwargs - - def __call__(self, **kwargs) -> typer.models.OptionInfo: - """ - Returns a typer.models.OptionInfo instance initialized with the specified default values along with any overrides - from kwargs.Note that if you are overriding param_decls, - you must pass an iterable of strings, you cannot use positional arguments like you can with typer.Option. - Does not modify the original instance. - """ - default = kwargs.get("default", self.default) - param_decls = kwargs.get("param_decls", self.param_decls) - if not isinstance(param_decls, list) and not isinstance(param_decls, tuple): - raise TypeError("param_decls must be a list or tuple") - passed_kwargs = self.kwargs.copy() - passed_kwargs.update(kwargs) - passed_kwargs.pop("default", None) - passed_kwargs.pop("param_decls", None) - return typer.Option(default, *param_decls, **passed_kwargs) diff --git a/src/snowflake/cli/api/commands/snow_typer.py b/src/snowflake/cli/api/commands/snow_typer.py index 1cfc5596ad..f0cf14afb1 100644 --- a/src/snowflake/cli/api/commands/snow_typer.py +++ b/src/snowflake/cli/api/commands/snow_typer.py @@ -12,8 +12,6 @@ from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS from snowflake.cli.api.exceptions import CommandReturnTypeError from snowflake.cli.api.output.types import CommandResult -from snowflake.cli.app.printing import print_result -from snowflake.cli.app.telemetry import flush_telemetry, log_command_usage log = logging.getLogger(__name__) @@ -73,12 +71,16 @@ def pre_execute(): Pay attention to make this method safe to use if performed operations are not necessary for executing the command in proper way. """ + from snowflake.cli.app.telemetry import log_command_usage + log.debug("Executing command pre execution callback") log_command_usage() @staticmethod def process_result(result): """Command result processor""" + from snowflake.cli.app.printing import print_result + # Because we still have commands like "logs" that do not return anything. # We should improve it in future. if not result: @@ -100,5 +102,7 @@ def post_execute(): Callback executed after running any command callable. Pay attention to make this method safe to use if performed operations are not necessary for executing the command in proper way. """ + from snowflake.cli.app.telemetry import flush_telemetry + log.debug("Executing command post execution callback") flush_telemetry() diff --git a/src/snowflake/cli/api/secure_path.py b/src/snowflake/cli/api/secure_path.py index 11d8d5b6b7..5287fc6dea 100644 --- a/src/snowflake/cli/api/secure_path.py +++ b/src/snowflake/cli/api/secure_path.py @@ -265,7 +265,7 @@ def temporary_directory(cls): Works similarly to tempfile.TemporaryDirectory """ - with tempfile.TemporaryDirectory(prefix="snowcli") as tmpdir: + with tempfile.TemporaryDirectory(prefix="snowflake-cli") as tmpdir: log.info("Created temporary directory %s", tmpdir) yield SecurePath(tmpdir) log.info("Removing temporary directory %s", tmpdir) diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index 248c3fa21f..3b2a04fcba 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -5,9 +5,8 @@ from functools import cached_property from io import StringIO from textwrap import dedent -from typing import Iterable, Optional +from typing import Iterable, Optional, Tuple -from click import ClickException from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.exceptions import ( DatabaseNotProvidedError, @@ -19,6 +18,7 @@ unquote_identifier, ) from snowflake.cli.api.utils.cursor import find_first_row +from snowflake.cli.api.utils.naming_utils import from_qualified_name from snowflake.connector.cursor import DictCursor, SnowflakeCursor from snowflake.connector.errors import ProgrammingError @@ -82,44 +82,27 @@ def use_role(self, new_role: str): if is_different_role: self._execute_query(f"use role {prev_role}") - def _execute_schema_query(self, query: str, **kwargs): - self.check_database_and_schema() - return self._execute_query(query, **kwargs) - - def check_database_and_schema(self) -> None: + def _execute_schema_query(self, query: str, name: Optional[str] = None, **kwargs): """ - Checks if the connection database and schema are set and that they actually exist in Snowflake. + Check that a database and schema are provided before executing the query. Useful for operating on schema level objects. """ - self.check_schema_exists(self._conn.database, self._conn.schema) + self.check_database_and_schema_provided(name) + return self._execute_query(query, **kwargs) - def check_database_exists(self, database: str) -> None: + def check_database_and_schema_provided(self, name: Optional[str] = None) -> None: """ - Checks that database is provided and that it is a valid database in - Snowflake. Note that this could fail for a variety of reasons, - including not authorized to use database, database doesn't exist, - database is not a valid identifier, and more. + Checks if a database and schema are provided, either through the connection context or a qualified name. """ + if name: + _, schema, database = from_qualified_name(name) + else: + schema, database = None, None + schema = schema or self._conn.schema + database = database or self._conn.database if not database: raise DatabaseNotProvidedError() - try: - self._execute_query(f"USE DATABASE {database}") - except ProgrammingError as e: - raise ClickException(f"Exception occurred: {e}.") from e - - def check_schema_exists(self, database: str, schema: str) -> None: - """ - Checks that schema is provided and that it is a valid schema in Snowflake. - Note that this could fail for a variety of reasons, - including not authorized to use schema, schema doesn't exist, - schema is not a valid identifier, and more. - """ - self.check_database_exists(database) if not schema: raise SchemaNotProvidedError() - try: - self._execute_query(f"USE {database}.{schema}") - except ProgrammingError as e: - raise ClickException(f"Exception occurred: {e}.") from e def to_fully_qualified_name( self, name: str, database: Optional[str] = None, schema: Optional[str] = None @@ -131,9 +114,7 @@ def to_fully_qualified_name( if not database: if not self._conn.database: - raise ClickException( - "Default database not specified in connection details." - ) + raise DatabaseNotProvidedError() database = self._conn.database if len(current_parts) == 2: @@ -150,29 +131,65 @@ def get_name_from_fully_qualified_name(name): Returns name of the object from the fully-qualified name. Assumes that [name] is in format [[database.]schema.]name """ - return name.split(".")[-1] + return from_qualified_name(name)[0] + + @staticmethod + def _qualified_name_to_in_clause(name: str) -> Tuple[str, Optional[str]]: + unqualified_name, schema, database = from_qualified_name(name) + if database: + in_clause = f"in schema {database}.{schema}" + elif schema: + in_clause = f"in schema {schema}" + else: + in_clause = None + return unqualified_name, in_clause + + class InClauseWithQualifiedNameError(ValueError): + def __init__(self): + super().__init__("non-empty 'in_clause' passed with qualified 'name'") def show_specific_object( self, object_type_plural: str, - unqualified_name: str, + name: str, name_col: str = "name", in_clause: str = "", check_schema: bool = False, ) -> Optional[dict]: """ Executes a "show like" query for a particular entity with a - given (unqualified) name. This command is useful when the corresponding + given (optionally qualified) name. This command is useful when the corresponding "describe " query does not provide the information you seek. + + Note that this command is analogous to describe and should only return a single row. + If the target object type is a schema level object, then check_schema should be set to True + so that the function will verify that a database and schema are provided, either through + the connection or a qualified name, before executing the query. """ - if check_schema: - self.check_database_and_schema() + + unqualified_name, name_in_clause = self._qualified_name_to_in_clause(name) + if in_clause and name_in_clause: + raise self.InClauseWithQualifiedNameError() + elif name_in_clause: + in_clause = name_in_clause show_obj_query = f"show {object_type_plural} like {identifier_to_show_like_pattern(unqualified_name)} {in_clause}".strip() - show_obj_cursor = self._execute_query( # type: ignore - show_obj_query, cursor_class=DictCursor - ) + + if check_schema: + show_obj_cursor = self._execute_schema_query( # type: ignore + show_obj_query, name=name, cursor_class=DictCursor + ) + else: + show_obj_cursor = self._execute_query( # type: ignore + show_obj_query, cursor_class=DictCursor + ) + if show_obj_cursor.rowcount is None: raise SnowflakeSQLExecutionError(show_obj_query) + elif show_obj_cursor.rowcount > 1: + raise ProgrammingError( + f"Received multiple rows from result of SQL statement: {show_obj_query}. Usage of 'show_specific_object' may not be properly scoped." + ) + show_obj_row = find_first_row( show_obj_cursor, lambda row: row[name_col] == unquote_identifier(unqualified_name), diff --git a/src/snowflake/cli/api/utils/naming_utils.py b/src/snowflake/cli/api/utils/naming_utils.py new file mode 100644 index 0000000000..895698cc6b --- /dev/null +++ b/src/snowflake/cli/api/utils/naming_utils.py @@ -0,0 +1,27 @@ +import re +from typing import Optional, Tuple + +from snowflake.cli.api.project.util import ( + VALID_IDENTIFIER_REGEX, +) + + +def from_qualified_name(name: str) -> Tuple[str, Optional[str], Optional[str]]: + """ + Takes in an object name in the form [[database.]schema.]name. Returns a tuple (name, [schema], [database]) + """ + # TODO: Use regex to match object name to a valid identifier or valid identifier (args). Second case is for sprocs and UDFs + qualifier_pattern = rf"(?:(?P{VALID_IDENTIFIER_REGEX})\.)?(?:(?P{VALID_IDENTIFIER_REGEX})\.)?(?P.*)" + result = re.fullmatch(qualifier_pattern, name) + + if result is None: + raise ValueError(f"'{name}' is not a valid qualified name") + + unqualified_name = result.group("name") + if result.group("second_qualifier") is not None: + database = result.group("first_qualifier") + schema = result.group("second_qualifier") + else: + database = None + schema = result.group("first_qualifier") + return unqualified_name, schema, database diff --git a/src/snowflake/cli/app/__main__.py b/src/snowflake/cli/app/__main__.py index 68aadbe2f4..c187979075 100644 --- a/src/snowflake/cli/app/__main__.py +++ b/src/snowflake/cli/app/__main__.py @@ -2,10 +2,11 @@ import sys -from snowflake.cli.app.cli_app import app +from snowflake.cli.app.cli_app import app_factory def main(*args): + app = app_factory() app(*args) diff --git a/src/snowflake/cli/app/cli_app.py b/src/snowflake/cli/app/cli_app.py index cca66eb61c..a7e22e08e9 100644 --- a/src/snowflake/cli/app/cli_app.py +++ b/src/snowflake/cli/app/cli_app.py @@ -28,10 +28,9 @@ setup_pycharm_remote_debugger_if_provided, ) from snowflake.cli.app.main_typer import SnowCliMainTyper -from snowflake.cli.app.printing import print_result +from snowflake.cli.app.printing import MessageResult, print_result from snowflake.connector.config_manager import CONFIG_MANAGER -app: SnowCliMainTyper = SnowCliMainTyper() log = logging.getLogger(__name__) _api = Api(plugin_config_provider=PluginConfigProviderImpl()) @@ -104,7 +103,7 @@ def _commands_structure_callback(value: bool): @_do_not_execute_on_completion def _version_callback(value: bool): if value: - typer.echo(f"Snowflake CLI version: {__about__.VERSION}") + print_result(MessageResult(f"Snowflake CLI version: {__about__.VERSION}")) _exit_with_cleanup() @@ -126,93 +125,98 @@ def _info_callback(value: bool): _exit_with_cleanup() -@app.callback() -def default( - version: bool = typer.Option( - None, - "--version", - help="Shows version of the Snowflake CLI", - callback=_version_callback, - is_eager=True, - ), - docs: bool = typer.Option( - None, - "--docs", - hidden=True, - help="Generates Snowflake CLI documentation", - callback=_docs_callback, - is_eager=True, - ), - structure: bool = typer.Option( - None, - "--structure", - hidden=True, - help="Prints Snowflake CLI structure of commands", - callback=_commands_structure_callback, - is_eager=True, - ), - info: bool = typer.Option( - None, - "--info", - help="Shows information about the Snowflake CLI", - callback=_info_callback, - ), - configuration_file: Path = typer.Option( - None, - "--config-file", - help="Specifies Snowflake CLI configuration file that should be used", - exists=True, - dir_okay=False, - is_eager=True, - callback=_config_init_callback, - ), - pycharm_debug_library_path: str = typer.Option( - None, - "--pycharm-debug-library-path", - hidden=True, - ), - pycharm_debug_server_host: str = typer.Option( - "localhost", - "--pycharm-debug-server-host", - hidden=True, - ), - pycharm_debug_server_port: int = typer.Option( - 12345, - "--pycharm-debug-server-port", - hidden=True, - ), - disable_external_command_plugins: bool = typer.Option( - None, - "--disable-external-command-plugins", - help="Disable external command plugins", - callback=_disable_external_command_plugins_callback, - is_eager=True, - hidden=True, - ), - # THIS OPTION SHOULD BE THE LAST OPTION IN THE LIST! - # --- - # This is a hidden artificial option used only to guarantee execution of commands registration - # and make this guaranty not dependent on other callbacks. - # Commands registration is invoked as soon as all callbacks - # decorated with "_commands_registration.before" are executed - # but if there are no such callbacks (at the result of possible future changes) - # then we need to invoke commands registration manually. - # - # This option is also responsible for resetting registration state for test purposes. - commands_registration: bool = typer.Option( - True, - "--commands-registration", - help="Commands registration", - hidden=True, - is_eager=True, - callback=_commands_registration_callback, - ), -) -> None: - """ - Snowflake CLI tool for developers. - """ - setup_pycharm_remote_debugger_if_provided( - pycharm_debug_library_path=pycharm_debug_library_path, - pycharm_debug_server_host=pycharm_debug_server_host, - pycharm_debug_server_port=pycharm_debug_server_port, - ) +def app_factory() -> SnowCliMainTyper: + app = SnowCliMainTyper() + + @app.callback() + def default( + version: bool = typer.Option( + None, + "--version", + help="Shows version of the Snowflake CLI", + callback=_version_callback, + is_eager=True, + ), + docs: bool = typer.Option( + None, + "--docs", + hidden=True, + help="Generates Snowflake CLI documentation", + callback=_docs_callback, + is_eager=True, + ), + structure: bool = typer.Option( + None, + "--structure", + hidden=True, + help="Prints Snowflake CLI structure of commands", + callback=_commands_structure_callback, + is_eager=True, + ), + info: bool = typer.Option( + None, + "--info", + help="Shows information about the Snowflake CLI", + callback=_info_callback, + ), + configuration_file: Path = typer.Option( + None, + "--config-file", + help="Specifies Snowflake CLI configuration file that should be used", + exists=True, + dir_okay=False, + is_eager=True, + callback=_config_init_callback, + ), + pycharm_debug_library_path: str = typer.Option( + None, + "--pycharm-debug-library-path", + hidden=True, + ), + pycharm_debug_server_host: str = typer.Option( + "localhost", + "--pycharm-debug-server-host", + hidden=True, + ), + pycharm_debug_server_port: int = typer.Option( + 12345, + "--pycharm-debug-server-port", + hidden=True, + ), + disable_external_command_plugins: bool = typer.Option( + None, + "--disable-external-command-plugins", + help="Disable external command plugins", + callback=_disable_external_command_plugins_callback, + is_eager=True, + hidden=True, + ), + # THIS OPTION SHOULD BE THE LAST OPTION IN THE LIST! + # --- + # This is a hidden artificial option used only to guarantee execution of commands registration + # and make this guaranty not dependent on other callbacks. + # Commands registration is invoked as soon as all callbacks + # decorated with "_commands_registration.before" are executed + # but if there are no such callbacks (at the result of possible future changes) + # then we need to invoke commands registration manually. + # + # This option is also responsible for resetting registration state for test purposes. + commands_registration: bool = typer.Option( + True, + "--commands-registration", + help="Commands registration", + hidden=True, + is_eager=True, + callback=_commands_registration_callback, + ), + ) -> None: + """ + Snowflake CLI tool for developers. + """ + setup_pycharm_remote_debugger_if_provided( + pycharm_debug_library_path=pycharm_debug_library_path, + pycharm_debug_server_host=pycharm_debug_server_host, + pycharm_debug_server_port=pycharm_debug_server_port, + ) + + return app diff --git a/src/snowflake/cli/app/loggers.py b/src/snowflake/cli/app/loggers.py index c9c43c62b6..68b6ccb4f3 100644 --- a/src/snowflake/cli/app/loggers.py +++ b/src/snowflake/cli/app/loggers.py @@ -10,7 +10,7 @@ from snowflake.cli.api.exceptions import InvalidLogsConfiguration from snowflake.cli.api.secure_path import SecurePath -_DEFAULT_LOG_FILENAME = "snowcli.log" +_DEFAULT_LOG_FILENAME = "snowflake-cli.log" @dataclass diff --git a/src/snowflake/cli/app/main_typer.py b/src/snowflake/cli/app/main_typer.py index aeae62b2cc..3498d91f1c 100644 --- a/src/snowflake/cli/app/main_typer.py +++ b/src/snowflake/cli/app/main_typer.py @@ -3,16 +3,16 @@ import sys import typer -from rich import print as rich_print from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS, DebugOption +from snowflake.cli.api.console import cli_console def _handle_exception(exception: Exception): if cli_context.enable_tracebacks: raise exception else: - rich_print( + cli_console.warning( "\nAn unexpected exception occurred. Use --debug option to see the traceback. Exception message:\n\n" + exception.__str__() ) diff --git a/src/snowflake/cli/plugins/connection/commands.py b/src/snowflake/cli/plugins/connection/commands.py index a92f9fdbf8..14bfd54982 100644 --- a/src/snowflake/cli/plugins/connection/commands.py +++ b/src/snowflake/cli/plugins/connection/commands.py @@ -2,7 +2,6 @@ import logging -import click import typer from click import ClickException, Context, Parameter # type: ignore from click.core import ParameterSource # type: ignore @@ -21,6 +20,7 @@ get_connection, set_config_value, ) +from snowflake.cli.api.console import cli_console from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.output.types import ( CollectionResult, @@ -81,7 +81,7 @@ def callback(value: str): def _password_callback(ctx: Context, param: Parameter, value: str): if value and ctx.get_parameter_source(param.name) == ParameterSource.COMMANDLINE: # type: ignore - click.echo(PLAIN_PASSWORD_MSG) + cli_console.warning(PLAIN_PASSWORD_MSG) return value diff --git a/src/snowflake/cli/plugins/object/common.py b/src/snowflake/cli/plugins/object/common.py index b53ac670e8..43424f9bbc 100644 --- a/src/snowflake/cli/plugins/object/common.py +++ b/src/snowflake/cli/plugins/object/common.py @@ -4,8 +4,7 @@ from click import ClickException from snowflake.cli.api.commands.flags import OverrideableOption from snowflake.cli.api.project.util import ( - QUOTED_IDENTIFIER_REGEX, - UNQUOTED_IDENTIFIER_REGEX, + VALID_IDENTIFIER_REGEX, is_valid_identifier, to_string_literal, ) @@ -34,11 +33,9 @@ def __init__(self): def _parse_tag(tag: str) -> Tag: import re - identifier_pattern = re.compile( - f"(?P{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})" - ) - value_pattern = re.compile(f"(?P.+)") - result = re.fullmatch(f"{identifier_pattern.pattern}={value_pattern.pattern}", tag) + identifier_pattern = rf"(?P{VALID_IDENTIFIER_REGEX})" + value_pattern = r"(?P.+)" + result = re.fullmatch(rf"{identifier_pattern}={value_pattern}", tag) if result is not None: try: return Tag(result.group("tag_name"), result.group("tag_value")) diff --git a/src/snowflake/cli/plugins/object/stage/commands.py b/src/snowflake/cli/plugins/object/stage/commands.py index bf0540eca1..f5282f9ee6 100644 --- a/src/snowflake/cli/plugins/object/stage/commands.py +++ b/src/snowflake/cli/plugins/object/stage/commands.py @@ -41,7 +41,7 @@ def copy( help="Source path for copy operation. Can be either stage path or local." ), destination_path: str = typer.Argument( - help="Target path for copy operation. Should be stage if source is local or local if source is stage.", + help="Target directory path for copy operation. Should be stage if source is local or local if source is stage.", ), overwrite: bool = typer.Option( False, diff --git a/src/snowflake/cli/plugins/snowpark/commands.py b/src/snowflake/cli/plugins/snowpark/commands.py index ec95b51c4b..f938d9f5a2 100644 --- a/src/snowflake/cli/plugins/snowpark/commands.py +++ b/src/snowflake/cli/plugins/snowpark/commands.py @@ -12,6 +12,7 @@ with_project_definition, ) from snowflake.cli.api.commands.flags import ( + ReplaceOption, execution_identifier_argument, ) from snowflake.cli.api.commands.project_initialisation import add_init_command @@ -50,12 +51,6 @@ help="Manages procedures and functions.", ) -ReplaceOption = typer.Option( - False, - "--replace", - help="Replaces procedure or function, even if no detected changes to metadata", -) - ObjectTypeArgument = typer.Argument( help="Type of Snowpark object", case_sensitive=False, @@ -67,7 +62,9 @@ @app.command("deploy", requires_connection=True) @with_project_definition("snowpark") def deploy( - replace: bool = ReplaceOption, + replace: bool = ReplaceOption( + help="Replaces procedure or function, even if no detected changes to metadata" + ), **options, ) -> CommandResult: """ @@ -116,7 +113,7 @@ def deploy( stage_manager = StageManager() stage_name = stage_manager.to_fully_qualified_name(stage_name) stage_manager.create( - stage_name=stage_name, comment="deployments managed by snowcli" + stage_name=stage_name, comment="deployments managed by Snowflake CLI" ) packages = get_snowflake_packages() diff --git a/src/snowflake/cli/plugins/spcs/common.py b/src/snowflake/cli/plugins/spcs/common.py index a4bb5acded..07d8a4ce2a 100644 --- a/src/snowflake/cli/plugins/spcs/common.py +++ b/src/snowflake/cli/plugins/spcs/common.py @@ -66,12 +66,16 @@ def validate_and_set_instances(min_instances, max_instances, instance_name): def handle_object_already_exists( - error: ProgrammingError, object_type: ObjectType, object_name: str + error: ProgrammingError, + object_type: ObjectType, + object_name: str, + replace_available: bool = False, ): if error.errno == 2002: raise ObjectAlreadyExistsError( object_type=object_type, name=unquote_identifier(object_name), + replace_available=replace_available, ) else: raise error diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py index 05979e052f..f069f4a2c3 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py @@ -2,9 +2,7 @@ import typer from click import ClickException -from snowflake.cli.api.commands.flags import ( - OverrideableOption, -) +from snowflake.cli.api.commands.flags import IfNotExistsOption, OverrideableOption from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import CommandResult, SingleQueryResult from snowflake.cli.api.project.util import is_valid_object_name @@ -90,6 +88,7 @@ def create( ), auto_suspend_secs: int = AutoSuspendSecsOption(), comment: Optional[str] = CommentOption(help=_COMMENT_HELP), + if_not_exists: bool = IfNotExistsOption(), **options, ) -> CommandResult: """ @@ -105,6 +104,7 @@ def create( initially_suspended=initially_suspended, auto_suspend_secs=auto_suspend_secs, comment=comment, + if_not_exists=if_not_exists, ) return SingleQueryResult(cursor) diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/manager.py b/src/snowflake/cli/plugins/spcs/compute_pool/manager.py index 416e16020a..4e7720e271 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/manager.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/manager.py @@ -22,9 +22,13 @@ def create( initially_suspended: bool, auto_suspend_secs: int, comment: Optional[str], + if_not_exists: bool, ) -> SnowflakeCursor: + create_statement = "CREATE COMPUTE POOL" + if if_not_exists: + create_statement = f"{create_statement} IF NOT EXISTS" query = f"""\ - CREATE COMPUTE POOL {pool_name} + {create_statement} {pool_name} MIN_NODES = {min_nodes} MAX_NODES = {max_nodes} INSTANCE_FAMILY = {instance_family} diff --git a/src/snowflake/cli/plugins/spcs/image_repository/commands.py b/src/snowflake/cli/plugins/spcs/image_repository/commands.py index 301e77b2bc..b765ba142c 100644 --- a/src/snowflake/cli/plugins/spcs/image_repository/commands.py +++ b/src/snowflake/cli/plugins/spcs/image_repository/commands.py @@ -4,7 +4,9 @@ import requests import typer from click import ClickException +from snowflake.cli.api.commands.flags import IfNotExistsOption, ReplaceOption from snowflake.cli.api.commands.snow_typer import SnowTyper +from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.types import ( CollectionResult, MessageResult, @@ -22,7 +24,7 @@ def _repo_name_callback(name: str): - if not is_valid_object_name(name, max_depth=0, allow_quoted=True): + if not is_valid_object_name(name, max_depth=2, allow_quoted=False): raise ClickException( f"'{name}' is not a valid image repository name. Note that image repository names must be unquoted identifiers. The same constraint also applies to database and schema names where you create an image repository." ) @@ -38,12 +40,18 @@ def _repo_name_callback(name: str): @app.command(requires_connection=True) def create( name: str = REPO_NAME_ARGUMENT, + replace: bool = ReplaceOption(), + if_not_exists: bool = IfNotExistsOption(), **options, ): """ Creates a new image repository in the current schema. """ - return SingleQueryResult(ImageRepositoryManager().create(name=name)) + return SingleQueryResult( + ImageRepositoryManager().create( + name=name, replace=replace, if_not_exists=if_not_exists + ) + ) @app.command("list-images", requires_connection=True) @@ -119,7 +127,7 @@ def list_tags( ) if response.status_code != 200: - print("Call to the registry failed", response.text) + cli_console.warning(f"Call to the registry failed {response.text}") data = json.loads(response.text) if "tags" in data: diff --git a/src/snowflake/cli/plugins/spcs/image_repository/manager.py b/src/snowflake/cli/plugins/spcs/image_repository/manager.py index f04cd52cac..4bb6ef4163 100644 --- a/src/snowflake/cli/plugins/spcs/image_repository/manager.py +++ b/src/snowflake/cli/plugins/spcs/image_repository/manager.py @@ -1,10 +1,6 @@ from urllib.parse import urlparse -from click import ClickException from snowflake.cli.api.constants import ObjectType -from snowflake.cli.api.project.util import ( - is_valid_unquoted_identifier, -) from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.plugins.spcs.common import handle_object_already_exists from snowflake.connector.errors import ProgrammingError @@ -21,17 +17,13 @@ def get_role(self): return self._conn.role def get_repository_url(self, repo_name: str, with_scheme: bool = True): - if not is_valid_unquoted_identifier(repo_name): - raise ValueError( - f"repo_name '{repo_name}' is not a valid unquoted Snowflake identifier" - ) - # we explicitly do not allow this function to be used without connection database and schema set + repo_row = self.show_specific_object( "image repositories", repo_name, check_schema=True ) if repo_row is None: - raise ClickException( - f"Image repository '{repo_name}' does not exist in database '{self.get_database()}' and schema '{self.get_schema()}' or not authorized." + raise ProgrammingError( + f"Image repository '{self.to_fully_qualified_name(repo_name)}' does not exist or not authorized." ) if with_scheme: return f"https://{repo_row['repository_url']}" @@ -51,8 +43,26 @@ def get_repository_api_url(self, repo_url): return f"{scheme}://{host}/v2{path}" - def create(self, name: str): + def create( + self, + name: str, + if_not_exists: bool, + replace: bool, + ): + if if_not_exists and replace: + raise ValueError( + "'replace' and 'if_not_exists' options are mutually exclusive for ImageRepositoryManager.create" + ) + elif replace: + create_statement = "create or replace image repository" + elif if_not_exists: + create_statement = "create image repository if not exists" + else: + create_statement = "create image repository" + try: - return self._execute_schema_query(f"create image repository {name}") + return self._execute_schema_query(f"{create_statement} {name}", name=name) except ProgrammingError as e: - handle_object_already_exists(e, ObjectType.IMAGE_REPOSITORY, name) + handle_object_already_exists( + e, ObjectType.IMAGE_REPOSITORY, name, replace_available=True + ) diff --git a/src/snowflake/cli/plugins/spcs/services/commands.py b/src/snowflake/cli/plugins/spcs/services/commands.py index b9acfe4959..976e9b0bd2 100644 --- a/src/snowflake/cli/plugins/spcs/services/commands.py +++ b/src/snowflake/cli/plugins/spcs/services/commands.py @@ -4,9 +4,7 @@ import typer from click import ClickException -from snowflake.cli.api.commands.flags import ( - OverrideableOption, -) +from snowflake.cli.api.commands.flags import IfNotExistsOption, OverrideableOption from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import ( CommandResult, @@ -95,6 +93,7 @@ def create( query_warehouse: Optional[str] = QueryWarehouseOption(), tags: Optional[List[Tag]] = TagOption(help="Tag for the service."), comment: Optional[str] = CommentOption(help=_COMMENT_HELP), + if_not_exists: bool = IfNotExistsOption(), **options, ) -> CommandResult: """ @@ -114,6 +113,7 @@ def create( query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=if_not_exists, ) return SingleQueryResult(cursor) diff --git a/src/snowflake/cli/plugins/spcs/services/manager.py b/src/snowflake/cli/plugins/spcs/services/manager.py index 5e006a6834..4f16df4ca4 100644 --- a/src/snowflake/cli/plugins/spcs/services/manager.py +++ b/src/snowflake/cli/plugins/spcs/services/manager.py @@ -27,11 +27,14 @@ def create( query_warehouse: Optional[str], tags: Optional[List[Tag]], comment: Optional[str], + if_not_exists: bool, ) -> SnowflakeCursor: spec = self._read_yaml(spec_path) - + create_statement = "CREATE SERVICE" + if if_not_exists: + create_statement = f"{create_statement} IF NOT EXISTS" query = f"""\ - CREATE SERVICE {service_name} + {create_statement} {service_name} IN COMPUTE POOL {compute_pool} FROM SPECIFICATION $$ {spec} diff --git a/src/snowflake/cli/plugins/streamlit/commands.py b/src/snowflake/cli/plugins/streamlit/commands.py index 2e713917d1..c2b68df6d6 100644 --- a/src/snowflake/cli/plugins/streamlit/commands.py +++ b/src/snowflake/cli/plugins/streamlit/commands.py @@ -1,6 +1,5 @@ import logging from pathlib import Path -from typing import Optional import click import typer @@ -10,6 +9,7 @@ with_experimental_behaviour, with_project_definition, ) +from snowflake.cli.api.commands.flags import ReplaceOption from snowflake.cli.api.commands.project_initialisation import add_init_command from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import ( @@ -70,12 +70,7 @@ def _check_file_exists_if_not_default(ctx: click.Context, value): @with_project_definition("streamlit") @with_experimental_behaviour() def streamlit_deploy( - replace: Optional[bool] = typer.Option( - False, - "--replace", - help="Replace the Streamlit if it already exists.", - is_flag=True, - ), + replace: bool = ReplaceOption(help="Replace the Streamlit if it already exists."), open_: bool = typer.Option( False, "--open", help="Whether to open Streamlit in a browser.", is_flag=True ), diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index f485746644..42f4795802 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -1,4 +1,33 @@ # serializer version: 1 +# name: test_help_messages[] + ''' + + Usage: default [OPTIONS] COMMAND [ARGS]... + + Snowflake CLI tool for developers. + + ╭─ Options ────────────────────────────────────────────────────────────────────╮ + │ --version Shows version of the Snowflake CLI │ + │ --info Shows information about the Snowflake CLI │ + │ --config-file FILE Specifies Snowflake CLI configuration file that │ + │ should be used │ + │ [default: None] │ + │ --help -h Show this message and exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Commands ───────────────────────────────────────────────────────────────────╮ + │ app Manages a Snowflake Native App │ + │ connection Manages connections to Snowflake. │ + │ object Manages Snowflake objects like warehouses and stages │ + │ snowpark Manages procedures and functions. │ + │ spcs Manages Snowpark Container Services compute pools, services, │ + │ image registries, and image repositories. │ + │ sql Executes Snowflake query. │ + │ streamlit Manages Streamlit in Snowflake. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + + ''' +# --- # name: test_help_messages[app.bundle] ''' @@ -988,9 +1017,9 @@ │ either stage path or local. │ │ [default: None] │ │ [required] │ - │ * destination_path TEXT Target path for copy operation. Should be │ - │ stage if source is local or local if source │ - │ is stage. │ + │ * destination_path TEXT Target directory path for copy operation. │ + │ Should be stage if source is local or local │ + │ if source is stage. │ │ [default: None] │ │ [required] │ ╰──────────────────────────────────────────────────────────────────────────────╯ @@ -1888,6 +1917,11 @@ │ --comment TEXT Comment for the │ │ compute pool. │ │ [default: None] │ + │ --if-not-exists Only apply this │ + │ operation if the │ + │ specified object │ + │ does not already │ + │ exist. │ │ --help -h Show this │ │ message and │ │ exit. │ @@ -2609,7 +2643,10 @@ │ [required] │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Options ────────────────────────────────────────────────────────────────────╮ - │ --help -h Show this message and exit. │ + │ --replace Replace this object if it already exists. │ + │ --if-not-exists Only apply this operation if the specified object │ + │ does not already exist. │ + │ --help -h Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Connection configuration ───────────────────────────────────────────────────╮ │ --connection,--environment -c TEXT Name of the connection, as defined │ @@ -3160,6 +3197,11 @@ │ --comment TEXT Comment for the │ │ service. │ │ [default: None] │ + │ --if-not-exists Only apply this │ + │ operation if the │ + │ specified object │ + │ does not already │ + │ exist. │ │ --help -h Show this │ │ message and │ │ exit. │ diff --git a/tests/__snapshots__/test_snow_connector.ambr b/tests/__snapshots__/test_snow_connector.ambr index 9a4cf332e9..c923550308 100644 --- a/tests/__snapshots__/test_snow_connector.ambr +++ b/tests/__snapshots__/test_snow_connector.ambr @@ -19,7 +19,7 @@ use schema schemaValue; - create stage if not exists namedStageValue comment='deployments managed by snowcli'; + create stage if not exists namedStageValue comment='deployments managed by Snowflake CLI'; put file://file_pathValue @namedStageValuepathValue auto_compress=false parallel=4 overwrite=overwriteValue; @@ -45,7 +45,7 @@ use schema schemaValue; - create stage if not exists snow://embeddedStageValue comment='deployments managed by snowcli'; + create stage if not exists snow://embeddedStageValue comment='deployments managed by Snowflake CLI'; put file://file_pathValue snow://embeddedStageValuepathValue auto_compress=false parallel=4 overwrite=overwriteValue; diff --git a/tests/api/commands/__snapshots__/test_flags.ambr b/tests/api/commands/__snapshots__/test_flags.ambr new file mode 100644 index 0000000000..6ffec5bafb --- /dev/null +++ b/tests/api/commands/__snapshots__/test_flags.ambr @@ -0,0 +1,28 @@ +# serializer version: 1 +# name: test_format + ''' + Usage: default object stage list [OPTIONS] STAGE_NAME + Try 'default object stage list --help' for help. + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Invalid value for '--format': 'invalid_format' is not one of 'TABLE', │ + │ 'JSON'. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- +# name: test_mutually_exclusive_options_error + ''' + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Options '--option2' and '--option1' are incompatible. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- +# name: test_overrideable_option_callback_with_mutually_exclusive + ''' + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Options '--option2' and '--option1' are incompatible. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- diff --git a/tests/api/commands/test_flags.py b/tests/api/commands/test_flags.py index f2e1de910f..5c7405fc29 100644 --- a/tests/api/commands/test_flags.py +++ b/tests/api/commands/test_flags.py @@ -1,22 +1,25 @@ -from snowflake.cli.api.commands.flags import PLAIN_PASSWORD_MSG, PasswordOption +from unittest import mock +from unittest.mock import Mock, create_autospec, patch + +import click.core +import pytest +import typer +from snowflake.cli.api.commands.flags import ( + PLAIN_PASSWORD_MSG, + OverrideableOption, + PasswordOption, +) from typer import Typer +from typer.core import TyperOption from typer.testing import CliRunner -def test_format(runner): +def test_format(runner, snapshot): result = runner.invoke( ["object", "stage", "list", "stage_name", "--format", "invalid_format"] ) - assert result.output == ( - """Usage: default object stage list [OPTIONS] STAGE_NAME -Try 'default object stage list --help' for help. -╭─ Error ──────────────────────────────────────────────────────────────────────╮ -│ Invalid value for '--format': 'invalid_format' is not one of 'TABLE', │ -│ 'JSON'. │ -╰──────────────────────────────────────────────────────────────────────────────╯ -""" - ) + assert result.output == snapshot def test_password_flag(): @@ -30,3 +33,170 @@ def _(password: str = PasswordOption): result = runner.invoke(app, ["--password", "dummy"], catch_exceptions=False) assert result.exit_code == 0 assert PLAIN_PASSWORD_MSG in result.output + + +@patch("snowflake.cli.api.commands.flags.typer.Option") +def test_overrideable_option_returns_typer_option(mock_option): + mock_option_info = Mock(spec=typer.models.OptionInfo) + mock_option.return_value = mock_option_info + default = 1 + param_decls = ["--option"] + help_message = "help message" + + option = OverrideableOption(default, *param_decls, help=help_message)() + mock_option.assert_called_once_with(default, *param_decls, help=help_message) + assert option == mock_option_info + + +def test_overrideable_option_is_overrideable(): + original_param_decls = ("--option",) + original = OverrideableOption(1, *original_param_decls, help="original help") + + new_default = 2 + new_help = "new help" + modified = original(default=new_default, help=new_help) + + assert modified.default == new_default + assert modified.help == new_help + assert modified.param_decls == original_param_decls + + +_MUTEX_OPTION_1 = OverrideableOption( + False, "--option1", mutually_exclusive=["option_1", "option_2"] +) +_MUTEX_OPTION_2 = OverrideableOption( + False, "--option2", mutually_exclusive=["option_1", "option_2"] +) + + +@pytest.mark.parametrize("set1, set2", [(False, False), (False, True), (True, False)]) +def test_mutually_exclusive_options_no_error(set1, set2): + app = Typer() + + @app.command() + def _(option_1: bool = _MUTEX_OPTION_1(), option_2: bool = _MUTEX_OPTION_2()): + pass + + command = [] + if set1: + command.append("--option1") + if set2: + command.append("--option2") + runner = CliRunner() + result = runner.invoke(app, command) + assert result.exit_code == 0 + + +def test_mutually_exclusive_options_error(snapshot): + app = Typer() + + @app.command() + def _(option_1: bool = _MUTEX_OPTION_1(), option_2: bool = _MUTEX_OPTION_2()): + pass + + command = ["--option1", "--option2"] + runner = CliRunner() + result = runner.invoke(app, command) + assert result.exit_code == 1 + assert result.output == snapshot + + +def test_overrideable_option_callback_passthrough(): + def callback(value): + return value + 1 + + app = Typer() + + @app.command() + def _(option: int = OverrideableOption(..., "--option", callback=callback)()): + print(option) + + runner = CliRunner() + result = runner.invoke(app, ["--option", "0"]) + assert result.exit_code == 0 + assert result.output.strip() == "1" + + +def test_overrideable_option_callback_with_context(): + # tests that generated_callback will correctly map ctx and param arguments to the original callback + def callback(value, param: typer.CallbackParam, ctx: typer.Context): + assert isinstance(value, int) + assert isinstance(param, TyperOption) + assert isinstance(ctx, click.core.Context) + return value + + app = Typer() + + @app.command() + def _(option: int = OverrideableOption(..., "--option", callback=callback)()): + pass + + runner = CliRunner() + result = runner.invoke(app, ["--option", "0"]) + assert result.exit_code == 0 + + +class _InvalidCallbackSignatureNamespace: + # dummy functions for test_overrideable_option_invalid_callback_signature + + # too many parameters + @staticmethod + def callback1( + ctx: typer.Context, param: typer.CallbackParam, value1: int, value2: float + ): + pass + + # untyped Context and CallbackParam + @staticmethod + def callback2(ctx, param, value): + pass + + # multiple untyped values + @staticmethod + def callback3(ctx: typer.Context, value1, value2): + pass + + +@pytest.mark.parametrize( + "callback", + [ + _InvalidCallbackSignatureNamespace.callback1, + _InvalidCallbackSignatureNamespace.callback2, + _InvalidCallbackSignatureNamespace.callback3, + ], +) +def test_overrideable_option_invalid_callback_signature(callback): + invalid_callback_option = OverrideableOption(None, "--option", callback=callback) + with pytest.raises(OverrideableOption.InvalidCallbackSignature): + invalid_callback_option() + + +def test_overrideable_option_callback_with_mutually_exclusive(snapshot): + """ + Tests that is both 'callback' and 'mutually_exclusive' are passed to OverrideableOption, both are respected. This + is mainly for the rare use case where you are using 'mutually_exclusive' with non-flag options. + """ + + def passthrough(value): + return value + + mock_callback = create_autospec(passthrough) + app = Typer() + + @app.command() + def _( + option_1: int = _MUTEX_OPTION_1(default=None, callback=mock_callback), + option_2: int = _MUTEX_OPTION_2(default=None, callback=mock_callback), + ): + pass + + runner = CliRunner() + + # test that callback is called on the option values + runner.invoke(app, ["--option1", "1"]) + mock_callback.assert_has_calls([mock.call(value=1), mock.call(value=None)]) + + # test that we can't provide both options as non-falsey values without throwing error + result = runner.invoke(app, ["--option1", "1", "--option2", "2"]) + assert result.exit_code == 1 + assert result.output == snapshot diff --git a/tests/api/commands/test_snow_typer.py b/tests/api/commands/test_snow_typer.py index b0c7136762..08bb095dc3 100644 --- a/tests/api/commands/test_snow_typer.py +++ b/tests/api/commands/test_snow_typer.py @@ -147,21 +147,21 @@ def test_command_with_connection_options(cli, snapshot): assert result.output == snapshot -@mock.patch("snowflake.cli.api.commands.snow_typer.log_command_usage") +@mock.patch("snowflake.cli.app.telemetry.log_command_usage") def test_snow_typer_pre_execute_sends_telemetry(mock_log_command_usage, cli): result = cli(app_factory(SnowTyper))(["simple_cmd", "Norma"]) assert result.exit_code == 0 mock_log_command_usage.assert_called_once_with() -@mock.patch("snowflake.cli.api.commands.snow_typer.flush_telemetry") +@mock.patch("snowflake.cli.app.telemetry.flush_telemetry") def test_snow_typer_post_execute_sends_telemetry(mock_flush_telemetry, cli): result = cli(app_factory(SnowTyper))(["simple_cmd", "Norma"]) assert result.exit_code == 0 mock_flush_telemetry.assert_called_once_with() -@mock.patch("snowflake.cli.api.commands.snow_typer.print_result") +@mock.patch("snowflake.cli.app.printing.print_result") def test_snow_typer_result_callback_sends_telemetry(mock_print_result, cli): result = cli(app_factory(SnowTyper))(["simple_cmd", "Norma"]) assert result.exit_code == 0 diff --git a/tests/api/utils/test_naming_utils.py b/tests/api/utils/test_naming_utils.py new file mode 100644 index 0000000000..ea03f2cf78 --- /dev/null +++ b/tests/api/utils/test_naming_utils.py @@ -0,0 +1,15 @@ +import pytest +from snowflake.cli.api.utils.naming_utils import from_qualified_name + + +@pytest.mark.parametrize( + "qualified_name, expected", + [ + ("func(number, number)", ("func(number, number)", None, None)), + ("name", ("name", None, None)), + ("schema.name", ("name", "schema", None)), + ("db.schema.name", ("name", "schema", "db")), + ], +) +def test_from_fully_qualified_name(qualified_name, expected): + assert from_qualified_name(qualified_name) == expected diff --git a/tests/conftest.py b/tests/conftest.py index 1b07455cdd..c19e1610b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,6 @@ from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.types import QueryResult from snowflake.cli.app import loggers -from snowflake.cli.app.cli_app import app pytest_plugins = ["tests.testing_utils.fixtures", "tests.project.fixtures"] @@ -72,7 +71,9 @@ def make_mock_cursor(mock_cursor): @pytest.fixture(name="faker_app") -def make_faker_app(_create_mock_cursor): +def make_faker_app(runner, _create_mock_cursor): + app = runner.app + @app.command("Faker") @with_output @global_options diff --git a/tests/snowpark/test_function.py b/tests/snowpark/test_function.py index 231d0f655f..75da0b0dd3 100644 --- a/tests/snowpark/test_function.py +++ b/tests/snowpark/test_function.py @@ -31,7 +31,7 @@ def test_deploy_function( assert result.exit_code == 0, result.output assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( @@ -78,7 +78,7 @@ def test_deploy_function_with_external_access( assert result.exit_code == 0, result.output assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( @@ -159,7 +159,7 @@ def test_deploy_function_no_changes( } ] assert queries == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project auto_compress=false parallel=4 overwrite=True", ] @@ -197,7 +197,7 @@ def test_deploy_function_needs_update_because_packages_changes( } ] assert queries == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project auto_compress=false parallel=4 overwrite=True", dedent( """\ @@ -246,7 +246,7 @@ def test_deploy_function_needs_update_because_handler_changes( } ] assert queries == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( diff --git a/tests/snowpark/test_procedure.py b/tests/snowpark/test_procedure.py index efdd2b7e3c..b83f709f56 100644 --- a/tests/snowpark/test_procedure.py +++ b/tests/snowpark/test_procedure.py @@ -52,7 +52,7 @@ def test_deploy_procedure( ] ) assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(tmp).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project auto_compress=false parallel=4 overwrite=True", dedent( """\ @@ -117,7 +117,7 @@ def test_deploy_procedure_with_external_access( ] ) assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( diff --git a/tests/spcs/__snapshots__/test_image_repository.ambr b/tests/spcs/__snapshots__/test_image_repository.ambr new file mode 100644 index 0000000000..a7de666b3e --- /dev/null +++ b/tests/spcs/__snapshots__/test_image_repository.ambr @@ -0,0 +1,19 @@ +# serializer version: 1 +# name: test_create_cli + ''' + +-----------------------------------------------------------+ + | key | value | + |--------+--------------------------------------------------| + | status | Image Repository TEST_REPO successfully created. | + +-----------------------------------------------------------+ + + ''' +# --- +# name: test_create_cli_replace_and_if_not_exists_fails + ''' + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Options '--if-not-exists' and '--replace' are incompatible. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- diff --git a/tests/spcs/test_compute_pool.py b/tests/spcs/test_compute_pool.py index 440a0278a7..63104becf3 100644 --- a/tests/spcs/test_compute_pool.py +++ b/tests/spcs/test_compute_pool.py @@ -41,6 +41,7 @@ def test_create(mock_execute_query): initially_suspended=initially_suspended, auto_suspend_secs=auto_suspend_secs, comment=comment, + if_not_exists=False, ) expected_query = " ".join( [ @@ -81,6 +82,7 @@ def test_create_pool_cli_defaults(mock_create, runner): initially_suspended=False, auto_suspend_secs=3600, comment=None, + if_not_exists=False, ) @@ -104,6 +106,7 @@ def test_create_pool_cli(mock_create, runner): "7200", "--comment", "this is a test", + "--if-not-exists", ] ) assert result.exit_code == 0, result.output @@ -116,6 +119,7 @@ def test_create_pool_cli(mock_create, runner): initially_suspended=True, auto_suspend_secs=7200, comment=to_string_literal("this is a test"), + if_not_exists=True, ) @@ -123,8 +127,8 @@ def test_create_pool_cli(mock_create, runner): "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" ) @patch("snowflake.cli.plugins.spcs.compute_pool.manager.handle_object_already_exists") -def test_create_repository_already_exists(mock_handle, mock_execute): - pool_name = "test_object" +def test_create_compute_pool_already_exists(mock_handle, mock_execute): + pool_name = "test_pool" mock_execute.side_effect = SPCS_OBJECT_EXISTS_ERROR ComputePoolManager().create( pool_name=pool_name, @@ -135,12 +139,46 @@ def test_create_repository_already_exists(mock_handle, mock_execute): initially_suspended=True, auto_suspend_secs=7200, comment=to_string_literal("this is a test"), + if_not_exists=False, ) mock_handle.assert_called_once_with( SPCS_OBJECT_EXISTS_ERROR, ObjectType.COMPUTE_POOL, pool_name ) +@patch( + "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" +) +def test_create_compute_pool_if_not_exists(mock_execute_query): + cursor = Mock(spec=SnowflakeCursor) + mock_execute_query.return_value = cursor + result = ComputePoolManager().create( + pool_name="test_pool", + min_nodes=1, + max_nodes=1, + instance_family="test_family", + auto_resume=True, + initially_suspended=False, + auto_suspend_secs=3600, + comment=None, + if_not_exists=True, + ) + expected_query = " ".join( + [ + "CREATE COMPUTE POOL IF NOT EXISTS test_pool", + "MIN_NODES = 1", + "MAX_NODES = 1", + "INSTANCE_FAMILY = test_family", + "AUTO_RESUME = True", + "INITIALLY_SUSPENDED = False", + "AUTO_SUSPEND_SECS = 3600", + ] + ) + actual_query = " ".join(mock_execute_query.mock_calls[0].args[0].split()) + assert expected_query == actual_query + assert result == cursor + + @patch( "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" ) diff --git a/tests/spcs/test_image_repository.py b/tests/spcs/test_image_repository.py index 9332c8379a..21b140f297 100644 --- a/tests/spcs/test_image_repository.py +++ b/tests/spcs/test_image_repository.py @@ -4,7 +4,6 @@ from unittest.mock import Mock import pytest -from click import ClickException from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.exceptions import ( DatabaseNotProvidedError, @@ -12,6 +11,7 @@ ) from snowflake.cli.plugins.spcs.image_repository.manager import ImageRepositoryManager from snowflake.connector.cursor import SnowflakeCursor +from snowflake.connector.errors import ProgrammingError from tests.spcs.test_common import SPCS_OBJECT_EXISTS_ERROR @@ -44,37 +44,68 @@ ] +@pytest.mark.parametrize( + "replace, if_not_exists, expected_query", + [ + (False, False, "create image repository test_repo"), + (False, True, "create image repository if not exists test_repo"), + (True, False, "create or replace image repository test_repo"), + # (True, True) is an invalid case as OR REPLACE and IF NOT EXISTS are mutually exclusive. + ], +) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._execute_schema_query" ) -def test_create( - mock_execute, -): +def test_create(mock_execute, replace, if_not_exists, expected_query): repo_name = "test_repo" cursor = Mock(spec=SnowflakeCursor) mock_execute.return_value = cursor - result = ImageRepositoryManager().create(name=repo_name) - expected_query = "create image repository test_repo" - mock_execute.assert_called_once_with(expected_query) + result = ImageRepositoryManager().create( + name=repo_name, replace=replace, if_not_exists=if_not_exists + ) + mock_execute.assert_called_once_with(expected_query, name=repo_name) assert result == cursor +def test_create_replace_and_if_not_exist(): + with pytest.raises(ValueError) as e: + ImageRepositoryManager().create( + name="test_repo", replace=True, if_not_exists=True + ) + assert "mutually exclusive" in str(e.value) + + @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.create" ) -def test_create_cli(mock_create, mock_cursor, runner): +def test_create_cli(mock_create, mock_cursor, runner, snapshot): repo_name = "test_repo" cursor = mock_cursor( rows=[[f"Image Repository {repo_name.upper()} successfully created."]], columns=["status"], ) mock_create.return_value = cursor - result = runner.invoke(["spcs", "image-repository", "create", repo_name]) - mock_create.assert_called_once_with(name=repo_name) - assert result.exit_code == 0, result.output - assert ( - f"Image Repository {repo_name.upper()} successfully created." in result.output + command = ["spcs", "image-repository", "create", repo_name] + result = runner.invoke(command) + mock_create.assert_called_once_with( + name=repo_name, replace=False, if_not_exists=False ) + assert result.exit_code == 0, result.output + assert result.output == snapshot + + +def test_create_cli_replace_and_if_not_exists_fails(runner, snapshot): + command = [ + "spcs", + "image-repository", + "create", + "test_repo", + "--replace", + "--if-not-exists", + ] + result = runner.invoke(command) + assert result.exit_code == 1 + assert result.output == snapshot @mock.patch( @@ -86,9 +117,12 @@ def test_create_cli(mock_create, mock_cursor, runner): def test_create_repository_already_exists(mock_handle, mock_execute): repo_name = "test_object" mock_execute.side_effect = SPCS_OBJECT_EXISTS_ERROR - ImageRepositoryManager().create(repo_name) + ImageRepositoryManager().create(repo_name, replace=False, if_not_exists=False) mock_handle.assert_called_once_with( - SPCS_OBJECT_EXISTS_ERROR, ObjectType.IMAGE_REPOSITORY, repo_name + SPCS_OBJECT_EXISTS_ERROR, + ObjectType.IMAGE_REPOSITORY, + repo_name, + replace_available=True, ) @@ -191,13 +225,10 @@ def test_get_repository_url_cli(mock_url, runner): assert result.output.strip() == repo_url -@mock.patch( - "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.check_database_and_schema" -) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.show_specific_object" ) -def test_get_repository_url(mock_get_row, mock_check_database_and_schema): +def test_get_repository_url(mock_get_row): expected_row = MOCK_ROWS_DICT[0] mock_get_row.return_value = expected_row result = ImageRepositoryManager().get_repository_url(repo_name="IMAGES") @@ -209,13 +240,10 @@ def test_get_repository_url(mock_get_row, mock_check_database_and_schema): assert result == f"https://{expected_row['repository_url']}" -@mock.patch( - "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.check_database_and_schema" -) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.show_specific_object" ) -def test_get_repository_url_no_scheme(mock_get_row, mock_check_database_and_schema): +def test_get_repository_url_no_scheme(mock_get_row): expected_row = MOCK_ROWS_DICT[0] mock_get_row.return_value = expected_row result = ImageRepositoryManager().get_repository_url( @@ -229,26 +257,21 @@ def test_get_repository_url_no_scheme(mock_get_row, mock_check_database_and_sche assert result == expected_row["repository_url"] -@mock.patch( - "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.check_database_and_schema" -) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._conn" ) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.show_specific_object" ) -def test_get_repository_url_no_repo_found( - mock_get_row, mock_conn, mock_check_database_and_schema -): +def test_get_repository_url_no_repo_found(mock_get_row, mock_conn): mock_get_row.return_value = None mock_conn.database = "DB" mock_conn.schema = "SCHEMA" - with pytest.raises(ClickException) as e: + with pytest.raises(ProgrammingError) as e: ImageRepositoryManager().get_repository_url(repo_name="IMAGES") assert ( - e.value.message - == "Image repository 'IMAGES' does not exist in database 'DB' and schema 'SCHEMA' or not authorized." + e.value.msg + == "Image repository 'DB.SCHEMA.IMAGES' does not exist or not authorized." ) mock_get_row.assert_called_once_with( "image repositories", "IMAGES", check_schema=True @@ -258,17 +281,17 @@ def test_get_repository_url_no_repo_found( @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._conn" ) -def test_get_repository_url_no_database(mock_conn): +def test_get_repository_url_no_database_provided(mock_conn): mock_conn.database = None with pytest.raises(DatabaseNotProvidedError): - ImageRepositoryManager().get_repository_url("test_repo") + ImageRepositoryManager().get_repository_url("IMAGES") @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._conn" ) -@mock.patch("snowflake.cli.api.sql_execution.SqlExecutionMixin.check_database_exists") -def test_get_repository_url_no_schema(mock_check_database_exists, mock_conn): +def test_get_repository_url_no_schema_provided(mock_conn): + mock_conn.database = "DB" mock_conn.schema = None with pytest.raises(SchemaNotProvidedError): - ImageRepositoryManager().get_repository_url("test_repo") + ImageRepositoryManager().get_repository_url("IMAGES") diff --git a/tests/spcs/test_jobs.py b/tests/spcs/test_jobs.py index b714fc2f8e..4658a03979 100644 --- a/tests/spcs/test_jobs.py +++ b/tests/spcs/test_jobs.py @@ -2,7 +2,10 @@ from tempfile import TemporaryDirectory from unittest import mock +import pytest + +@pytest.mark.skip("Snowpark Container Services Job not supported.") @mock.patch("snowflake.connector.connect") def test_create_job(mock_connector, runner, mock_ctx): ctx = mock_ctx() @@ -40,6 +43,7 @@ def test_create_job(mock_connector, runner, mock_ctx): ) +@pytest.mark.skip("Snowpark Container Services Job not supported.") @mock.patch("snowflake.connector.connect") def test_job_status(mock_connector, runner, mock_ctx): ctx = mock_ctx() @@ -51,6 +55,7 @@ def test_job_status(mock_connector, runner, mock_ctx): assert ctx.get_query() == "CALL SYSTEM$GET_JOB_STATUS('jobName')" +@pytest.mark.skip("Snowpark Container Services Job not supported.") @mock.patch("snowflake.connector.connect") def test_job_logs(mock_connector, runner, mock_ctx): ctx = mock_ctx() diff --git a/tests/spcs/test_services.py b/tests/spcs/test_services.py index 24919a3dff..e7ee202e75 100644 --- a/tests/spcs/test_services.py +++ b/tests/spcs/test_services.py @@ -72,6 +72,7 @@ def test_create_service(mock_execute_query, other_directory): query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=False, ) expected_query = " ".join( [ @@ -120,6 +121,7 @@ def test_create_service_cli_defaults(mock_create, other_directory, runner): query_warehouse=None, tags=[], comment=None, + if_not_exists=False, ) @@ -155,6 +157,7 @@ def test_create_service_cli(mock_create, other_directory, runner): '"$trange name"=normal value', "--comment", "this is a test", + "--if-not-exists", ] ) assert result.exit_code == 0, result.output @@ -169,6 +172,7 @@ def test_create_service_cli(mock_create, other_directory, runner): query_warehouse="test_warehouse", tags=[Tag("name", "value"), Tag('"$trange name"', "normal value")], comment=to_string_literal("this is a test"), + if_not_exists=True, ) @@ -195,14 +199,15 @@ def test_create_service_with_invalid_spec(mock_read_yaml): query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=False, ) @patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._read_yaml") @patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._execute_query") @patch("snowflake.cli.plugins.spcs.services.manager.handle_object_already_exists") -def test_create_repository_already_exists(mock_handle, mock_execute, mock_read_yaml): - service_name = "test_object" +def test_create_service_already_exists(mock_handle, mock_execute, mock_read_yaml): + service_name = "test_service" compute_pool = "test_pool" spec_path = "/path/to/spec.yaml" min_instances = 42 @@ -221,12 +226,47 @@ def test_create_repository_already_exists(mock_handle, mock_execute, mock_read_y query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=False, ) mock_handle.assert_called_once_with( SPCS_OBJECT_EXISTS_ERROR, ObjectType.SERVICE, service_name ) +@patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._execute_query") +def test_create_service_if_not_exists(mock_execute_query, other_directory): + cursor = Mock(spec=SnowflakeCursor) + mock_execute_query.return_value = cursor + tmp_dir = Path(other_directory) + spec_path = tmp_dir / "spec.yml" + spec_path.write_text(SPEC_CONTENT) + result = ServiceManager().create( + service_name="test_service", + compute_pool="test_pool", + spec_path=spec_path, + min_instances=1, + max_instances=1, + auto_resume=True, + external_access_integrations=None, + query_warehouse=None, + tags=None, + comment=None, + if_not_exists=True, + ) + expected_query = " ".join( + [ + "CREATE SERVICE IF NOT EXISTS test_service", + "IN COMPUTE POOL test_pool", + f"FROM SPECIFICATION $$ {json.dumps(SPEC_DICT)} $$", + "MIN_INSTANCES = 1 MAX_INSTANCES = 1", + "AUTO_RESUME = True", + ] + ) + actual_query = " ".join(mock_execute_query.mock_calls[0].args[0].split()) + assert expected_query == actual_query + assert result == cursor + + @patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._execute_query") def test_status(mock_execute_query): service_name = "test_service" diff --git a/tests/test_help_messages.py b/tests/test_help_messages.py index fcafaae5b3..477679b622 100644 --- a/tests/test_help_messages.py +++ b/tests/test_help_messages.py @@ -18,6 +18,7 @@ def _iter_through_commands(command, path): yield from _iter_through_commands(subcommand, path) path.pop() + yield [] # "snow" with no commands builtin_plugins = load_only_builtin_command_plugins() for plugin in builtin_plugins: spec = plugin.command_spec diff --git a/tests/test_sql.py b/tests/test_sql.py index 95a70b24e4..f433dc8bec 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -4,8 +4,10 @@ import pytest from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.project.util import identifier_to_show_like_pattern from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector.cursor import DictCursor +from snowflake.connector.errors import ProgrammingError from tests.testing_utils.result_assertions import assert_that_result_is_usage_error @@ -169,3 +171,80 @@ def test_show_specific_object_sql_execution_error(mock_execute): mock_execute.assert_called_once_with( r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor ) + + +@pytest.mark.parametrize( + "name, name_split, expected_name, expected_in_clause", + [ + ( + "func(number, number)", + ("func(number, number)", None, None), + "func(number, number)", + None, + ), + ("name", ("name", None, None), "name", None), + ("schema.name", ("name", "schema", None), "name", "in schema schema"), + ("db.schema.name", ("name", "schema", "db"), "name", "in schema db.schema"), + ], +) +@mock.patch("snowflake.cli.api.sql_execution.from_qualified_name") +def test_qualified_name_to_in_clause( + mock_from_qualified_name, name, name_split, expected_name, expected_in_clause +): + mock_from_qualified_name.return_value = name_split + assert SqlExecutionMixin._qualified_name_to_in_clause(name) == ( # noqa: SLF001 + expected_name, + expected_in_clause, + ) + mock_from_qualified_name.assert_called_once_with(name) + + +@mock.patch("snowflake.cli.plugins.sql.manager.SqlExecutionMixin._execute_query") +@mock.patch( + "snowflake.cli.api.sql_execution.SqlExecutionMixin._qualified_name_to_in_clause" +) +def test_show_specific_object_qualified_name( + mock_qualified_name_to_in_clause, mock_execute_query, mock_cursor +): + name = "db.schema.obj" + unqualified_name = "obj" + name_in_clause = "in schema db.schema" + mock_columns = ["name", "created_on"] + mock_row_dict = {c: r for c, r in zip(mock_columns, [unqualified_name, "date"])} + cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns) + mock_execute_query.return_value = cursor + + mock_qualified_name_to_in_clause.return_value = (unqualified_name, name_in_clause) + SqlExecutionMixin().show_specific_object("objects", name) + mock_execute_query.assert_called_once_with( + f"show objects like {identifier_to_show_like_pattern(unqualified_name)} {name_in_clause}", + cursor_class=DictCursor, + ) + + +@mock.patch( + "snowflake.cli.api.sql_execution.SqlExecutionMixin._qualified_name_to_in_clause" +) +def test_show_specific_object_qualified_name_and_in_clause_error( + mock_qualified_name_to_in_clause, +): + object_name = "db.schema.name" + mock_qualified_name_to_in_clause.return_value = ("name", "in schema db.schema") + with pytest.raises(SqlExecutionMixin.InClauseWithQualifiedNameError): + SqlExecutionMixin().show_specific_object( + "objects", object_name, in_clause="in database db" + ) + mock_qualified_name_to_in_clause.assert_called_once_with(object_name) + + +@mock.patch("snowflake.cli.api.sql_execution.SqlExecutionMixin._execute_query") +def test_show_specific_object_multiple_rows(mock_execute_query): + cursor = mock.Mock(spec=DictCursor) + cursor.rowcount = 2 + mock_execute_query.return_value = cursor + with pytest.raises(ProgrammingError) as err: + SqlExecutionMixin().show_specific_object("objects", "name", name_col="id") + assert "Received multiple rows" in err.value.msg + mock_execute_query.assert_called_once_with( + r"show objects like 'NAME'", cursor_class=DictCursor + ) diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index fe643b5b59..5c30a0bb23 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -13,6 +13,7 @@ import pytest import strictyaml from snowflake.cli.api.project.definition import merge_left +from snowflake.cli.app.cli_app import app_factory from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.errors import ProgrammingError from strictyaml import as_document @@ -80,7 +81,7 @@ def dot_packages_directory(temp_dir): @pytest.fixture() def mock_ctx(mock_cursor): - return lambda cursor=mock_cursor(["row"], []): MockConnectionCtx(cursor) + yield lambda cursor=mock_cursor(["row"], []): MockConnectionCtx(cursor) class MockConnectionCtx(mock.MagicMock): @@ -200,9 +201,8 @@ def package_file(): @pytest.fixture(scope="function") def runner(test_snowcli_config): - from snowflake.cli.app.cli_app import app - - return SnowCLIRunner(app, test_snowcli_config) + app = app_factory() + yield SnowCLIRunner(app, test_snowcli_config) @pytest.fixture diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index 404b4eba87..633c7d4046 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -14,7 +14,7 @@ import strictyaml from snowflake.cli.api.cli_global_context import cli_context_manager from snowflake.cli.api.project.definition import merge_left -from snowflake.cli.app.cli_app import app +from snowflake.cli.app.cli_app import app_factory from strictyaml import as_document from typer import Typer from typer.testing import CliRunner @@ -113,7 +113,8 @@ def invoke_with_connection( @pytest.fixture def runner(test_snowcli_config_provider): - return SnowCLIRunner(app, test_snowcli_config_provider) + app = app_factory() + yield SnowCLIRunner(app, test_snowcli_config_provider) class QueryResultJsonEncoderError(RuntimeError): diff --git a/tests_integration/test_external_plugins.py b/tests_integration/test_external_plugins.py index c7112a36c0..cdbc88d83b 100644 --- a/tests_integration/test_external_plugins.py +++ b/tests_integration/test_external_plugins.py @@ -93,7 +93,10 @@ def test_loading_of_installed_plugins_if_all_plugins_enabled( @pytest.mark.integration def test_loading_of_installed_plugins_if_only_one_plugin_is_enabled( - runner, install_plugins, caplog, reset_command_registration_state + runner, + install_plugins, + caplog, + reset_command_registration_state, ): runner.use_config("config_with_enabled_only_one_external_plugin.toml") @@ -111,8 +114,18 @@ def test_loading_of_installed_plugins_if_only_one_plugin_is_enabled( @pytest.mark.integration +@pytest.mark.parametrize( + "config_value", + ( + pytest.param("1", id="integer as value"), + pytest.param('"True"', id="string as value"), + ), +) def test_enabled_value_must_be_boolean( - runner, snowflake_home, reset_command_registration_state + config_value, + runner, + snowflake_home, + reset_command_registration_state, ): def _use_config_with_value(value): config = Path(snowflake_home) / "config.toml" @@ -123,19 +136,18 @@ def _use_config_with_value(value): ) runner.use_config(config) - for value in ["1", '"True"']: - _use_config_with_value(value) - result = runner.invoke_with_config(["--help"]) - output = result.output.splitlines() - assert all( - [ - "Error" in output[0], - 'Invalid plugin configuration. [multilingual-hello]: "enabled" must be a' - in output[1], - "boolean" in output[2], - ] - ) - reset_command_registration_state() + _use_config_with_value(config_value) + result = runner.invoke_with_config(("--help,")) + + first, second, third, *_ = result.output.splitlines() + assert "Error" in first, first + assert ( + 'Invalid plugin configuration. [multilingual-hello]: "enabled" must be a' + in second + ), second + assert "boolean" in third, third + + reset_command_registration_state() def _assert_that_no_error_logs(caplog):