Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Components - Added ComponentStore search #3884

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 140 additions & 1 deletion sdk/python/kfp/components/_component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@

from pathlib import Path
import copy
import hashlib
import json
import logging
import requests
from typing import Callable
import tempfile
from typing import Callable, Iterable
from . import _components as comp
from .structures import ComponentReference
from ._key_value_store import KeyValueStore


_COMPONENT_FILENAME = 'component.yaml'


class ComponentStore:
def __init__(self, local_search_paths=None, url_search_prefixes=None):
Expand All @@ -18,6 +27,10 @@ def __init__(self, local_search_paths=None, url_search_prefixes=None):
self._digests_subpath = 'versions/sha256'
self._tags_subpath = 'versions/tags'

cache_base_dir = Path(tempfile.gettempdir()) / '.kfp_components'
self._git_blob_hash_to_data_db = KeyValueStore(cache_dir=cache_base_dir / 'git_blob_hash_to_data')
self._url_to_info_db = KeyValueStore(cache_dir=cache_base_dir / 'url_to_info')

def load_component_from_url(self, url):
return comp.load_component_from_url(url)

Expand Down Expand Up @@ -129,6 +142,132 @@ def _load_component_from_ref(self, component_ref: ComponentReference) -> Callabl
component_ref = self._load_component_spec_in_component_ref(component_ref)
return comp._create_task_factory_from_component_spec(component_spec=component_ref.spec, component_ref=component_ref)

def search(self, name: str):
'''Searches for components by name in the configured component store.

Prints the component name and URL for components that match the given name.
Only components on GitHub are currently supported.

Example::

kfp.components.ComponentStore.default_store.search('xgboost')

>>> Xgboost train https://raw.githubusercontent.com/.../components/XGBoost/Train/component.yaml
>>> Xgboost predict https://raw.githubusercontent.com/.../components/XGBoost/Predict/component.yaml
'''
self._refresh_component_cache()
for url in self._url_to_info_db.keys():
component_info = json.loads(self._url_to_info_db.try_get_value_bytes(url))
component_name = component_info['name']
if name.casefold() in component_name.casefold():
print('\t'.join([
component_name,
url,
]))

def list(self):
self.search('')

def _refresh_component_cache(self):
for url_search_prefix in self.url_search_prefixes:
if url_search_prefix.startswith('https://raw.githubusercontent.com/'):
logging.info('Searching for components in "{}"'.format(url_search_prefix))
for candidate in _list_candidate_component_uris_from_github_repo(url_search_prefix):
component_url = candidate['url']
if self._url_to_info_db.exists(component_url):
continue

logging.debug('Found new component URL: "{}"'.format(component_url))

blob_hash = candidate['git_blob_hash']
if not self._git_blob_hash_to_data_db.exists(blob_hash):
logging.debug('Downloading component spec from "{}"'.format(component_url))
response = _get_request_session().get(component_url)
response.raise_for_status()
component_data = response.content

# Verifying the hash
received_data_hash = _calculate_git_blob_hash(component_data)
if received_data_hash.lower() != blob_hash.lower():
raise RuntimeError(
'The downloaded component ({}) has incorrect hash: "{}" != "{}"'.format(
component_url, received_data_hash, blob_hash,
)
)

# Verifying that the component is loadable
try:
component_spec = comp._load_component_spec_from_component_text(component_data)
except:
continue
self._git_blob_hash_to_data_db.store_value_bytes(blob_hash, component_data)
else:
component_data = self._git_blob_hash_to_data_db.try_get_value_bytes(blob_hash)
component_spec = comp._load_component_spec_from_component_text(component_data)

component_name = component_spec.name
self._url_to_info_db.store_value_text(component_url, json.dumps(dict(
name=component_name,
url=component_url,
git_blob_hash=blob_hash,
digest=_calculate_component_digest(component_data),
)))


def _get_request_session(max_retries: int = 3):
session = requests.Session()

retry_strategy = requests.packages.urllib3.util.retry.Retry(
total=max_retries,
backoff_factor=0.1,
status_forcelist=[413, 429, 500, 502, 503, 504],
method_whitelist=frozenset(['GET', 'POST']),
)

session.mount('https://', requests.adapters.HTTPAdapter(max_retries=retry_strategy))
session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retry_strategy))

return session


def _calculate_git_blob_hash(data: bytes) -> str:
return hashlib.sha1(b'blob ' + str(len(data)).encode('utf-8') + b'\x00' + data).hexdigest()


def _calculate_component_digest(data: bytes) -> str:
return hashlib.sha256(data.replace(b'\r\n', b'\n')).hexdigest()


def _list_candidate_component_uris_from_github_repo(url_search_prefix: str) -> Iterable[str]:
(schema, _, host, org, repo, ref, path_prefix) = url_search_prefix.split('/', 6)
for page in range(1, 999):
search_url = (
'https://api.github.com/search/code?q=filename:{}+repo:{}/{}&page={}&per_page=1000'
).format(_COMPONENT_FILENAME, org, repo, page)
response = _get_request_session().get(search_url)
response.raise_for_status()
result = response.json()
items = result['items']
if not items:
break
for item in items:
html_url = item['html_url']
# Constructing direct content URL
# There is an API (/repos/:owner/:repo/git/blobs/:file_sha) for
# getting the blob content, but it requires decoding the content.
raw_url = html_url.replace(
'https://github.com/', 'https://raw.githubusercontent.com/'
).replace('/blob/', '/', 1)
if not raw_url.endswith(_COMPONENT_FILENAME):
# GitHub matches component_test.yaml when searching for filename:"component.yaml"
continue
result_item = dict(
url=raw_url,
path = item['path'],
git_blob_hash = item['sha'],
)
yield result_item


ComponentStore.default_store = ComponentStore(
local_search_paths=[
Expand Down
62 changes: 62 additions & 0 deletions sdk/python/kfp/components/_key_value_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import hashlib
from pathlib import Path


class KeyValueStore:
KEY_FILE_SUFFIX = '.key'
VALUE_FILE_SUFFIX = '.value'

def __init__(
self,
cache_dir: str,
):
cache_dir = Path(cache_dir)
hash_func = (lambda text: hashlib.sha256(text.encode('utf-8')).hexdigest())
self.cache_dir = cache_dir
self.hash_func = hash_func

def store_value_text(self, key: str, text: str) -> str:
return self.store_value_bytes(key, text.encode('utf-8'))

def store_value_bytes(self, key: str, data: bytes) -> str:
cache_id = self.hash_func(key)
self.cache_dir.mkdir(parents=True, exist_ok=True)
cache_key_file_path = self.cache_dir / (cache_id + KeyValueStore.KEY_FILE_SUFFIX)
cache_value_file_path = self.cache_dir / (cache_id + KeyValueStore.VALUE_FILE_SUFFIX)
if cache_key_file_path.exists():
old_key = cache_key_file_path.read_text()
if key != old_key:
raise RuntimeError(
'Cache is corrupted: File "{}" contains existing key '
'"{}" != new key "{}"'.format(cache_key_file_path, old_key, key)
)
if cache_value_file_path.exists():
old_data = cache_value_file_path.write_bytes()
if data != old_data:
# TODO: Add options to raise error when overwriting the value.
pass
cache_value_file_path.write_bytes(data)
cache_key_file_path.write_text(key)
return cache_id

def try_get_value_text(self, key: str) -> str:
result = self.try_get_value_bytes(key)
if result is None:
return None
return result.decode('utf-8')

def try_get_value_bytes(self, key: str) -> bytes:
cache_id = self.hash_func(key)
cache_value_file_path = self.cache_dir / (cache_id + KeyValueStore.VALUE_FILE_SUFFIX)
if cache_value_file_path.exists():
return cache_value_file_path.read_bytes()
return None

def exists(self, key: str) -> bool:
cache_id = self.hash_func(key)
cache_key_file_path = self.cache_dir / (cache_id + KeyValueStore.KEY_FILE_SUFFIX)
return cache_key_file_path.exists()

def keys(self):
for cache_key_file_path in self.cache_dir.glob('*' + KeyValueStore.KEY_FILE_SUFFIX):
yield Path(cache_key_file_path).read_text()