diff --git a/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py b/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py index 55d1113748..ad8d4ba7a3 100644 --- a/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py +++ b/src/snowflake/cli/_plugins/snowpark/snowpark_entity.py @@ -1,16 +1,266 @@ -from typing import Generic, TypeVar +import functools +from enum import Enum +from pathlib import Path +from typing import Generic, List, Optional, TypeVar +from click import ClickException +from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag +from snowflake.cli._plugins.snowpark import package_utils +from snowflake.cli._plugins.snowpark.common import DEFAULT_RUNTIME +from snowflake.cli._plugins.snowpark.package.anaconda_packages import ( + AnacondaPackages, + AnacondaPackagesManager, +) +from snowflake.cli._plugins.snowpark.package_utils import ( + DownloadUnavailablePackagesResult, +) from snowflake.cli._plugins.snowpark.snowpark_entity_model import ( FunctionEntityModel, ProcedureEntityModel, ) -from snowflake.cli.api.entities.common import EntityBase +from snowflake.cli._plugins.snowpark.zipper import zip_dir +from snowflake.cli._plugins.workspace.context import ActionContext +from snowflake.cli.api.entities.common import EntityBase, get_sql_executor +from snowflake.cli.api.secure_path import SecurePath +from snowflake.connector import ProgrammingError T = TypeVar("T") +class DeployMode( + str, Enum +): # This should probably be moved to some common place, think where + create = "CREATE" + create_or_replace = "CREATE OR REPLACE" + create_if_not_exists = "CREATE IF NOT EXISTS" + + class SnowparkEntity(EntityBase[Generic[T]]): - pass + def __init__(self, *args, **kwargs): + + if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled(): + raise NotImplementedError("Snowpark entity is not implemented yet") + super().__init__(*args, **kwargs) + + @property + def root(self): + return self._workspace_ctx.project_root + + @property + def identifier(self): + return self.model.fqn.sql_identifier + + @property + def fqn(self): + return self.model.fqn + + @functools.cached_property + def _sql_executor( + self, + ): # maybe this could be moved to parent class, as it is used in streamlit entity as well + return get_sql_executor() + + @functools.cached_property + def _conn(self): + return self._sql_executor._conn # noqa + + @property + def model(self): + return self._entity_model # noqa + + def action_bundle( + self, + action_ctx: ActionContext, + output_dir: Path | None, + ignore_anaconda: bool, + skip_version_check: bool, + index_url: str | None = None, + allow_shared_libraries: bool = False, + *args, + **kwargs, + ): + return self.bundle( + output_dir, + ignore_anaconda, + skip_version_check, + index_url, + allow_shared_libraries, + ) + + def action_deploy( + self, action_ctx: ActionContext, mode: DeployMode, *args, **kwargs + ): + # TODO: After introducing bundle map, we should introduce file copying part here + return self._sql_executor.execute_query(self.get_deploy_sql(mode)) + + def action_drop(self, action_ctx: ActionContext, *args, **kwargs): + return self._sql_executor.execute_query(self.get_drop_sql()) + + def action_describe(self, action_ctx: ActionContext, *args, **kwargs): + return self._sql_executor.execute_query(self.get_describe_sql()) + + def action_execute( + self, + action_ctx: ActionContext, + execution_arguments: List[str] | None = None, + *args, + **kwargs, + ): + return self._sql_executor.execute_query( + self.get_execute_sql(execution_arguments) + ) + + def bundle( + self, + output_dir: Path | None, + ignore_anaconda: bool, + skip_version_check: bool, + index_url: str | None = None, + allow_shared_libraries: bool = False, + ) -> List[Path]: + """ + Bundles the entity artifacts and dependencies into a directory. + Parameters: + output_dir: The directory to output the bundled artifacts to. Defaults to output dir in project root + ignore_anaconda: If True, ignores anaconda chceck and tries to download all packages using pip + skip_version_check: If True, skips version check when downloading packages + index_url: The index URL to use when downloading packages, if none set - default pip index is used (in most cases- Pypi) + allow_shared_libraries: If not set to True, using dependency with .so/.dll files will raise an exception + Returns: + """ + # 0 Create a directory for the entity + if not output_dir: + output_dir = self.root / "output" / self.model.stage + output_dir.mkdir(parents=True, exist_ok=True) # type: ignore + + output_files = [] + + # 1 Check if requirements exits + if (self.root / "requirements.txt").exists(): + download_results = self._process_requirements( + bundle_dir=output_dir, # type: ignore + archive_name="dependencies.zip", + requirements_file=SecurePath(self.root / "requirements.txt"), + ignore_anaconda=ignore_anaconda, + skip_version_check=skip_version_check, + index_url=index_url, + allow_shared_libraries=allow_shared_libraries, + ) + + # 3 get the artifacts list + artifacts = self.model.artifacts + + for artifact in artifacts: + output_file = output_dir / artifact.dest / artifact.src.name + + if artifact.src.is_file(): + output_file.mkdir(parents=True, exist_ok=True) + SecurePath(artifact.src).copy(output_file) + elif artifact.is_dir(): + output_file.mkdir(parents=True, exist_ok=True) + + output_files.append(output_file) + + return output_files + + def check_if_exists( + self, action_ctx: ActionContext + ) -> bool: # TODO it should return current state, so we know if update is necessary + try: + current_state = self.action_describe(action_ctx) + return True + except ProgrammingError: + return False + + def get_deploy_sql(self, mode: DeployMode): + query = [ + f"{mode.value} {self.model.type.upper()} {self.identifier}", + "COPY GRANTS", + f"RETURNS {self.model.returns}", + f"LANGUAGE PYTHON", + f"RUNTIME_VERSION '{self.model.runtime or DEFAULT_RUNTIME}'", + f"IMPORTS={','.join(self.model.imports)}", # TODO: Add source files here after introducing bundlemap + f"HANDLER='{self.model.handler}'", + ] + + if self.model.external_access_integrations: + query.append(self.model.get_external_access_integrations_sql()) + + if self.model.secrets: + query.append(self.model.get_secrets_sql()) + + if self.model.type == "procedure" and self.model.execute_as_caller: + query.append("EXECUTE AS CALLER") + + return "\n".join(query) + + def get_describe_sql(self): + return f"DESCRIBE {self.model.type.upper()} {self.identifier}" + + def get_drop_sql(self): + return f"DROP {self.model.type.upper()} {self.identifier}" + + def get_execute_sql(self, execution_arguments: List[str] | None = None): + raise NotImplementedError + + def get_usage_grant_sql(self, app_role: str): + return f"GRANT USAGE ON {self.model.type.upper()} {self.identifier} TO ROLE {app_role}" + + def _process_requirements( # TODO: maybe leave all the logic with requirements here - so download, write requirements file etc. + self, + bundle_dir: Path, + archive_name: str, # TODO: not the best name, think of something else + requirements_file: Optional[SecurePath], + ignore_anaconda: bool, + skip_version_check: bool = False, + index_url: Optional[str] = None, + allow_shared_libraries: bool = False, + ) -> DownloadUnavailablePackagesResult: + """ + Processes the requirements file and downloads the dependencies + Parameters: + + """ + anaconda_packages_manager = AnacondaPackagesManager() + with SecurePath.temporary_directory() as tmp_dir: + requirements = package_utils.parse_requirements(requirements_file) + anaconda_packages = ( + AnacondaPackages.empty() + if ignore_anaconda + else anaconda_packages_manager.find_packages_available_in_snowflake_anaconda() + ) + download_result = package_utils.download_unavailable_packages( + requirements=requirements, + target_dir=tmp_dir, + anaconda_packages=anaconda_packages, + skip_version_check=skip_version_check, + pip_index_url=index_url, + ) + + if download_result.anaconda_packages: + anaconda_packages.write_requirements_file_in_snowflake_format( + file_path=SecurePath(bundle_dir / "requirements.txt"), + requirements=download_result.anaconda_packages, + ) + + if download_result.downloaded_packages_details: + if ( + package_utils.detect_and_log_shared_libraries( + download_result.downloaded_packages_details + ) + and not allow_shared_libraries + ): + raise ClickException( + "Some packages contain shared (.so/.dll) libraries. " + "Try again with allow_shared_libraries_flag." + ) + + zip_dir( + source=tmp_dir, + dest_zip=bundle_dir / archive_name, + ) + + return download_result class FunctionEntity(SnowparkEntity[FunctionEntityModel]): @@ -18,7 +268,17 @@ class FunctionEntity(SnowparkEntity[FunctionEntityModel]): A single UDF """ - pass + # TO THINK OF + # Where will we get imports? Should we rely on bundle map? Or should it be self-sufficient in this matter? + + def get_execute_sql( + self, execution_arguments: List[str] | None = None, *args, **kwargs + ): + if not execution_arguments: + execution_arguments = [] + return ( + f"SELECT {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})" + ) class ProcedureEntity(SnowparkEntity[ProcedureEntityModel]): @@ -26,4 +286,12 @@ class ProcedureEntity(SnowparkEntity[ProcedureEntityModel]): A stored procedure """ - pass + def get_execute_sql( + self, + execution_arguments: List[str] | None = None, + ): + if not execution_arguments: + execution_arguments = [] + return ( + f"CALL {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})" + ) diff --git a/tests/snowpark/__snapshots__/test_snowpark_entity.ambr b/tests/snowpark/__snapshots__/test_snowpark_entity.ambr new file mode 100644 index 0000000000..a486d2f274 --- /dev/null +++ b/tests/snowpark/__snapshots__/test_snowpark_entity.ambr @@ -0,0 +1,63 @@ +# serializer version: 1 +# name: test_action_execute[None] + 'SELECT func1()' +# --- +# name: test_action_execute[execution_arguments1] + 'SELECT func1(arg1, arg2)' +# --- +# name: test_action_execute[execution_arguments2] + 'SELECT func1(foo, 42, bar)' +# --- +# name: test_function_get_execute_sql[None] + 'SELECT func1()' +# --- +# name: test_function_get_execute_sql[execution_arguments1] + 'SELECT func1(arg1, arg2)' +# --- +# name: test_function_get_execute_sql[execution_arguments2] + 'SELECT func1(foo, 42, bar)' +# --- +# name: test_get_deploy_sql[CREATE IF NOT EXISTS] + ''' + CREATE IF NOT EXISTS FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- +# name: test_get_deploy_sql[CREATE OR REPLACE] + ''' + CREATE OR REPLACE FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- +# name: test_get_deploy_sql[CREATE] + ''' + CREATE FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- +# name: test_nativeapp_children_interface + ''' + CREATE FUNCTION IDENTIFIER('func1') + COPY GRANTS + RETURNS string + LANGUAGE PYTHON + RUNTIME_VERSION '3.10' + IMPORTS= + HANDLER='app.func1_handler' + ''' +# --- diff --git a/tests/snowpark/test_snowpark_entity.py b/tests/snowpark/test_snowpark_entity.py new file mode 100644 index 0000000000..c87c83e0fe --- /dev/null +++ b/tests/snowpark/test_snowpark_entity.py @@ -0,0 +1,175 @@ +from pathlib import Path +from unittest import mock + +import pytest +import yaml +from snowflake.cli._plugins.snowpark.package.anaconda_packages import ( + AnacondaPackages, + AvailablePackage, +) +from snowflake.cli._plugins.snowpark.snowpark_entity import ( + DeployMode, + FunctionEntity, + ProcedureEntity, +) +from snowflake.cli._plugins.snowpark.snowpark_entity_model import FunctionEntityModel +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext + +from tests.testing_utils.mock_config import mock_config_key + +CONNECTOR = "snowflake.connector.connect" +CONTEXT = "" +EXECUTE_QUERY = "snowflake.cli.api.sql_execution.BaseSqlExecutor.execute_query" +ANACONDA_PACKAGES = "snowflake.cli._plugins.snowpark.package.anaconda_packages.AnacondaPackagesManager.find_packages_available_in_snowflake_anaconda" + + +@pytest.fixture +def example_function_workspace( + project_directory, +): # TODO: try to make a common fixture for all entities + with mock_config_key("enable_native_app_children", True): + with project_directory("snowpark_functions_v2") as pdir: + with Path(pdir / "snowflake.yml").open() as definition_file: + definition = yaml.safe_load(definition_file) + model = FunctionEntityModel( + **definition.get("entities", {}).get("func1") + ) + + workspace_context = WorkspaceContext( + console=mock.MagicMock(), + project_root=pdir, + get_default_role=lambda: "test_role", + get_default_warehouse=lambda: "test_warehouse", + ) + + return ( + FunctionEntity(workspace_ctx=workspace_context, entity_model=model), + ActionContext( + get_entity=lambda *args: None, + ), + ) + + +def test_cannot_instantiate_without_feature_flag(): + with pytest.raises(NotImplementedError) as err: + FunctionEntity() + assert str(err.value) == "Snowpark entity is not implemented yet" + + with pytest.raises(NotImplementedError) as err: + ProcedureEntity() + assert str(err.value) == "Snowpark entity is not implemented yet" + + +@mock.patch(ANACONDA_PACKAGES) +def test_nativeapp_children_interface( + mock_anaconda, example_function_workspace, snapshot +): + mock_anaconda.return_value = AnacondaPackages( + { + "pandas": AvailablePackage("pandas", "1.2.3"), + "numpy": AvailablePackage("numpy", "1.2.3"), + "snowflake_snowpark_python": AvailablePackage( + "snowflake_snowpark_python", "1.2.3" + ), + } + ) + + sl, action_context = example_function_workspace + + sl.bundle(None, False, False, None, False) + bundle_artifact = ( + sl.root / "output" / sl.model.stage / "my_snowpark_project" / "app.py" + ) + deploy_sql_str = sl.get_deploy_sql(DeployMode.create) + grant_sql_str = sl.get_usage_grant_sql(app_role="app_role") + + assert bundle_artifact.exists() + assert deploy_sql_str == snapshot + assert ( + grant_sql_str == f"GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE app_role" + ) + + +@mock.patch(EXECUTE_QUERY) +def test_action_describe(mock_execute, example_function_workspace): + entity, action_context = example_function_workspace + result = entity.action_describe(action_context) + + mock_execute.assert_called_with("DESCRIBE FUNCTION IDENTIFIER('func1')") + + +@mock.patch(EXECUTE_QUERY) +def test_action_drop(mock_execute, example_function_workspace): + entity, action_context = example_function_workspace + result = entity.action_drop(action_context) + + mock_execute.assert_called_with("DROP FUNCTION IDENTIFIER('func1')") + + +@pytest.mark.parametrize( + "execution_arguments", [None, ["arg1", "arg2"], ["foo", 42, "bar"]] +) +@mock.patch(EXECUTE_QUERY) +def test_action_execute( + mock_execute, execution_arguments, example_function_workspace, snapshot +): + entity, action_context = example_function_workspace + result = entity.action_execute(action_context, execution_arguments) + + mock_execute.assert_called_with(snapshot) + + +@mock.patch(ANACONDA_PACKAGES) +def test_bundle(mock_anaconda, example_function_workspace): + mock_anaconda.return_value = AnacondaPackages( + { + "pandas": AvailablePackage("pandas", "1.2.3"), + "numpy": AvailablePackage("numpy", "1.2.3"), + "snowflake_snowpark_python": AvailablePackage( + "snowflake_snowpark_python", "1.2.3" + ), + } + ) + entity, action_context = example_function_workspace + entity.action_bundle(action_context, None, False, False, None, False) + + output = entity.root / "output" / entity._entity_model.stage # noqa + assert output.exists() + assert (output / "my_snowpark_project" / "app.py").exists() + + +def test_describe_function_sql(example_function_workspace): + entity, _ = example_function_workspace + assert entity.get_describe_sql() == "DESCRIBE FUNCTION IDENTIFIER('func1')" + + +def test_drop_function_sql(example_function_workspace): + entity, _ = example_function_workspace + assert entity.get_drop_sql() == "DROP FUNCTION IDENTIFIER('func1')" + + +@pytest.mark.parametrize( + "execution_arguments", [None, ["arg1", "arg2"], ["foo", 42, "bar"]] +) +def test_function_get_execute_sql( + execution_arguments, example_function_workspace, snapshot +): + entity, _ = example_function_workspace + assert entity.get_execute_sql(execution_arguments) == snapshot + + +@pytest.mark.parametrize( + "mode", + [DeployMode.create, DeployMode.create_or_replace, DeployMode.create_if_not_exists], +) +def test_get_deploy_sql(mode, example_function_workspace, snapshot): + entity, _ = example_function_workspace + assert entity.get_deploy_sql(mode) == snapshot + + +def test_get_usage_grant_sql(example_function_workspace): + entity, _ = example_function_workspace + assert ( + entity.get_usage_grant_sql("test_role") + == "GRANT USAGE ON FUNCTION IDENTIFIER('func1') TO ROLE test_role" + )