Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Add deactivate SSL certification verity option #179

Merged
merged 12 commits into from
Dec 24, 2022
15 changes: 13 additions & 2 deletions mim/commands/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@
cls=OptionEatAll,
required=True,
help='Config ids to download, such as resnet18_8xb16_cifar10')
@click.option(
'--no-check-certificate',
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
'check_certificate',
is_flag=True,
default=True,
help='Ignore ssl certificate check')
@click.option(
'--dest', 'dest_root', type=str, help='Destination of saving checkpoints.')
def cli(package: str,
configs: List[str],
check_certificate: bool = True,
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
dest_root: Optional[str] = None) -> None:
"""Download checkpoints from url and parse configs from package.

Expand All @@ -47,17 +54,20 @@ def cli(package: str,
> mim download mmcls --config resnet18_8xb16_cifar10
> mim download mmcls --config resnet18_8xb16_cifar10 --dest .
"""
download(package, configs, dest_root)
download(package, configs, check_certificate, dest_root)


def download(package: str,
configs: List[str],
check_certificate: bool = True,
dest_root: Optional[str] = None) -> List[str]:
"""Download checkpoints from url and parse configs from package.

Args:
package (str): Name of package.
configs (List[str]): List of config ids.
check_certificate (bool): Whether to check the ssl certificate.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
Default: True.
dest_root (Optional[str]): Destination directory to save checkpoint and
config. Default: None.
"""
Expand Down Expand Up @@ -112,7 +122,8 @@ def download(package: str,
echo_success(f'{filename} exists in {dest_root}')
else:
# TODO: check checkpoint hash when all the models are ready.
download_from_file(checkpoint_url, checkpoint_path)
download_from_file(checkpoint_url, checkpoint_path,
check_certificate)
kim3321 marked this conversation as resolved.
Show resolved Hide resolved

echo_success(
f'Successfully downloaded {filename} to {dest_root}')
Expand Down
14 changes: 11 additions & 3 deletions mim/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,19 @@ def get_github_url(package: str) -> str:

def get_content_from_url(url: str,
timeout: int = 15,
stream: bool = False) -> Response:
stream: bool = False,
check_certificate: bool = True) -> Response:
"""Get content from url.

Args:
url (str): Url for getting content.
timeout (int): Set the socket timeout. Default: 15.
check_certificate (bool): Whether to check the ssl certificate.
Default: True.
"""
try:
response = requests.get(url, timeout=timeout, stream=stream)
response = requests.get(
url, timeout=timeout, stream=stream, verify=check_certificate)
except InvalidURL as err:
raise highlighted_error(err) # type: ignore
except Timeout as err:
Expand All @@ -157,19 +161,23 @@ def get_content_from_url(url: str,
@typing.no_type_check
def download_from_file(url: str,
dest_path: str,
check_certificate: bool,
kim3321 marked this conversation as resolved.
Show resolved Hide resolved
hash_prefix: Optional[str] = None) -> None:
"""Download object at the given URL to a local path.

Args:
url (str): URL of the object to download.
dest_path (str): Path where object will be saved.
check_certificate (bool): Whether to check the ssl certificate.
Default: True.
hash_prefix (string, optional): If not None, the SHA256 downloaded
file should start with `hash_prefix`. Default: None.
"""
if hash_prefix is not None:
sha256 = hashlib.sha256()

response = get_content_from_url(url, stream=True)
response = get_content_from_url(
url, stream=True, check_certificate=check_certificate)
size = int(response.headers.get('content-length'))
with open(dest_path, 'wb') as fw:
content_iter = response.iter_content(chunk_size=1024)
Expand Down