From 06c3876d06216da06ccfb2421bd03542c589dd2e Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Thu, 17 Oct 2024 19:06:33 +0200 Subject: [PATCH] Introduce StagePath for working with files on stages (#1674) --- src/snowflake/cli/_plugins/git/commands.py | 2 +- src/snowflake/cli/_plugins/git/manager.py | 5 + src/snowflake/cli/_plugins/stage/commands.py | 2 +- src/snowflake/cli/_plugins/stage/diff.py | 32 +-- src/snowflake/cli/_plugins/stage/manager.py | 137 +++++----- src/snowflake/cli/api/entities/utils.py | 4 +- src/snowflake/cli/api/identifiers.py | 2 + src/snowflake/cli/api/stage_path.py | 241 ++++++++++++++++++ .../git/__snapshots__/test_git_commands.ambr | 28 +- tests/git/test_git_commands.py | 38 +-- tests/nativeapp/test_manager.py | 10 +- tests/stage/test_diff.py | 14 +- tests/stage/test_stage.py | 19 +- tests/stage/test_stage_path.py | 170 ++++++++++++ tests/streamlit/test_commands.py | 24 +- 15 files changed, 580 insertions(+), 148 deletions(-) create mode 100644 src/snowflake/cli/api/stage_path.py create mode 100644 tests/stage/test_stage_path.py diff --git a/src/snowflake/cli/_plugins/git/commands.py b/src/snowflake/cli/_plugins/git/commands.py index 6279b1c5c3..1f77644ec6 100644 --- a/src/snowflake/cli/_plugins/git/commands.py +++ b/src/snowflake/cli/_plugins/git/commands.py @@ -338,7 +338,7 @@ def execute( extension will be executed. """ results = GitManager().execute( - stage_path=repository_path, + stage_path_str=repository_path, on_error=on_error, variables=variables, requires_temporary_stage=True, diff --git a/src/snowflake/cli/_plugins/git/manager.py b/src/snowflake/cli/_plugins/git/manager.py index d5b7b15c5a..d9d950652d 100644 --- a/src/snowflake/cli/_plugins/git/manager.py +++ b/src/snowflake/cli/_plugins/git/manager.py @@ -26,6 +26,7 @@ UserStagePathParts, ) from snowflake.cli.api.identifiers import FQN +from snowflake.cli.api.stage_path import StagePath from snowflake.connector.cursor import SnowflakeCursor # Replace magic numbers with constants @@ -78,6 +79,10 @@ def get_directory_from_file_path(self, file_path: str) -> List[str]: class GitManager(StageManager): + @staticmethod + def build_path(stage_path: str) -> StagePathParts: + return StagePath.from_git_str(stage_path) + def show_branches(self, repo_name: str, like: str) -> SnowflakeCursor: return self._execute_query(f"show git branches like '{like}' in {repo_name}") diff --git a/src/snowflake/cli/_plugins/stage/commands.py b/src/snowflake/cli/_plugins/stage/commands.py index eb98496d51..979b9d72b0 100644 --- a/src/snowflake/cli/_plugins/stage/commands.py +++ b/src/snowflake/cli/_plugins/stage/commands.py @@ -212,7 +212,7 @@ def execute( e.g. `@stage/*.sql`, `@stage/dev/*`. Only files with `.sql` extension will be executed. """ results = StageManager().execute( - stage_path=stage_path, on_error=on_error, variables=variables + stage_path_str=stage_path, on_error=on_error, variables=variables ) return CollectionResult(results) diff --git a/src/snowflake/cli/_plugins/stage/diff.py b/src/snowflake/cli/_plugins/stage/diff.py index d1eac06869..39fa451f74 100644 --- a/src/snowflake/cli/_plugins/stage/diff.py +++ b/src/snowflake/cli/_plugins/stage/diff.py @@ -30,7 +30,7 @@ log = logging.getLogger(__name__) -StagePath = PurePosixPath # alias PurePosixPath as StagePath for clarity +StagePathType = PurePosixPath # alias PurePosixPath as StagePath for clarity @dataclass @@ -39,16 +39,16 @@ class DiffResult: Each collection is a list of stage paths ('/'-separated, regardless of the platform), relative to the stage root. """ - identical: List[StagePath] = field(default_factory=list) + identical: List[StagePathType] = field(default_factory=list) "Files with matching md5sums" - different: List[StagePath] = field(default_factory=list) + different: List[StagePathType] = field(default_factory=list) "Files that may be different between the stage and the local directory" - only_local: List[StagePath] = field(default_factory=list) + only_local: List[StagePathType] = field(default_factory=list) "Files that only exist in the local directory" - only_on_stage: List[StagePath] = field(default_factory=list) + only_on_stage: List[StagePathType] = field(default_factory=list) "Files that only exist on the stage" def has_changes(self) -> bool: @@ -83,12 +83,12 @@ def enumerate_files(path: Path) -> List[Path]: return paths -def strip_stage_name(path: str) -> StagePath: +def strip_stage_name(path: str) -> StagePathType: """Returns the given stage path without the stage name as the first part.""" - return StagePath(*path.split("/")[1:]) + return StagePathType(*path.split("/")[1:]) -def build_md5_map(list_stage_cursor: DictCursor) -> Dict[StagePath, Optional[str]]: +def build_md5_map(list_stage_cursor: DictCursor) -> Dict[StagePathType, Optional[str]]: """ Returns a mapping of relative stage paths to their md5sums. """ @@ -99,7 +99,7 @@ def build_md5_map(list_stage_cursor: DictCursor) -> Dict[StagePath, Optional[str def preserve_from_diff( - diff: DiffResult, stage_paths_to_sync: Collection[StagePath] + diff: DiffResult, stage_paths_to_sync: Collection[StagePathType] ) -> DiffResult: """ Returns a filtered version of the provided diff, keeping only the provided stage paths. @@ -163,7 +163,7 @@ def compute_stage_diff( return result -def get_stage_subpath(stage_path: StagePath) -> str: +def get_stage_subpath(stage_path: StagePathType) -> str: """ Returns the parent portion of a stage path, as a string, for inclusion in the fully qualified stage path. Note that '.' treated specially here, and so the return value of this call is not a `StagePath` instance. @@ -172,21 +172,21 @@ def get_stage_subpath(stage_path: StagePath) -> str: return "" if parent == "." else parent -def to_stage_path(filename: Path) -> StagePath: +def to_stage_path(filename: Path) -> StagePathType: """ Returns the stage file name, with the path separator suitably transformed if needed. """ - return StagePath(*filename.parts) + return StagePathType(*filename.parts) -def to_local_path(stage_path: StagePath) -> Path: +def to_local_path(stage_path: StagePathType) -> Path: return Path(*stage_path.parts) def delete_only_on_stage_files( stage_manager: StageManager, stage_fqn: str, - only_on_stage: List[StagePath], + only_on_stage: List[StagePathType], role: Optional[str] = None, ): """ @@ -200,7 +200,7 @@ def put_files_on_stage( stage_manager: StageManager, stage_fqn: str, deploy_root_path: Path, - stage_paths: List[StagePath], + stage_paths: List[StagePathType], role: Optional[str] = None, overwrite: bool = False, ): @@ -254,7 +254,7 @@ def sync_local_diff_with_stage( def _to_src_dest_pair( - stage_path: StagePath, bundle_map: Optional[BundleMap] + stage_path: StagePathType, bundle_map: Optional[BundleMap] ) -> Tuple[Optional[str], str]: if not bundle_map: return None, str(stage_path) diff --git a/src/snowflake/cli/_plugins/stage/manager.py b/src/snowflake/cli/_plugins/stage/manager.py index 6837ed1554..3b4773fcdf 100644 --- a/src/snowflake/cli/_plugins/stage/manager.py +++ b/src/snowflake/cli/_plugins/stage/manager.py @@ -40,6 +40,7 @@ from snowflake.cli.api.project.util import to_string_literal from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin +from snowflake.cli.api.stage_path import StagePath from snowflake.cli.api.utils.path_utils import path_resolver from snowflake.connector import DictCursor, ProgrammingError from snowflake.connector.cursor import SnowflakeCursor @@ -208,6 +209,10 @@ def __init__(self): super().__init__() self._python_exe_procedure = None + @staticmethod + def build_path(stage_path: str) -> StagePath: + return StagePath.from_stage_str(stage_path) + @staticmethod def get_standard_stage_prefix(name: str | FQN) -> str: if isinstance(name, FQN): @@ -245,9 +250,14 @@ def _to_uri(self, local_path: str): return uri return to_string_literal(uri) - def list_files(self, stage_name: str, pattern: str | None = None) -> DictCursor: - stage_name = self.get_standard_stage_prefix(stage_name) - query = f"ls {self.quote_stage_name(stage_name)}" + def list_files( + self, stage_name: str | StagePath, pattern: str | None = None + ) -> DictCursor: + if not isinstance(stage_name, StagePath): + stage_path = self.build_path(stage_name).path_for_sql() + else: + stage_path = stage_name.path_for_sql() + query = f"ls {stage_path}" if pattern is not None: query += f" pattern = '{pattern}'" return self._execute_query(query, cursor_class=DictCursor) @@ -262,27 +272,27 @@ def _assure_is_existing_directory(path: Path) -> None: def get( self, stage_path: str, dest_path: Path, parallel: int = 4 ) -> SnowflakeCursor: - stage_path = self.get_standard_stage_prefix(stage_path) + spath = self.build_path(stage_path) self._assure_is_existing_directory(dest_path) dest_directory = f"{dest_path}/" return self._execute_query( - f"get {self.quote_stage_name(stage_path)} {self._to_uri(dest_directory)} parallel={parallel}" + f"get {spath.path_for_sql()} {self._to_uri(dest_directory)} parallel={parallel}" ) def get_recursive( self, stage_path: str, dest_path: Path, parallel: int = 4 ) -> List[SnowflakeCursor]: - stage_path_parts = self._stage_path_part_factory(stage_path) + stage_root = self.build_path(stage_path) results = [] - for file_path in self.iter_stage(stage_path): - dest_directory = dest_path - for path_part in stage_path_parts.get_directory_from_file_path(file_path): - dest_directory = dest_directory / path_part - self._assure_is_existing_directory(dest_directory) + for file_path in self.iter_stage(stage_root): + local_dir = file_path.get_local_target_path( + target_dir=dest_path, stage_root=stage_root + ) + self._assure_is_existing_directory(local_dir) result = self._execute_query( - f"get {self.quote_stage_name(stage_path_parts.replace_stage_prefix(file_path))} {self._to_uri(f'{dest_directory}/')} parallel={parallel}" + f"get {file_path.path_for_sql()} {self._to_uri(f'{local_dir}/')} parallel={parallel}" ) results.append(result) @@ -304,28 +314,31 @@ def put( and switch back to the original role for the next commands to run. """ with self.use_role(role) if role else nullcontext(): - stage_path = self.get_standard_stage_prefix(stage_path) + spath = self.build_path(stage_path) local_resolved_path = path_resolver(str(local_path)) log.info("Uploading %s to %s", local_resolved_path, stage_path) cursor = self._execute_query( - f"put {self._to_uri(local_resolved_path)} {self.quote_stage_name(stage_path)} " + f"put {self._to_uri(local_resolved_path)} {spath.path_for_sql()} " f"auto_compress={str(auto_compress).lower()} parallel={parallel} overwrite={overwrite}" ) return cursor def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor: - source_path_parts = self._stage_path_part_factory(source_path) - destination_path_parts = self._stage_path_part_factory(destination_path) + source_stage_path = self.build_path(source_path) + # We copy only into stage + destination_stage_path = StagePath.from_stage_str(destination_path) - if isinstance(destination_path_parts, UserStagePathParts): + if destination_stage_path.is_user_stage(): raise ClickException( "Destination path cannot be a user stage. Please provide a named stage." ) - source = source_path_parts.get_standard_stage_path() - destination = destination_path_parts.get_standard_stage_path() - log.info("Copying files from %s to %s", source, destination) - query = f"copy files into {destination} from {source}" + log.info( + "Copying files from %s to %s", source_stage_path, destination_stage_path + ) + # Destination needs to end with / + dest = destination_stage_path.absolute_path().rstrip("/") + "/" + query = f"copy files into {dest} from {source_stage_path}" return self._execute_query(query) def remove( @@ -338,10 +351,8 @@ def remove( and switch back to the original role for the next commands to run. """ with self.use_role(role) if role else nullcontext(): - stage_name = self.get_standard_stage_prefix(stage_name) - path = path if path.startswith("/") else "/" + path - quoted_stage_name = self.quote_stage_name(f"{stage_name}{path}") - return self._execute_query(f"remove {quoted_stage_name}") + stage_path = self.build_path(stage_name) / path + return self._execute_query(f"remove {stage_path.path_for_sql()}") def create( self, fqn: FQN, comment: Optional[str] = None, temporary: bool = False @@ -352,13 +363,17 @@ def create( query += f" comment='{comment}'" return self._execute_query(query) - def iter_stage(self, stage_path: str): - for file in self.list_files(stage_path).fetchall(): - yield file["name"] + def iter_stage(self, stage_path: StagePath): + for file in self.list_files(stage_path.absolute_path()).fetchall(): + if stage_path.is_user_stage(): + path = StagePath.get_user_stage() / file["name"] + else: + path = self.build_path(file["name"]) + yield path def execute( self, - stage_path: str, + stage_path_str: str, on_error: OnErrorType, variables: Optional[List[str]] = None, requires_temporary_stage: bool = False, @@ -367,11 +382,15 @@ def execute( ( stage_path_parts, original_path_parts, - ) = self._create_temporary_copy_of_stage(stage_path) + ) = self._create_temporary_copy_of_stage(stage_path_str) + stage_path = StagePath.from_stage_str( + stage_path_parts.get_standard_stage_path() + ) else: - stage_path_parts = self._stage_path_part_factory(stage_path) + stage_path_parts = self._stage_path_part_factory(stage_path_str) + stage_path = self.build_path(stage_path_str) - all_files_list = self._get_files_list_from_stage(stage_path_parts) + all_files_list = self._get_files_list_from_stage(stage_path.root_path()) all_files_with_stage_name_prefix = [ stage_path_parts.get_directory(file) for file in all_files_list @@ -397,7 +416,7 @@ def execute( if any(file.endswith(".py") for file in sorted_file_path_list): self._python_exe_procedure = self._bootstrap_snowpark_execution_environment( - stage_path_parts + stage_path ) for file_path in sorted_file_path_list: @@ -437,15 +456,14 @@ def _create_temporary_copy_of_stage( original_path_parts = self._stage_path_part_factory(stage_path) # noqa: SLF001 tmp_stage_name = f"snowflake_cli_tmp_stage_{int(time.time())}" - tmp_stage = ( - FQN.from_stage(tmp_stage_name).using_connection(conn=self._conn).identifier - ) + tmp_stage_fqn = FQN.from_stage(tmp_stage_name).using_connection(conn=self._conn) + tmp_stage = tmp_stage_fqn.identifier stage_path_parts = sm._stage_path_part_factory( # noqa: SLF001 tmp_stage + "/" + original_path_parts.directory ) # Create temporary stage, it will be dropped with end of session - sm.create(FQN.from_string(tmp_stage), temporary=True) + sm.create(tmp_stage_fqn, temporary=True) # Copy the content self.copy_files( @@ -459,14 +477,12 @@ def _create_temporary_copy_of_stage( return stage_path_parts, original_path_parts def _get_files_list_from_stage( - self, stage_path_parts: StagePathParts, pattern: str | None = None + self, stage_path: StagePath, pattern: str | None = None ) -> List[str]: - files_list_result = self.list_files( - stage_path_parts.stage, pattern=pattern - ).fetchall() + files_list_result = self.list_files(stage_path, pattern=pattern).fetchall() if not files_list_result: - raise ClickException(f"No files found on stage '{stage_path_parts.stage}'") + raise ClickException(f"No files found on stage '{stage_path}'") return [f["name"] for f in files_list_result] @@ -556,32 +572,34 @@ def _stage_path_part_factory(stage_path: str) -> StagePathParts: return UserStagePathParts(stage_path) return DefaultStagePathParts(stage_path) - def _check_for_requirements_file( - self, stage_path_parts: StagePathParts - ) -> List[str]: + def _check_for_requirements_file(self, stage_path: StagePath) -> List[str]: """Looks for requirements.txt file on stage.""" + current_dir = stage_path.parent if stage_path.is_file() else stage_path req_files_on_stage = self._get_files_list_from_stage( - stage_path_parts, pattern=r".*requirements\.txt$" + current_dir, pattern=r".*requirements\.txt$" ) if not req_files_on_stage: return [] # Construct all possible path for requirements file for this context - # We don't use os.path or pathlib to preserve compatibility on Windows req_file_name = "requirements.txt" - path_parts = stage_path_parts.path.split("/") possible_req_files = [] + while not current_dir.is_root(): + current_file = current_dir / req_file_name + possible_req_files.append(current_file) + current_dir = current_dir.parent - while path_parts: - current_file = "/".join([*path_parts, req_file_name]) - possible_req_files.append(str(current_file)) - path_parts = path_parts[:-1] + current_file = current_dir / req_file_name + possible_req_files.append(current_file) # Now for every possible path check if the file exists on stage, # if yes break, we use the first possible file - requirements_file = None + requirements_file: StagePath | None = None for req_file in possible_req_files: - if req_file in req_files_on_stage: + if ( + req_file.absolute_path(no_fqn=True, at_prefix=False) + in req_files_on_stage + ): requirements_file = req_file break @@ -590,19 +608,16 @@ def _check_for_requirements_file( return [] # req_file at this moment is the first found requirements file + requirements_path = requirements_file.with_stage(stage_path.stage) with SecurePath.temporary_directory() as tmp_dir: - self.get( - stage_path_parts.get_full_stage_path(requirements_file), tmp_dir.path - ) + self.get(str(requirements_path), tmp_dir.path) requirements = parse_requirements( requirements_file=tmp_dir / "requirements.txt" ) return [req.package_name for req in requirements] - def _bootstrap_snowpark_execution_environment( - self, stage_path_parts: StagePathParts - ): + def _bootstrap_snowpark_execution_environment(self, stage_path: StagePath): """Prepares Snowpark session for executing Python code remotely.""" if sys.version_info >= PYTHON_3_12: raise ClickException( @@ -613,7 +628,7 @@ def _bootstrap_snowpark_execution_environment( self.snowpark_session.add_packages("snowflake-snowpark-python") self.snowpark_session.add_packages("snowflake.core") - requirements = self._check_for_requirements_file(stage_path_parts) + requirements = self._check_for_requirements_file(stage_path) self.snowpark_session.add_packages(*requirements) @sproc(is_permanent=False) diff --git a/src/snowflake/cli/api/entities/utils.py b/src/snowflake/cli/api/entities/utils.py index c60a2224b4..31c94e9722 100644 --- a/src/snowflake/cli/api/entities/utils.py +++ b/src/snowflake/cli/api/entities/utils.py @@ -16,7 +16,7 @@ from snowflake.cli._plugins.nativeapp.utils import verify_exists, verify_no_directories from snowflake.cli._plugins.stage.diff import ( DiffResult, - StagePath, + StagePathType, compute_stage_diff, preserve_from_diff, sync_local_diff_with_stage, @@ -80,7 +80,7 @@ def generic_sql_error_handler( def _get_stage_paths_to_sync( local_paths_to_sync: List[Path], deploy_root: Path -) -> List[StagePath]: +) -> List[StagePathType]: """ Takes a list of paths (files and directories), returning a list of all files recursively relative to the deploy root. """ diff --git a/src/snowflake/cli/api/identifiers.py b/src/snowflake/cli/api/identifiers.py index 062fd55269..f003e7967e 100644 --- a/src/snowflake/cli/api/identifiers.py +++ b/src/snowflake/cli/api/identifiers.py @@ -122,6 +122,8 @@ def from_stage(cls, stage: str) -> "FQN": name = stage if stage.startswith("@"): name = stage[1:] + if stage.startswith("~"): + return cls(name="~", database=None, schema=None) return cls.from_string(name) @classmethod diff --git a/src/snowflake/cli/api/stage_path.py b/src/snowflake/cli/api/stage_path.py new file mode 100644 index 0000000000..2f7a4e3e6a --- /dev/null +++ b/src/snowflake/cli/api/stage_path.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import re +from pathlib import Path, PurePosixPath + +from snowflake.cli.api.identifiers import FQN +from snowflake.cli.api.project.util import ( + to_string_literal, +) + +USER_STAGE_PREFIX = "~" + + +class StagePath: + def __init__( + self, + stage_name: str, + path: str | PurePosixPath | None = None, + git_ref: str | None = None, + trailing_slash: bool = False, + ): + self._stage_name = self.strip_stage_prefixes(stage_name) + self._path = PurePosixPath(path) if path else PurePosixPath(".") + + self._trailing_slash = trailing_slash + # Check if user stage + self._is_user_stage = self._stage_name.startswith(USER_STAGE_PREFIX) + + # Setup git information + self._git_ref = None + self._is_git_repo = False + if git_ref: + self._git_ref = git_ref + self._is_git_repo = True + + @classmethod + def get_user_stage(cls) -> StagePath: + return cls.from_stage_str("~") + + @property + def stage(self) -> str: + return self._stage_name + + @property + def path(self) -> PurePosixPath: + return self._path + + @property + def stage_with_at(self) -> str: + return self.add_at_prefix(self._stage_name) + + def is_user_stage(self) -> bool: + return self._is_user_stage + + def is_git_repo(self) -> bool: + return self._is_git_repo + + @property + def git_ref(self) -> str | None: + return self._git_ref + + @staticmethod + def add_at_prefix(text: str): + if not text.startswith("@"): + return "@" + text + return text + + @staticmethod + def strip_at_prefix(text: str): + if text.startswith("@"): + return text[1:] + return text + + @staticmethod + def strip_snow_prefix(text: str): + if text.startswith("snow://"): + return text[len("snow://") :] + return text + + @classmethod + def strip_stage_prefixes(cls, text: str): + return cls.strip_at_prefix(cls.strip_snow_prefix(text)) + + @classmethod + def from_stage_str(cls, stage_str: str | FQN): + stage_str = cls.strip_stage_prefixes(str(stage_str)) + parts = stage_str.split("/", maxsplit=1) + parts = [p for p in parts if p] + if len(parts) == 2: + stage_string, path = parts + else: + stage_string = parts[0] + path = None + return cls( + stage_name=stage_string, path=path, trailing_slash=stage_str.endswith("/") + ) + + @classmethod + def from_git_str(cls, git_str: str): + """ + @configuration_repo / branches/main / scripts/setup.sql + @configuration_repo / branches/"foo/main" / scripts/setup.sql + """ + repo_name, git_ref, path = cls._split_repo_path( + cls.strip_stage_prefixes(git_str) + ) + return cls( + stage_name=repo_name, + path=path, + git_ref=git_ref, + trailing_slash=git_str.endswith("/"), + ) + + @staticmethod + def _split_repo_path(git_str: str) -> tuple[str, str, str]: + parts = [] + slash_index = 0 + skipping_mode = False + for current_idx, (char, next_char) in enumerate(zip(git_str[:-1], git_str[1:])): + if not skipping_mode: + if char != "/": + continue + + # Normal split + parts.append(git_str[slash_index:current_idx]) + slash_index = current_idx + 1 + + if next_char == '"': + skipping_mode = not skipping_mode + # Add last part + parts.append(git_str[slash_index:]) + repo_name = parts[0] + ref = parts[1] + "/" + parts[2] + path = "/".join(parts[3:]) if len(parts) > 2 else "" + return repo_name, ref, path + + def absolute_path(self, no_fqn=False, at_prefix=True) -> str: + stage_name = self._stage_name + if not self.is_user_stage() and no_fqn: + stage_name = FQN.from_string(self._stage_name).name + + path = PurePosixPath(stage_name) + if self.git_ref: + path = path / self.git_ref + if not self.is_root(): + path = path / self._path + + str_path = str(path) + if at_prefix: + str_path = self.add_at_prefix(str_path) + + if self._trailing_slash: + return str_path.rstrip("/") + "/" + return str_path + + def joinpath(self, path: str) -> StagePath: + if self.is_file(): + raise ValueError("Cannot join path to a file") + + return StagePath( + stage_name=self._stage_name, + path=PurePosixPath(self._path) / path.lstrip("/"), + git_ref=self._git_ref, + ) + + def __truediv__(self, path: str): + return self.joinpath(path) + + def with_stage(self, stage_name: str) -> StagePath: + """Returns a new path with new stage name""" + return StagePath( + stage_name=stage_name, + path=self._path, + git_ref=self._git_ref, + ) + + @property + def parts(self) -> tuple[str, ...]: + return self._path.parts + + @property + def name(self) -> str: + return self._path.name + + def is_dir(self) -> bool: + return "." not in self.name + + def is_file(self) -> bool: + return not self.is_dir() + + @property + def suffix(self) -> str: + return self._path.suffix + + @property + def stem(self) -> str: + return self._path.stem + + @property + def parent(self) -> StagePath: + return StagePath( + stage_name=self._stage_name, path=self._path.parent, git_ref=self._git_ref + ) + + def is_root(self) -> bool: + return self._path == PurePosixPath(".") + + def root_path(self) -> StagePath: + if self.is_git_repo(): + return StagePath(stage_name=self._stage_name, git_ref=self._git_ref) + return StagePath(stage_name=self._stage_name) + + def is_quoted(self) -> bool: + path = self.absolute_path() + return path.startswith("'") and path.endswith("'") + + def path_for_sql(self) -> str: + path = self.absolute_path() + if not re.fullmatch(r"@([\w./$])+", path): + return to_string_literal(path) + return path + + def quoted_absolute_path(self) -> str: + if self.is_quoted(): + return self.absolute_path() + return to_string_literal(self.absolute_path()) + + def relative_to(self, stage_path: StagePath) -> PurePosixPath: + return self.path.relative_to(stage_path.path) + + def get_local_target_path(self, target_dir: Path, stage_root: StagePath): + # Case for downloading @stage/aa/file.py with root @stage/aa + if self.relative_to(stage_root) == PurePosixPath("."): + return target_dir + return (target_dir / self.relative_to(stage_root)).parent + + def __str__(self): + return self.absolute_path() + + def __eq__(self, other): + return self.absolute_path() == other.absolute_path() diff --git a/tests/git/__snapshots__/test_git_commands.ambr b/tests/git/__snapshots__/test_git_commands.ambr index 724ab9245d..129fa12cef 100644 --- a/tests/git/__snapshots__/test_git_commands.ambr +++ b/tests/git/__snapshots__/test_git_commands.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_execute[@DB.SCHEMA.REPO/branches/main/s1.sql-@DB.SCHEMA.REPO/branches/main/-expected_files4] +# name: test_execute[@DB.SCHEMA.REPO/branches/main/s1.sql-@DB.SCHEMA.REPO/branches/main-expected_files4] ''' SUCCESS - @DB.SCHEMA.REPO/branches/main/s1.sql +--------------------------------------------------------+ @@ -10,7 +10,7 @@ ''' # --- -# name: test_execute[@DB.schema.REPO/branches/main/a/S3.sql-@DB.schema.REPO/branches/main/-expected_files5] +# name: test_execute[@DB.schema.REPO/branches/main/a/S3.sql-@DB.schema.REPO/branches/main-expected_files5] ''' SUCCESS - @DB.schema.REPO/branches/main/a/S3.sql +----------------------------------------------------------+ @@ -21,7 +21,7 @@ ''' # --- -# name: test_execute[@db.schema.repo/branches/main/-@db.schema.repo/branches/main/-expected_files2] +# name: test_execute[@db.schema.repo/branches/main/-@db.schema.repo/branches/main-expected_files2] ''' SUCCESS - @db.schema.repo/branches/main/s1.sql SUCCESS - @db.schema.repo/branches/main/a/S3.sql @@ -34,7 +34,7 @@ ''' # --- -# name: test_execute[@db.schema.repo/branches/main/s1.sql-@db.schema.repo/branches/main/-expected_files3] +# name: test_execute[@db.schema.repo/branches/main/s1.sql-@db.schema.repo/branches/main-expected_files3] ''' SUCCESS - @db.schema.repo/branches/main/s1.sql +--------------------------------------------------------+ @@ -45,7 +45,7 @@ ''' # --- -# name: test_execute[@repo/branches/main/-@repo/branches/main/-expected_files0] +# name: test_execute[@repo/branches/main/-@repo/branches/main-expected_files0] ''' SUCCESS - @repo/branches/main/s1.sql SUCCESS - @repo/branches/main/a/S3.sql @@ -58,7 +58,7 @@ ''' # --- -# name: test_execute[@repo/branches/main/a-@repo/branches/main/-expected_files1] +# name: test_execute[@repo/branches/main/a-@repo/branches/main-expected_files1] ''' SUCCESS - @repo/branches/main/a/S3.sql +------------------------------------------------+ @@ -69,7 +69,7 @@ ''' # --- -# name: test_execute_new_git_repository_list_files[@repo/branches/main/-@repo/branches/main/-expected_files0] +# name: test_execute_new_git_repository_list_files[@repo/branches/main/-@repo/branches/main-expected_files0] ''' SUCCESS - @repo/branches/main/S2.sql SUCCESS - @repo/branches/main/s1.sql @@ -84,7 +84,7 @@ ''' # --- -# name: test_execute_new_git_repository_list_files[@repo/branches/main/S2.sql-@repo/branches/main/-expected_files2] +# name: test_execute_new_git_repository_list_files[@repo/branches/main/S2.sql-@repo/branches/main-expected_files2] ''' SUCCESS - @repo/branches/main/S2.sql +----------------------------------------------+ @@ -95,7 +95,7 @@ ''' # --- -# name: test_execute_new_git_repository_list_files[@repo/branches/main/a/s3.sql-@repo/branches/main/-expected_files3] +# name: test_execute_new_git_repository_list_files[@repo/branches/main/a/s3.sql-@repo/branches/main-expected_files3] ''' SUCCESS - @repo/branches/main/a/s3.sql +------------------------------------------------+ @@ -106,7 +106,7 @@ ''' # --- -# name: test_execute_new_git_repository_list_files[@repo/branches/main/s1.sql-@repo/branches/main/-expected_files1] +# name: test_execute_new_git_repository_list_files[@repo/branches/main/s1.sql-@repo/branches/main-expected_files1] ''' SUCCESS - @repo/branches/main/s1.sql +----------------------------------------------+ @@ -117,7 +117,7 @@ ''' # --- -# name: test_execute_slash_in_repository_name[@db.schema.repo/branches/"feature/commit"/-@db.schema.repo/branches/"feature/commit"/-expected_files3] +# name: test_execute_slash_in_repository_name[@db.schema.repo/branches/"feature/commit"/-@db.schema.repo/branches/"feature/commit"-expected_files3] ''' SUCCESS - @db.schema.repo/branches/"feature/commit"/s1.sql SUCCESS - @db.schema.repo/branches/"feature/commit"/a/S3.sql @@ -130,7 +130,7 @@ ''' # --- -# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/-@repo/branches/"feature/commit"/-expected_files0] +# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/-@repo/branches/"feature/commit"-expected_files0] ''' SUCCESS - @repo/branches/"feature/commit"/s1.sql SUCCESS - @repo/branches/"feature/commit"/a/S3.sql @@ -143,7 +143,7 @@ ''' # --- -# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/a/-@repo/branches/"feature/commit"/-expected_files2] +# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/a/-@repo/branches/"feature/commit"-expected_files2] ''' SUCCESS - @repo/branches/"feature/commit"/a/S3.sql +------------------------------------------------------------+ @@ -154,7 +154,7 @@ ''' # --- -# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/s1.sql-@repo/branches/"feature/commit"/-expected_files1] +# name: test_execute_slash_in_repository_name[@repo/branches/"feature/commit"/s1.sql-@repo/branches/"feature/commit"-expected_files1] ''' SUCCESS - @repo/branches/"feature/commit"/s1.sql +----------------------------------------------------------+ diff --git a/tests/git/test_git_commands.py b/tests/git/test_git_commands.py index a8370de951..6679736523 100644 --- a/tests/git/test_git_commands.py +++ b/tests/git/test_git_commands.py @@ -19,6 +19,7 @@ import pytest from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED +from snowflake.cli.api.stage_path import StagePath from snowflake.connector import DictCursor, ProgrammingError EXAMPLE_URL = "https://github.com/an-example-repo.git" @@ -127,7 +128,8 @@ def test_copy_to_local_file_system( ctx = mock_ctx() mock_connector.return_value = ctx mock_iter.return_value = ( - x for x in [f"{repo_prefix}file.txt", f"{repo_prefix}dir/file_in_dir.txt"] + StagePath.from_git_str(x) + for x in [f"{repo_prefix}file.txt", f"{repo_prefix}dir/file_in_dir.txt"] ) mock_iter.__len__.return_value = 2 mock_result.result = {"file": "mock"} @@ -542,17 +544,17 @@ def test_api_integration_and_secrets_get_unique_names( [ ( "@repo/branches/main/", - "@repo/branches/main/", + "@repo/branches/main", ["/s1.sql", "/a/S3.sql"], ), ( "@repo/branches/main/a", - "@repo/branches/main/", + "@repo/branches/main", ["/a/S3.sql"], ), ( "@db.schema.repo/branches/main/", - "@db.schema.repo/branches/main/", + "@db.schema.repo/branches/main", [ "/s1.sql", "/a/S3.sql", @@ -560,17 +562,17 @@ def test_api_integration_and_secrets_get_unique_names( ), ( "@db.schema.repo/branches/main/s1.sql", - "@db.schema.repo/branches/main/", + "@db.schema.repo/branches/main", ["/s1.sql"], ), ( "@DB.SCHEMA.REPO/branches/main/s1.sql", - "@DB.SCHEMA.REPO/branches/main/", + "@DB.SCHEMA.REPO/branches/main", ["/s1.sql"], ), ( "@DB.schema.REPO/branches/main/a/S3.sql", - "@DB.schema.REPO/branches/main/", + "@DB.schema.REPO/branches/main", ["/a/S3.sql"], ), ], @@ -604,7 +606,7 @@ def test_execute( "create temporary stage if not exists IDENTIFIER('FOO.BAR.snowflake_cli_tmp_stage_123')" ) assert copy_call == mock.call( - f"copy files into @FOO.BAR.snowflake_cli_tmp_stage_123/ from {expected_stage}" + f"copy files into @FOO.BAR.snowflake_cli_tmp_stage_123/ from {expected_stage}/" ) assert ls_call == mock.call( f"ls @FOO.BAR.snowflake_cli_tmp_stage_123", cursor_class=DictCursor @@ -621,7 +623,7 @@ def test_execute( [ ( "@repo/branches/main/", - "@repo/branches/main/", + "@repo/branches/main", [ "/S2.sql", "/s1.sql", @@ -630,17 +632,17 @@ def test_execute( ), ( "@repo/branches/main/s1.sql", - "@repo/branches/main/", + "@repo/branches/main", ["/s1.sql"], ), ( "@repo/branches/main/S2.sql", - "@repo/branches/main/", + "@repo/branches/main", ["/S2.sql"], ), ( "@repo/branches/main/a/s3.sql", - "@repo/branches/main/", + "@repo/branches/main", ["/a/s3.sql"], ), ], @@ -674,7 +676,7 @@ def test_execute_new_git_repository_list_files( "create temporary stage if not exists IDENTIFIER('FOO.BAR.snowflake_cli_tmp_stage_123')" ) assert copy_call == mock.call( - f"copy files into @FOO.BAR.snowflake_cli_tmp_stage_123/ from {expected_stage}" + f"copy files into @FOO.BAR.snowflake_cli_tmp_stage_123/ from {expected_stage}/" ) assert ls_call == mock.call( f"ls @FOO.BAR.snowflake_cli_tmp_stage_123", cursor_class=DictCursor @@ -691,7 +693,7 @@ def test_execute_new_git_repository_list_files( [ ( '@repo/branches/"feature/commit"/', - '@repo/branches/"feature/commit"/', + '@repo/branches/"feature/commit"', [ "/s1.sql", "/a/S3.sql", @@ -699,21 +701,21 @@ def test_execute_new_git_repository_list_files( ), ( '@repo/branches/"feature/commit"/s1.sql', - '@repo/branches/"feature/commit"/', + '@repo/branches/"feature/commit"', [ "/s1.sql", ], ), ( '@repo/branches/"feature/commit"/a/', - '@repo/branches/"feature/commit"/', + '@repo/branches/"feature/commit"', [ "/a/S3.sql", ], ), ( '@db.schema.repo/branches/"feature/commit"/', - '@db.schema.repo/branches/"feature/commit"/', + '@db.schema.repo/branches/"feature/commit"', [ "/s1.sql", "/a/S3.sql", @@ -750,7 +752,7 @@ def test_execute_slash_in_repository_name( "create temporary stage if not exists IDENTIFIER('FOO.BAR.snowflake_cli_tmp_stage_123')" ) assert copy_call == mock.call( - f"copy files into @FOO.BAR.snowflake_cli_tmp_stage_123/ from {expected_stage}" + f"copy files into @FOO.BAR.snowflake_cli_tmp_stage_123/ from {expected_stage}/" ) assert ls_call == mock.call( f"ls @FOO.BAR.snowflake_cli_tmp_stage_123", cursor_class=DictCursor diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index 54780ce85a..beadf762e0 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -48,7 +48,7 @@ from snowflake.cli._plugins.nativeapp.policy import AllowAlwaysPolicy from snowflake.cli._plugins.stage.diff import ( DiffResult, - StagePath, + StagePathType, ) from snowflake.cli._plugins.workspace.manager import WorkspaceManager from snowflake.cli.api.console import cli_console as cc @@ -123,7 +123,7 @@ def test_sync_deploy_root_with_stage( mock_cursor, ): mock_execute.return_value = mock_cursor([("old_role",)], []) - mock_diff_result = DiffResult(different=[StagePath("setup.sql")]) + mock_diff_result = DiffResult(different=[StagePathType("setup.sql")]) mock_compute_stage_diff.return_value = mock_diff_result mock_local_diff_with_stage.return_value = None current_working_directory = os.getcwd() @@ -184,12 +184,12 @@ def test_sync_deploy_root_with_stage( [ [ True, - [StagePath("only-stage.txt")], + [StagePathType("only-stage.txt")], False, ], [ False, - [StagePath("only-stage-1.txt"), StagePath("only-stage-2.txt")], + [StagePathType("only-stage-1.txt"), StagePathType("only-stage-2.txt")], True, ], ], @@ -1106,7 +1106,7 @@ def test_get_paths_to_sync( paths_to_sync = [Path(p) for p in paths_to_sync] result = _get_stage_paths_to_sync(paths_to_sync, Path("deploy/")) - assert result.sort() == [StagePath(p) for p in expected_result].sort() + assert result.sort() == [StagePathType(p) for p in expected_result].sort() @mock.patch(SQL_EXECUTOR_EXECUTE) diff --git a/tests/stage/test_diff.py b/tests/stage/test_diff.py index af225cd5e3..8ee1e462fd 100644 --- a/tests/stage/test_diff.py +++ b/tests/stage/test_diff.py @@ -24,7 +24,7 @@ from snowflake.cli._plugins.nativeapp.artifacts import BundleMap from snowflake.cli._plugins.stage.diff import ( DiffResult, - StagePath, + StagePathType, build_md5_map, compute_stage_diff, delete_only_on_stage_files, @@ -58,8 +58,8 @@ STAGE_LS_COLUMNS = ["name", "size", "md5", "last_modified"] -def as_stage_paths(paths: typing.Iterable[str]) -> List[StagePath]: - return [StagePath(p) for p in paths] +def as_stage_paths(paths: typing.Iterable[str]) -> List[StagePathType]: + return [StagePathType(p) for p in paths] def md5_of(contents: Union[str, bytes]) -> str: @@ -223,7 +223,7 @@ def test_get_stage_path_from_file(): local_files = enumerate_files(local_path) for local_file in local_files: relpath = str(local_file.relative_to(local_path)) - actual.append(get_stage_subpath(StagePath(relpath))) + actual.append(get_stage_subpath(StagePathType(relpath))) assert actual.sort() == expected @@ -284,9 +284,9 @@ def test_build_md5_map(mock_cursor): ) expected = { - StagePath("README.md"): "9b650974f65cc49be96a5ed34ac6d1fd", - StagePath("my.jar"): "fc605d0e2e50cf3e71873d57f4c598b0", - StagePath("ui/streamlit.py"): "a7dfdfaf892ecfc5f164914123c7f2cc", + StagePathType("README.md"): "9b650974f65cc49be96a5ed34ac6d1fd", + StagePathType("my.jar"): "fc605d0e2e50cf3e71873d57f4c598b0", + StagePathType("ui/streamlit.py"): "a7dfdfaf892ecfc5f164914123c7f2cc", } assert actual == expected diff --git a/tests/stage/test_stage.py b/tests/stage/test_stage.py index 95000be426..86944fd07a 100644 --- a/tests/stage/test_stage.py +++ b/tests/stage/test_stage.py @@ -19,6 +19,7 @@ import pytest from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli.api.errno import DOES_NOT_EXIST_OR_NOT_AUTHORIZED +from snowflake.cli.api.stage_path import StagePath from snowflake.connector import ProgrammingError from snowflake.connector.cursor import DictCursor @@ -188,7 +189,7 @@ def test_stage_copy_remote_to_local_quoted_uri_recursive( mock_execute, runner, mock_cursor, raw_path, expected_uri ): mock_execute.side_effect = [ - mock_cursor([{"name": "stageName/file"}], []), + mock_cursor([{"name": "stageName/file.py"}], []), mock_cursor([(raw_path)], ["file"]), ] with TemporaryDirectory() as tmp_dir: @@ -209,7 +210,7 @@ def test_stage_copy_remote_to_local_quoted_uri_recursive( assert result.exit_code == 0, result.output assert mock_execute.mock_calls == [ mock.call("ls @stageName", cursor_class=DictCursor), - mock.call(f"get @stageName/file {file_uri} parallel=4"), + mock.call(f"get @stageName/file.py {file_uri} parallel=4"), ] @@ -973,9 +974,9 @@ def test_execute_not_existing_stage(mock_execute, mock_cursor, runner): @pytest.mark.parametrize( "stage_path,expected_message", [ - ("exe/*.txt", "No files matched pattern 'exe/*.txt'"), - ("exe/directory", "No files matched pattern 'exe/directory'"), - ("exe/some_file.sql", "No files matched pattern 'exe/some_file.sql'"), + ("exe/*.txt", "No files matched pattern '@exe/*.txt'"), + ("exe/directory", "No files matched pattern '@exe/directory'"), + ("exe/some_file.sql", "No files matched pattern '@exe/some_file.sql'"), ], ) @mock.patch(f"{STAGE_MANAGER}._execute_query") @@ -1105,7 +1106,7 @@ def test_command_aliases(mock_connector, runner, mock_ctx, command, parameters): (["my_stage/dir/parallel/requirements.txt"], None, []), ( ["my_stage/dir/files/requirements.txt"], - "db.schema.my_stage/dir/files/requirements.txt", + "@db.schema.my_stage/dir/files/requirements.txt", ["aaa", "bbb"], ), ( @@ -1114,12 +1115,12 @@ def test_command_aliases(mock_connector, runner, mock_ctx, command, parameters): "my_stage/dir/requirements.txt", "my_stage/dir/files/requirements.txt", ], - "db.schema.my_stage/dir/files/requirements.txt", + "@db.schema.my_stage/dir/files/requirements.txt", ["aaa", "bbb"], ), ( ["my_stage/requirements.txt"], - "db.schema.my_stage/requirements.txt", + "@db.schema.my_stage/requirements.txt", ["aaa", "bbb"], ), ], @@ -1145,7 +1146,7 @@ def __call__(self, file_on_stage, target_dir): ): with mock.patch.object(StageManager, "get", get_mock) as get_mock: result = sm._check_for_requirements_file( # noqa: SLF001 - stage_path_parts=sm._stage_path_part_factory(input_path) # noqa: SLF001 + stage_path=StagePath.from_stage_str(input_path) ) assert result == packages diff --git a/tests/stage/test_stage_path.py b/tests/stage/test_stage_path.py new file mode 100644 index 0000000000..9bc4fe42cc --- /dev/null +++ b/tests/stage/test_stage_path.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import pytest +from snowflake.cli.api.stage_path import StagePath + +# (path, is_git_repo) +ROOT_STAGES = [ + ("~", False), + ("~/", False), + ("stage", False), + ("stage/", False), + ("db.schema.stage", False), + ("db.schema.stage/", False), + ("db.schema.repo/branches/main/", True), + ('db.schema.repo/branches/"main/with/slash"/', True), +] + +DIRECTORIES = [ + ("~/my_path", False), + ("~/my_path/", False), + ("stage/my_path", False), + ("stage/my_path/", False), + ("db.schema.stage/my_path", False), + ("db.schema.stage/my_path/", False), + ("db.schema.repo/branches/main/my_path", True), + ("db.schema.repo/branches/main/my_path/", True), + ('db.schema.repo/branches/"main/with/slash"/my_path', True), +] + +FILES = [ + ("~/file.py", False), + ("stage/file.py", False), + ("db.schema.stage/file.py", False), + ("repo/branches/main/file.py", True), + ("db.schema.repo/branches/main/file.py", True), + ('db.schema.repo/branches/"main/with/slash"/file.py', True), +] + +FILES_UNDER_PATH = [ + ("~/my_path/file.py", False), + ("stage/my_path/file.py", False), + ("db.schema.stage/my_path/file.py", False), + ("repo/branches/main/my_path/file.py", True), + ("db.schema.repo/branches/main/my_path/file.py", True), + ('db.schema.repo/branches/"main/with/slash"/my_path/file.py', True), +] + + +def with_at_prefix(test_data: list[tuple[str, bool]]): + return [(f"@{path}", is_git_repo) for path, is_git_repo in test_data] + + +def with_snow_prefix(test_data: list[tuple[str, bool]]): + return [(f"snow://{path}", is_git_repo) for path, is_git_repo in test_data] + + +def parametrize_with(data: list[tuple[str, bool]]): + return pytest.mark.parametrize( + "path, is_git_repo", [*data, *with_at_prefix(data), *with_snow_prefix(data)] + ) + + +def build_stage_path(path, is_git_repo): + if is_git_repo: + stage_path = StagePath.from_git_str(path) + else: + stage_path = StagePath.from_stage_str(path) + return stage_path + + +@parametrize_with(ROOT_STAGES) +def test_root_paths(path, is_git_repo): + stage_path = build_stage_path(path, is_git_repo) + assert stage_path.is_root() + assert stage_path.parts == () + assert stage_path.is_dir() + assert not stage_path.is_file() + assert stage_path.name == "" + assert stage_path.suffix == "" + assert stage_path.stem == "" + assert stage_path.stage == path.lstrip("@").replace("snow://", "").split("/")[0] + assert stage_path.absolute_path() == "@" + path.lstrip("@").replace("snow://", "") + + +@parametrize_with(DIRECTORIES) +def test_dir_paths(path, is_git_repo): + stage_path = build_stage_path(path, is_git_repo) + assert not stage_path.is_root() + assert stage_path.parts == ("my_path",) + assert stage_path.is_dir() + assert not stage_path.is_file() + assert stage_path.name == "my_path" + assert stage_path.suffix == "" + assert stage_path.stem == "my_path" + assert stage_path.stage == path.lstrip("@").replace("snow://", "").split("/")[0] + assert stage_path.absolute_path() == "@" + path.lstrip("@").replace("snow://", "") + + +@parametrize_with(FILES) +def test_file_paths(path, is_git_repo): + stage_path = build_stage_path(path, is_git_repo) + assert not stage_path.is_root() + assert stage_path.parts == ("file.py",) + assert not stage_path.is_dir() + assert stage_path.is_file() + assert stage_path.name == "file.py" + assert stage_path.suffix == ".py" + assert stage_path.stem == "file" + assert stage_path.stage == path.lstrip("@").replace("snow://", "").split("/")[0] + assert stage_path.absolute_path() == "@" + path.lstrip("@").replace( + "snow://", "" + ).rstrip("/") + + +@parametrize_with(FILES_UNDER_PATH) +def test_dir_with_file_paths(path, is_git_repo): + stage_path = build_stage_path(path, is_git_repo) + assert not stage_path.is_root() + assert stage_path.parts == ("my_path", "file.py") + assert not stage_path.is_dir() + assert stage_path.is_file() + assert stage_path.name == "file.py" + assert stage_path.suffix == ".py" + assert stage_path.stem == "file" + assert stage_path.stage == path.lstrip("@").replace("snow://", "").split("/")[0] + assert stage_path.absolute_path() == "@" + path.lstrip("@").replace( + "snow://", "" + ).rstrip("/") + + +def test_join_path(): + path = StagePath.from_stage_str("@my_stage/path") + new_path = path.joinpath("new_path").joinpath("file.py") + assert new_path.parts == ("path", "new_path", "file.py") + assert path.stage == new_path.stage + + +def test_join_path_using_division(): + path = StagePath.from_stage_str("@my_stage/path") + new_path = path / "new_path" / "file.py" + assert new_path.parts == ("path", "new_path", "file.py") + assert path.stage == new_path.stage + + +def test_path_starting_with_slash(): + path = StagePath.from_stage_str("@my_stage") + new_path = path.joinpath("/file.txt") + assert new_path.parts == ("file.txt",) + assert path.stage == new_path.stage + assert new_path.absolute_path() == "@my_stage/file.txt" + + +@parametrize_with(FILES_UNDER_PATH) +def test_parent_path(path, is_git_repo): + path = build_stage_path(path, is_git_repo) + parent_path = path.parent + assert parent_path.parts == ("my_path",) + assert path.stage == parent_path.stage + + +@pytest.mark.parametrize( + "stage_name, path", + [ + ("my_stage", "@my_stage/path/file.py"), + ("db.schema.my_stage", "@db.schema.my_stage/path/file.py"), + ], +) +def test_root_path(stage_name, path): + stage_path = StagePath.from_stage_str(path) + assert stage_path.root_path() == StagePath.from_stage_str(f"@{stage_name}") diff --git a/tests/streamlit/test_commands.py b/tests/streamlit/test_commands.py index 92065e3628..84968e5aab 100644 --- a/tests/streamlit/test_commands.py +++ b/tests/streamlit/test_commands.py @@ -565,12 +565,12 @@ def test_deploy_streamlit_main_and_pages_files_experimental( result = runner.invoke(["streamlit", "deploy", "--experimental"]) if enable_streamlit_versioned_stage: - root_path = ( - f"snow://streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/versions/live" - ) + root_path = f"@streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/versions/live" post_create_command = f"ALTER STREAMLIT MockDatabase.MockSchema.{STREAMLIT_NAME} ADD LIVE VERSION FROM LAST" else: - root_path = f"snow://streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" + root_path = ( + f"@streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" + ) if enable_streamlit_no_checkouts: post_create_command = None else: @@ -645,9 +645,7 @@ def test_deploy_streamlit_main_and_pages_files_experimental_double_deploy( assert result2.exit_code == 0, result2.output - root_path = ( - f"snow://streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" - ) + root_path = f"@streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" # Same as normal, except no ALTER query assert ctx.get_queries() == [ @@ -699,12 +697,12 @@ def test_deploy_streamlit_main_and_pages_files_experimental_no_stage( result = runner.invoke(["streamlit", "deploy", "--experimental"]) if enable_streamlit_versioned_stage: - root_path = ( - f"snow://streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/versions/live" - ) + root_path = f"@streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/versions/live" post_create_command = f"ALTER STREAMLIT MockDatabase.MockSchema.{STREAMLIT_NAME} ADD LIVE VERSION FROM LAST" else: - root_path = f"snow://streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" + root_path = ( + f"@streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" + ) post_create_command = ( f"ALTER streamlit MockDatabase.MockSchema.{STREAMLIT_NAME} CHECKOUT" ) @@ -748,9 +746,7 @@ def test_deploy_streamlit_main_and_pages_files_experimental_replace( with project_directory("example_streamlit"): result = runner.invoke(["streamlit", "deploy", "--experimental", "--replace"]) - root_path = ( - f"snow://streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" - ) + root_path = f"@streamlit/MockDatabase.MockSchema.{STREAMLIT_NAME}/default_checkout" assert result.exit_code == 0, result.output assert ctx.get_queries() == [ dedent(