Skip to content

Commit

Permalink
general updates for locating the globals
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mrojas committed Oct 1, 2024
1 parent 07dad6f commit 3a1df01
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
4 changes: 0 additions & 4 deletions extras/sf_custom_importer/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Custom Importer for Snowflake Notebooks

This helper implements a custom importer that allows you to load code from an stage.

USAGE:

Add the sf_custom_importer.py to your stage imports:
15 changes: 11 additions & 4 deletions extras/sf_custom_importer/sf_custom_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def load(self):
finally:
os.chdir(current_dir) # Change back to it

def find_main_globals():
# Traverse the stack in the current thread
for frame_info in inspect.stack():
frame = frame_info.frame # Get the frame object
if frame.f_globals.get("__name__") == "__main__":
# If the frame is the main module, return its globals
return frame.f_globals
return None # Return None if no main module frame is found


STAGE_PREFIX="__STAGE_IMPORT__"
Expand All @@ -56,28 +64,27 @@ def __init__(self):
self.stages = {}
self.tempdirs_for_stage_files = {}
def load_stage_files(self,stage_name):
print("Loading module")
# Connect to Snowflake and load the file corresponding to the module
session = Session.builder.getOrCreate()
stage_files = [x[0] for x in session.sql(f"ls @{stage_name}").collect()]
tempdir_ref = tempfile.mkdtemp()
for stage_file in stage_files:
print(f"Downloading {stage_file} to {tempdir_ref}")
session.file.get(f"@{stage_file}", tempdir_ref)
print(f"Downloaded {stage_file} to {tempdir_ref}")
self.tempdirs_for_stage_files[stage_name] = tempdir_ref
def find_spec(self, fullname, path, target=None):
caller_globals = inspect.currentframe().f_back.f_globals
if fullname.startswith(STAGE_PREFIX) and fullname.endswith("__") and not "." in fullname:
caller_globals = find_main_globals()
stage_name = fullname.replace(STAGE_PREFIX,"")[:-2]
if stage_name == "":
stage_name = current_stage_for_import
if stage_name is None or stage_name.trim() == "":
if stage_name is None or stage_name.strip() == "":
raise Exception("No stage name specified")
file_loader = SnowflakeFileLoader(stage_name,caller_globals)
self.load_stage_files(stage_name)
return spec_from_loader(fullname, file_loader)
elif fullname.startswith(STAGE_PREFIX) and "." in fullname:
caller_globals = find_main_globals()
stage_name = fullname.split(".")[0].replace(STAGE_PREFIX,"")[:-2]
filename = fullname.split(".")[-1]
temp_dir = self.tempdirs_for_stage_files[stage_name]
Expand Down

0 comments on commit 3a1df01

Please sign in to comment.