Skip to content

Commit

Permalink
Add snowpark entities (#1959)
Browse files Browse the repository at this point in the history
* Actions

* fix

* text fix

* outline

* test fix

* test fix

* test fix

* Execute

* Execute

* Tests

* Tests

* get_deploy

* get_deploy

* Solution

* Fix directories
  • Loading branch information
sfc-gh-jsikorski authored Jan 13, 2025
1 parent 4ca0526 commit 8cc48be
Show file tree
Hide file tree
Showing 6 changed files with 522 additions and 13 deletions.
238 changes: 234 additions & 4 deletions src/snowflake/cli/_plugins/snowpark/snowpark_entity.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,259 @@
from typing import Generic, TypeVar
from enum import Enum
from pathlib import Path
from typing import Generic, List, Optional, TypeVar

from click import ClickException
from snowflake.cli._plugins.nativeapp.feature_flags import FeatureFlag
from snowflake.cli._plugins.snowpark import package_utils
from snowflake.cli._plugins.snowpark.common import DEFAULT_RUNTIME
from snowflake.cli._plugins.snowpark.package.anaconda_packages import (
AnacondaPackages,
AnacondaPackagesManager,
)
from snowflake.cli._plugins.snowpark.package_utils import (
DownloadUnavailablePackagesResult,
)
from snowflake.cli._plugins.snowpark.snowpark_entity_model import (
FunctionEntityModel,
ProcedureEntityModel,
)
from snowflake.cli._plugins.snowpark.zipper import zip_dir
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.entities.common import EntityBase
from snowflake.cli.api.secure_path import SecurePath
from snowflake.connector import ProgrammingError

T = TypeVar("T")


class CreateMode(
str, Enum
): # This should probably be moved to some common place, think where
create = "CREATE"
create_or_replace = "CREATE OR REPLACE"
create_if_not_exists = "CREATE IF NOT EXISTS"


class SnowparkEntity(EntityBase[Generic[T]]):
pass
def __init__(self, *args, **kwargs):

if not FeatureFlag.ENABLE_NATIVE_APP_CHILDREN.is_enabled():
raise NotImplementedError("Snowpark entity is not implemented yet")
super().__init__(*args, **kwargs)

def action_bundle(
self,
action_ctx: ActionContext,
output_dir: Path | None,
ignore_anaconda: bool,
skip_version_check: bool,
index_url: str | None = None,
allow_shared_libraries: bool = False,
*args,
**kwargs,
) -> List[Path]:
return self.bundle(
output_dir,
ignore_anaconda,
skip_version_check,
index_url,
allow_shared_libraries,
)

def action_deploy(
self, action_ctx: ActionContext, mode: CreateMode, *args, **kwargs
):
# TODO: After introducing bundle map, we should introduce file copying part here
return self._execute_query(self.get_deploy_sql(mode))

def action_drop(self, action_ctx: ActionContext, *args, **kwargs):
return self._execute_query(self.get_drop_sql())

def action_describe(self, action_ctx: ActionContext, *args, **kwargs):
return self._execute_query(self.get_describe_sql())

def action_execute(
self,
action_ctx: ActionContext,
execution_arguments: List[str] | None = None,
*args,
**kwargs,
):
return self._execute_query(self.get_execute_sql(execution_arguments))

def bundle(
self,
output_dir: Path | None,
ignore_anaconda: bool,
skip_version_check: bool,
index_url: str | None = None,
allow_shared_libraries: bool = False,
) -> List[Path]:
"""
Bundles the entity artifacts and dependencies into a directory.
Parameters:
output_dir: The directory to output the bundled artifacts to. Defaults to output dir in project root
ignore_anaconda: If True, ignores anaconda chceck and tries to download all packages using pip
skip_version_check: If True, skips version check when downloading packages
index_url: The index URL to use when downloading packages, if none set - default pip index is used (in most cases- Pypi)
allow_shared_libraries: If not set to True, using dependency with .so/.dll files will raise an exception
Returns:
"""
# 0 Create a directory for the entity
if not output_dir:
output_dir = self.root / "output" / "bundle" / "snowpark"
output_dir.mkdir(parents=True, exist_ok=True) # type: ignore

output_files = []

# 1 Check if requirements exits
if (self.root / "requirements.txt").exists():
download_results = self._process_requirements(
bundle_dir=output_dir, # type: ignore
archive_name="dependencies.zip",
requirements_file=SecurePath(self.root / "requirements.txt"),
ignore_anaconda=ignore_anaconda,
skip_version_check=skip_version_check,
index_url=index_url,
allow_shared_libraries=allow_shared_libraries,
)

# 3 get the artifacts list
artifacts = self.model.artifacts

for artifact in artifacts:
output_file = output_dir / artifact.dest / artifact.src.name

if artifact.src.is_file():
output_file.mkdir(parents=True, exist_ok=True)
SecurePath(artifact.src).copy(output_file)
elif artifact.is_dir():
output_file.mkdir(parents=True, exist_ok=True)

output_files.append(output_file)

return output_files

def check_if_exists(
self, action_ctx: ActionContext
) -> bool: # TODO it should return current state, so we know if update is necessary
try:
current_state = self.action_describe(action_ctx)
return True
except ProgrammingError:
return False

def get_deploy_sql(self, mode: CreateMode):
query = [
f"{mode.value} {self.model.type.upper()} {self.identifier}",
"COPY GRANTS",
f"RETURNS {self.model.returns}",
f"LANGUAGE PYTHON",
f"RUNTIME_VERSION '{self.model.runtime or DEFAULT_RUNTIME}'",
f"IMPORTS={','.join(self.model.imports)}", # TODO: Add source files here after introducing bundlemap
f"HANDLER='{self.model.handler}'",
]

if self.model.external_access_integrations:
query.append(self.model.get_external_access_integrations_sql())

if self.model.secrets:
query.append(self.model.get_secrets_sql())

if self.model.type == "procedure" and self.model.execute_as_caller:
query.append("EXECUTE AS CALLER")

return "\n".join(query)

def get_execute_sql(self, execution_arguments: List[str] | None = None):
raise NotImplementedError

def _process_requirements( # TODO: maybe leave all the logic with requirements here - so download, write requirements file etc.
self,
bundle_dir: Path,
archive_name: str, # TODO: not the best name, think of something else
requirements_file: Optional[SecurePath],
ignore_anaconda: bool,
skip_version_check: bool = False,
index_url: Optional[str] = None,
allow_shared_libraries: bool = False,
) -> DownloadUnavailablePackagesResult:
"""
Processes the requirements file and downloads the dependencies
Parameters:
"""
anaconda_packages_manager = AnacondaPackagesManager()
with SecurePath.temporary_directory() as tmp_dir:
requirements = package_utils.parse_requirements(requirements_file)
anaconda_packages = (
AnacondaPackages.empty()
if ignore_anaconda
else anaconda_packages_manager.find_packages_available_in_snowflake_anaconda()
)
download_result = package_utils.download_unavailable_packages(
requirements=requirements,
target_dir=tmp_dir,
anaconda_packages=anaconda_packages,
skip_version_check=skip_version_check,
pip_index_url=index_url,
)

if download_result.anaconda_packages:
anaconda_packages.write_requirements_file_in_snowflake_format(
file_path=SecurePath(bundle_dir / "requirements.txt"),
requirements=download_result.anaconda_packages,
)

if download_result.downloaded_packages_details:
if (
package_utils.detect_and_log_shared_libraries(
download_result.downloaded_packages_details
)
and not allow_shared_libraries
):
raise ClickException(
"Some packages contain shared (.so/.dll) libraries. "
"Try again with allow_shared_libraries_flag."
)

zip_dir(
source=tmp_dir,
dest_zip=bundle_dir / archive_name,
)

return download_result


class FunctionEntity(SnowparkEntity[FunctionEntityModel]):
"""
A single UDF
"""

pass
# TO THINK OF
# Where will we get imports? Should we rely on bundle map? Or should it be self-sufficient in this matter?

def get_execute_sql(
self, execution_arguments: List[str] | None = None, *args, **kwargs
):
if not execution_arguments:
execution_arguments = []
return (
f"SELECT {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})"
)


class ProcedureEntity(SnowparkEntity[ProcedureEntityModel]):
"""
A stored procedure
"""

pass
def get_execute_sql(
self,
execution_arguments: List[str] | None = None,
):
if not execution_arguments:
execution_arguments = []
return (
f"CALL {self.fqn}({', '.join([str(arg) for arg in execution_arguments])})"
)
7 changes: 2 additions & 5 deletions src/snowflake/cli/_plugins/streamlit/streamlit_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,12 @@ def get_deploy_sql(

return query + ";"

def get_drop_sql(self):
return f"DROP STREAMLIT {self._entity_model.fqn};"
def get_share_sql(self, to_role: str) -> str:
return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};"

def get_execute_sql(self):
return f"EXECUTE STREAMLIT {self._entity_model.fqn}();"

def get_share_sql(self, to_role: str) -> str:
return f"GRANT USAGE ON STREAMLIT {self.model.fqn.sql_identifier} TO ROLE {to_role};"

def get_usage_grant_sql(self, app_role: str, schema: Optional[str] = None) -> str:
entity_id = self.entity_id
streamlit_name = f"{schema}.{entity_id}" if schema else entity_id
Expand Down
47 changes: 45 additions & 2 deletions src/snowflake/cli/api/entities/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import functools
from enum import Enum
from pathlib import Path
from typing import Generic, Type, TypeVar, get_args

from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext
from snowflake.cli.api.cli_global_context import span
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.sql_execution import SqlExecutor
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor


class EntityActions(str, Enum):
Expand Down Expand Up @@ -68,8 +73,8 @@ def __init__(self, entity_model: T, workspace_ctx: WorkspaceContext):
self._workspace_ctx = workspace_ctx

@property
def entity_id(self):
return self._entity_model.entity_id
def entity_id(self) -> str:
return self._entity_model.entity_id # type: ignore

@classmethod
def get_entity_model_type(cls) -> Type[T]:
Expand All @@ -94,6 +99,44 @@ def perform(
"""
return getattr(self, action)(action_ctx, *args, **kwargs)

@property
def root(self) -> Path:
return self._workspace_ctx.project_root

@property
def identifier(self) -> str:
return self.model.fqn.sql_identifier

@property
def fqn(self) -> FQN:
return self._entity_model.fqn # type: ignore[attr-defined]

@functools.cached_property
def _sql_executor(
self,
) -> SqlExecutor:
return get_sql_executor()

def _execute_query(self, sql: str) -> SnowflakeCursor:
return self._sql_executor.execute_query(sql)

@functools.cached_property
def _conn(self) -> SnowflakeConnection:
return self._sql_executor._conn # noqa

@property
def model(self):
return self._entity_model

def get_usage_grant_sql(self, app_role: str) -> str:
return f"GRANT USAGE ON {self.model.type.upper()} {self.identifier} TO ROLE {app_role};"

def get_describe_sql(self) -> str:
return f"DESCRIBE {self.model.type.upper()} {self.identifier};"

def get_drop_sql(self) -> str:
return f"DROP {self.model.type.upper()} {self.identifier};"


def get_sql_executor() -> SqlExecutor:
"""Returns an SQL Executor that uses the connection from the current CLI context"""
Expand Down
Loading

0 comments on commit 8cc48be

Please sign in to comment.