Skip to content

Commit

Permalink
Use FQN for tmp stage
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Oct 4, 2024
1 parent 8895ff0 commit f8f554e
Showing 1 changed file with 52 additions and 15 deletions.
67 changes: 52 additions & 15 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def get_standard_stage_directory_path(self) -> str:
return path + "/"
return path

def strip_stage_prefix(self, path: str):
raise NotImplementedError


@dataclass
class DefaultStagePathParts(StagePathParts):
Expand Down Expand Up @@ -142,6 +145,13 @@ def replace_stage_prefix(self, file_path: str) -> str:
file_path_without_prefix = Path(file_path).parts[OMIT_FIRST]
return f"{stage}/{'/'.join(file_path_without_prefix)}"

def strip_stage_prefix(self, file_path: str) -> str:
if file_path.startswith("@"):
file_path = file_path[OMIT_FIRST]
if file_path.startswith(self.stage_name):
return file_path[len(self.stage_name) :]
return file_path

def add_stage_prefix(self, file_path: str) -> str:
stage = self.stage.rstrip("/")
return f"{stage}/{file_path.lstrip('/')}"
Expand Down Expand Up @@ -312,8 +322,8 @@ def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor
"Destination path cannot be a user stage. Please provide a named stage."
)

source = source_path_parts.get_standard_stage_path()
destination = destination_path_parts.get_standard_stage_directory_path()
source = source_path_parts.full_path
destination = destination_path_parts.full_path
log.info("Copying files from %s to %s", source, destination)
query = f"copy files into {destination} from {source}"
return self._execute_query(query)
Expand All @@ -336,8 +346,8 @@ def remove(
def create(
self, fqn: FQN, comment: Optional[str] = None, temporary: bool = False
) -> SnowflakeCursor:
temporary_str = "temporary" if temporary else ""
query = f"create {temporary_str} stage if not exists {fqn.sql_identifier}"
temporary_str = "temporary " if temporary else ""
query = f"create {temporary_str}stage if not exists {fqn.sql_identifier}"
if comment:
query += f" comment='{comment}'"
return self._execute_query(query)
Expand All @@ -354,7 +364,10 @@ def execute(
requires_temporary_stage: bool = False,
):
if requires_temporary_stage:
stage_path_parts = self._create_temporary_copy_of_stage(stage_path)
(
stage_path_parts,
original_path_parts,
) = self._create_temporary_copy_of_stage(stage_path)
else:
stage_path_parts = self._stage_path_part_factory(stage_path)

Expand Down Expand Up @@ -389,24 +402,39 @@ def execute(

for file_path in sorted_file_path_list:
file_stage_path = stage_path_parts.add_stage_prefix(file_path)

# For better reporting push down the information about original
# path if execution happens from temporary stage
if requires_temporary_stage:
original_path = original_path_parts.add_stage_prefix(file_path)
else:
original_path = file_stage_path

if file_path.endswith(".py"):
result = self._execute_python(
file_stage_path=file_stage_path,
on_error=on_error,
variables=python_variables,
original_file=original_path,
)
else:
result = self._call_execute_immediate(
file_stage_path=file_stage_path,
variables=sql_variables,
on_error=on_error,
original_file=original_path,
)
results.append(result)

return results

def _create_temporary_copy_of_stage(self, stage_path: str) -> StagePathParts:
tmp_stage = f"snowflake_cli_tmp_stage_{int(time.time())}"
def _create_temporary_copy_of_stage(
self, stage_path: str
) -> tuple[StagePathParts, StagePathParts]:
tmp_stage_name = f"snowflake_cli_tmp_stage_{int(time.time())}"
tmp_stage = (
FQN.from_stage(tmp_stage_name).using_connection(conn=self._conn).identifier
)
sm = StageManager()

# Create temporary stage, it will be dropped with end of session
Expand All @@ -420,10 +448,14 @@ def _create_temporary_copy_of_stage(self, stage_path: str) -> StagePathParts:

# Copy the content
self.copy_files(
source_path=original_path_parts.stage_name,
destination_path=stage_path_parts.stage_name,
source_path=original_path_parts.get_full_stage_path(
original_path_parts.stage_name
),
destination_path=stage_path_parts.get_full_stage_path(
stage_path_parts.stage_name
),
)
return stage_path_parts
return stage_path_parts, original_path_parts

def _get_files_list_from_stage(
self, stage_path_parts: StagePathParts, pattern: str | None = None
Expand Down Expand Up @@ -504,16 +536,17 @@ def _call_execute_immediate(
file_stage_path: str,
variables: Optional[str],
on_error: OnErrorType,
original_file: str,
) -> Dict:
try:
query = f"execute immediate from {self.quote_stage_name(file_stage_path)}"
if variables:
query += variables
self._execute_query(query)
return StageManager._success_result(file=file_stage_path)
return StageManager._success_result(file=original_file)
except ProgrammingError as e:
StageManager._handle_execution_exception(on_error=on_error, exception=e)
return StageManager._error_result(file=file_stage_path, msg=e.msg)
return StageManager._error_result(file=original_file, msg=e.msg)

@staticmethod
def _stage_path_part_factory(stage_path: str) -> StagePathParts:
Expand Down Expand Up @@ -606,7 +639,11 @@ def _python_execution_procedure(
return _python_execution_procedure

def _execute_python(
self, file_stage_path: str, on_error: OnErrorType, variables: Dict
self,
file_stage_path: str,
on_error: OnErrorType,
variables: Dict,
original_file: str,
):
"""
Executes Python file from stage using a Snowpark temporary procedure.
Expand All @@ -616,7 +653,7 @@ def _execute_python(

try:
self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables) # type: ignore
return StageManager._success_result(file=file_stage_path)
return StageManager._success_result(file=original_file)
except SnowparkSQLException as e:
StageManager._handle_execution_exception(on_error=on_error, exception=e)
return StageManager._error_result(file=file_stage_path, msg=e.message)
return StageManager._error_result(file=original_file, msg=e.message)

0 comments on commit f8f554e

Please sign in to comment.