Skip to content

Commit

Permalink
updating wheel_loader for usages in notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mrojas committed Nov 21, 2024
1 parent 3a1df01 commit 0879086
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
18 changes: 17 additions & 1 deletion extras/wheel_loader/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,35 @@ def your_handler(arg1, arg2):
$$
```

Sometimes a package doesn't have a `.whl` but instead it has a `.tar.gz` or `.tgz` archive.
Sometimes a package doesn't have a `.whl` but instead it has a `.tar.gz` or `.tgz` archive.

In that case you can use:

```python
wheel_loader.load_tgz('packagename.tar.tgz')
```

You can also do:

```python
wheel_loader.add_tars()
```

To load all the `.tar.gz` files added to your imports.

That rest is just python bliss :)

# loading wheels in snowflake notebooks

After the release of the snowflake notebooks, some users have the need to load wheels into their notebooks.

Notebooks **already provide** a mechanism to [reference stage packages ](https://docs.snowflake.com/en/user-guide/ui-snowsight/notebooks-import-packages#import-packages-from-a-snowflake-stage)but there might be some scenarios where you might like to leverage having a prepopulated stage and be able to load all the wheel files in that stage.

This can be helpful to enforce some RBAC policies, so if an users does not have permissions on an stage your wont be able to load those wheels
![error](./wheels_error1.png)

Or you just want to have an easy way to have some custom packages you want to easily load on some notebooks with a couple of lines.

The wheel loader has been very helpful for me, so just in case I hope this functionality becomes useful for notebook users as well.

This snippet was developed by [James Weakley](https://medium.com/@jamesweakley) in a [Medium post](https://medium.com/snowflake/running-pip-packages-in-snowflake-d43581a67439), check it out for more details.
56 changes: 41 additions & 15 deletions extras/wheel_loader/wheel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
import logging
from functools import lru_cache
from snowflake.snowpark._internal.utils import is_in_stored_procedure
import tarfile

class FileLock:
Expand Down Expand Up @@ -56,8 +55,17 @@ def load_tgz(tgz_name,append=True, use_lock=True):
@lru_cache()
def load(whl_name,append=True, use_lock=True):
logging.info(f"loading wheel {whl_name}")
whl_path = Path(sys._xoptions['snowflake_import_directory']) / whl_name
extraction_path = Path('/tmp') / whl_name
if whl_name.startswith("@"):
# this is only expected to be used in notebooks
import snowflake.snowpark
session = snowflake.snowpark.Session.builder.getOrCreate()
os.makedirs("/tmp/whl_download/",exist_ok=True)
session.file.get(whl_name,"/tmp/whl_download/")
whl_path = Path("/tmp/whl_download/") / os.path.basename(whl_name)
extraction_path = Path('/tmp') / os.path.basename(whl_name)
else:
whl_path = Path(sys._xoptions['snowflake_import_directory']) / whl_name
extraction_path = Path('/tmp') / whl_name

if use_lock:
with FileLock():
Expand All @@ -83,7 +91,7 @@ def load(whl_name,append=True, use_lock=True):
# Add a directory to the pkg_resources working set
pkg_resources.working_set.add_entry(str(extraction_path))
except Exception as ex:
logging.error(f"failed to add {extraction_path} to pkg_resources working set: {ex}")
logging.warning(f"failed to add {extraction_path} to pkg_resources working set: {ex}")
return message

def setup_home():
Expand All @@ -96,27 +104,45 @@ def setup_home():

# this decoration will make sure that this does not get loaded more that one
@lru_cache(maxsize=1)
def add_wheels():
if not is_in_stored_procedure():
message = "Wheel loader can only be used in stored procedures"
logging.warning(message)
return message
def add_wheels(from_stage=None):
try:
import snowflake.snowpark._internal.utils
if not snowflake.snowpark._internal.utils.is_in_stored_procedure():
message = "Wheel loader can only be used in stored procedures"
logging.warning(message)
return message
except:
pass
setup_home()
wheels = [x for x in os.listdir(sys._xoptions['snowflake_import_directory']) if x.endswith('.whl')]
if from_stage and from_stage.startswith("@"):
try:
import snowflake.snowpark
session = snowflake.snowpark.Session.builder.getOrCreate()
wheels = [("@" + r[0]) for r in session.sql(f"ls {from_stage} pattern='.*whl'").collect()]
except Exception as exceptionStageFiles:
message = "Error while getting wheels from stage: " + str(exceptionStageFiles)
logging.warning(message)
return message
else:
wheels = [x for x in os.listdir(sys._xoptions['snowflake_import_directory']) if x.endswith('.whl')]
with FileLock():
for whl in wheels:
load(whl, False) # we use one lock for all
message = str(wheels) + " where loaded"
message = str(wheels) + " wheels where loaded"
logging.info(message)
return message

# this decoration will make sure that this does not get loaded more that one
@lru_cache(maxsize=1)
def add_tars():
if not is_in_stored_procedure():
message = "tgz loader can only be used in stored procedures"
logging.warning(message)
return message
try:
import snowflake.snowpark._internal.utils
if not snowflake.snowpark._internal.utils.is_in_stored_procedure():
message = "tgz loader can only be used in stored procedures"
logging.warning(message)
return message
except:
pass
setup_home()
tars = [x for x in os.listdir(sys._xoptions['snowflake_import_directory']) if x.endswith('.tgz') or x.endswith('tar.gz')]
with FileLock():
Expand Down
Binary file added extras/wheel_loader/wheels_error1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 0879086

Please sign in to comment.