diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 95c71872724..776df2014d0 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -29,6 +29,7 @@ jobs: pip install yapf==0.32.0 pip install toml==0.10.2 pip install black==22.10.0 + pip install isort==5.12.0 - name: Running yapf run: | yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \ @@ -42,3 +43,14 @@ jobs: sky/skylet/providers/gcp/ \ sky/skylet/providers/azure/ \ sky/skylet/providers/ibm/ + - name: Running isort for black formatted files + run: | + isort --diff --check --profile black -l 88 -m 3 \ + sky/skylet/providers/ibm/ + - name: Running isort for yapf formatted files + run: | + isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \ + --sg 'sky/skylet/providers/aws/**' \ + --sg 'sky/skylet/providers/gcp/**' \ + --sg 'sky/skylet/providers/azure/**' \ + --sg 'sky/skylet/providers/ibm/**' diff --git a/docs/source/conf.py b/docs/source/conf.py index 8469c793bfe..517c5a19edd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,7 @@ # Configuration file for the Sphinx documentation builder. -import sys import os +import sys sys.path.insert(0, os.path.abspath('.')) sys.path.insert(0, os.path.abspath('../')) diff --git a/examples/docker/echo_app.py b/examples/docker/echo_app.py index 0edfc5f3928..1dae60dd16f 100644 --- a/examples/docker/echo_app.py +++ b/examples/docker/echo_app.py @@ -8,9 +8,10 @@ # python echo_app.py import random -import sky import string +import sky + with sky.Dag() as dag: # The setup command to build the container image setup = 'docker build -t echo:v0 /echo_app' diff --git a/examples/example_app.py b/examples/example_app.py index c26c4970343..44b3b29a713 100644 --- a/examples/example_app.py +++ b/examples/example_app.py @@ -15,10 +15,10 @@ Incorporate the notion of region/zone (affects pricing). Incorporate the notion of per-account egress quota (affects pricing). """ -import sky - import time_estimators +import sky + def make_application(): """A simple application: train_op -> infer_op.""" diff --git a/examples/horovod_distributed_tf_app.py b/examples/horovod_distributed_tf_app.py index e25ffc0e6c9..273f653a710 100644 --- a/examples/horovod_distributed_tf_app.py +++ b/examples/horovod_distributed_tf_app.py @@ -2,9 +2,10 @@ import json from typing import Dict, List -import sky import time_estimators +import sky + IPAddr = str with sky.Dag() as dag: diff --git a/examples/local/launch_cloud_onprem.py b/examples/local/launch_cloud_onprem.py index b937381a12e..4068ab9c56d 100644 --- a/examples/local/launch_cloud_onprem.py +++ b/examples/local/launch_cloud_onprem.py @@ -22,9 +22,9 @@ import tempfile import textwrap import uuid -import yaml from click import testing as cli_testing +import yaml from sky import cli from sky import global_user_state diff --git a/examples/playground/storage_playground.py b/examples/playground/storage_playground.py index 9ac7eb76522..fba7ed300ec 100644 --- a/examples/playground/storage_playground.py +++ b/examples/playground/storage_playground.py @@ -2,7 +2,8 @@ # These are not exhaustive tests. Actual Tests are in tests/test_storage.py and # tests/test_smoke.py. -from sky.data import storage, StoreType +from sky.data import storage +from sky.data import StoreType def get_args(): diff --git a/examples/ray_tune_examples/tune_ptl_example.py b/examples/ray_tune_examples/tune_ptl_example.py index addfdb9aa39..b9788d344f9 100644 --- a/examples/ray_tune_examples/tune_ptl_example.py +++ b/examples/ray_tune_examples/tune_ptl_example.py @@ -1,15 +1,14 @@ ### Source: https://docs.ray.io/en/latest/tune/examples/mnist_ptl_mini.html import math +import os -import torch from filelock import FileLock -from torch.nn import functional as F -import pytorch_lightning as pl from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule -import os -from ray.tune.integration.pytorch_lightning import TuneReportCallback - +import pytorch_lightning as pl from ray import tune +from ray.tune.integration.pytorch_lightning import TuneReportCallback +import torch +from torch.nn import functional as F class LightningMNISTClassifier(pl.LightningModule): diff --git a/examples/spot/lightning_cifar10/train.py b/examples/spot/lightning_cifar10/train.py index 13cc21842b5..0df6f18484b 100644 --- a/examples/spot/lightning_cifar10/train.py +++ b/examples/spot/lightning_cifar10/train.py @@ -1,21 +1,25 @@ # Code modified from https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/cifar10-baseline.html +import argparse +import glob import os -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization -from pytorch_lightning import LightningModule, Trainer, seed_everything -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning import LightningModule +from pytorch_lightning import seed_everything +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +import torch +import torch.nn as nn +import torch.nn.functional as F from torch.optim.lr_scheduler import OneCycleLR -from torch.optim.swa_utils import AveragedModel, update_bn +from torch.optim.swa_utils import AveragedModel +from torch.optim.swa_utils import update_bn from torchmetrics.functional import accuracy - -import argparse, glob +import torchvision seed_everything(7) diff --git a/examples/spot/resnet_ddp/resnet_ddp.py b/examples/spot/resnet_ddp/resnet_ddp.py index 69ad01e8eec..89d6d37fc83 100644 --- a/examples/spot/resnet_ddp/resnet_ddp.py +++ b/examples/spot/resnet_ddp/resnet_ddp.py @@ -1,17 +1,15 @@ +import argparse +import os +import random + +import numpy as np import torch -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim - +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler import torchvision import torchvision.transforms as transforms - -import argparse -import os -import random -import numpy as np - import wandb diff --git a/examples/tpu/tpu_app_code/run_tpu.py b/examples/tpu/tpu_app_code/run_tpu.py index 4ba7b65b3b6..2278dd730d2 100644 --- a/examples/tpu/tpu_app_code/run_tpu.py +++ b/examples/tpu/tpu_app_code/run_tpu.py @@ -1,8 +1,8 @@ -import tensorflow_datasets as tfds import tensorflow as tf +import tensorflow_datasets as tfds import tensorflow_text as tf_text -from transformers import TFDistilBertForSequenceClassification from transformers import TFBertForSequenceClassification +from transformers import TFDistilBertForSequenceClassification tpu = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(tpu) diff --git a/format.sh b/format.sh index 9100aeadc73..60d9aa53e68 100755 --- a/format.sh +++ b/format.sh @@ -54,6 +54,14 @@ YAPF_EXCLUDES=( '--exclude' 'sky/skylet/providers/ibm/**' ) +ISORT_YAPF_EXCLUDES=( + '--sg' 'build/**' + '--sg' 'sky/skylet/providers/aws/**' + '--sg' 'sky/skylet/providers/gcp/**' + '--sg' 'sky/skylet/providers/azure/**' + '--sg' 'sky/skylet/providers/ibm/**' +) + BLACK_INCLUDES=( 'sky/skylet/providers/aws' 'sky/skylet/providers/gcp' @@ -86,9 +94,12 @@ format_changed() { # Format all files format_all() { - yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sky tests examples + yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sky tests examples llm } +echo 'SkyPilot Black:' +black "${BLACK_INCLUDES[@]}" + ## This flag formats individual files. --files *must* be the first command line ## arg to use this option. if [[ "$1" == '--files' ]]; then @@ -102,8 +113,12 @@ else format_changed fi echo 'SkyPilot yapf: Done' -echo 'SkyPilot Black:' -black "${BLACK_INCLUDES[@]}" + +echo 'SkyPilot isort:' +isort sky tests examples llm docs "${ISORT_YAPF_EXCLUDES[@]}" + +isort --profile black -l 88 -m 3 "sky/skylet/providers/ibm" + # Run mypy # TODO(zhwu): When more of the codebase is typed properly, the mypy flags diff --git a/llm/vicuna-llama-2/scripts/flash_attn_patch.py b/llm/vicuna-llama-2/scripts/flash_attn_patch.py index 8ba2ea4a7ca..6839646307f 100644 --- a/llm/vicuna-llama-2/scripts/flash_attn_patch.py +++ b/llm/vicuna-llama-2/scripts/flash_attn_patch.py @@ -1,17 +1,16 @@ -from typing import List, Optional, Tuple import logging +from typing import List, Optional, Tuple +from einops import rearrange +from flash_attn.bert_padding import pad_input +from flash_attn.bert_padding import unpad_input +# pip3 install "flash-attn>=2.0" +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func import torch from torch import nn - import transformers from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from einops import rearrange -from flash_attn.flash_attn_interface import ( # pip3 install "flash-attn>=2.0" - flash_attn_varlen_qkvpacked_func,) -from flash_attn.bert_padding import unpad_input, pad_input - def forward( self, diff --git a/llm/vicuna-llama-2/scripts/train.py b/llm/vicuna-llama-2/scripts/train.py index c1030013d72..9112a8dc527 100644 --- a/llm/vicuna-llama-2/scripts/train.py +++ b/llm/vicuna-llama-2/scripts/train.py @@ -30,23 +30,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import dataclass +from dataclasses import field import json -import pathlib import os +import pathlib import shutil import subprocess from typing import Dict, Optional +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template import torch from torch.utils.data import Dataset import transformers from transformers import Trainer from transformers.trainer_pt_utils import LabelSmoother -from fastchat.conversation import SeparatorStyle -from fastchat.model.model_adapter import get_conversation_template - IGNORE_TOKEN_ID = LabelSmoother.ignore_index diff --git a/llm/vicuna-llama-2/scripts/train_flash_attn.py b/llm/vicuna-llama-2/scripts/train_flash_attn.py index 22769c29156..a68657e23e1 100644 --- a/llm/vicuna-llama-2/scripts/train_flash_attn.py +++ b/llm/vicuna-llama-2/scripts/train_flash_attn.py @@ -1,8 +1,7 @@ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. -from flash_attn_patch import ( - replace_llama_attn_with_flash_attn,) +from flash_attn_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() diff --git a/llm/vicuna-llama-2/scripts/train_xformers.py b/llm/vicuna-llama-2/scripts/train_xformers.py index f0544fc7b9a..461df636156 100644 --- a/llm/vicuna-llama-2/scripts/train_xformers.py +++ b/llm/vicuna-llama-2/scripts/train_xformers.py @@ -16,8 +16,7 @@ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. -from xformers_patch import ( - replace_llama_attn_with_xformers_attn,) +from xformers_patch import replace_llama_attn_with_xformers_attn replace_llama_attn_with_xformers_attn() diff --git a/llm/vicuna-llama-2/scripts/xformers_patch.py b/llm/vicuna-llama-2/scripts/xformers_patch.py index 9ffde5fe9c3..2a9db753cd1 100644 --- a/llm/vicuna-llama-2/scripts/xformers_patch.py +++ b/llm/vicuna-llama-2/scripts/xformers_patch.py @@ -21,8 +21,8 @@ from typing import Optional, Tuple import torch -import transformers.models.llama.modeling_llama from torch import nn +import transformers.models.llama.modeling_llama try: import xformers.ops diff --git a/pyproject.toml b/pyproject.toml index 97ada381264..39d1cc344bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,3 +22,10 @@ python_version = "3.8" follow_imports = "skip" ignore_missing_imports = true allow_redefinition = true + +[tool.isort] +profile = "google" +line_length = 80 +multi_line_output = 0 +combine_as_imports = true +use_parentheses = true diff --git a/requirements-dev.txt b/requirements-dev.txt index 9e8427248ea..1bccad92c30 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,6 +6,7 @@ black==22.10.0 # https://github.com/edaniszewski/pylint-quotes pylint-quotes==0.2.3 toml==0.10.2 +isort==5.12.0 # type checking mypy==0.991 diff --git a/sky/__init__.py b/sky/__init__.py index c814cee3e62..8e5007bb38b 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -11,18 +11,35 @@ from sky import benchmark from sky import clouds from sky.clouds.service_catalog import list_accelerators +from sky.core import autostop +from sky.core import cancel +from sky.core import cost_report +from sky.core import down +from sky.core import download_logs +from sky.core import job_status +from sky.core import queue +from sky.core import spot_cancel +from sky.core import spot_queue +from sky.core import spot_status +from sky.core import start +from sky.core import status +from sky.core import stop +from sky.core import storage_delete +from sky.core import storage_ls +from sky.core import tail_logs from sky.dag import Dag -from sky.execution import launch, exec, spot_launch # pylint: disable=redefined-builtin +from sky.data import Storage +from sky.data import StorageMode +from sky.data import StoreType +from sky.execution import exec # pylint: disable=redefined-builtin +from sky.execution import launch +from sky.execution import spot_launch +from sky.optimizer import Optimizer +from sky.optimizer import OptimizeTarget from sky.resources import Resources -from sky.task import Task -from sky.optimizer import Optimizer, OptimizeTarget -from sky.data import Storage, StorageMode, StoreType -from sky.status_lib import ClusterStatus from sky.skylet.job_lib import JobStatus -from sky.core import (status, start, stop, down, autostop, queue, cancel, - tail_logs, download_logs, job_status, spot_queue, - spot_status, spot_cancel, storage_ls, storage_delete, - cost_report) +from sky.status_lib import ClusterStatus +from sky.task import Task # Aliases. IBM = clouds.IBM diff --git a/sky/adaptors/cloudflare.py b/sky/adaptors/cloudflare.py index dd75a6d9fc1..b70d04769ed 100644 --- a/sky/adaptors/cloudflare.py +++ b/sky/adaptors/cloudflare.py @@ -3,8 +3,8 @@ import contextlib import functools -import threading import os +import threading from typing import Dict, Optional, Tuple from sky.utils import ux_utils diff --git a/sky/adaptors/gcp.py b/sky/adaptors/gcp.py index 477fe6b78ec..6e611ee1f2b 100644 --- a/sky/adaptors/gcp.py +++ b/sky/adaptors/gcp.py @@ -14,8 +14,8 @@ def wrapper(*args, **kwargs): global googleapiclient, google if googleapiclient is None or google is None: try: - import googleapiclient as _googleapiclient import google as _google + import googleapiclient as _googleapiclient googleapiclient = _googleapiclient google = _google except ImportError: diff --git a/sky/adaptors/ibm.py b/sky/adaptors/ibm.py index 9cffdca89f7..5a2b4990fe4 100644 --- a/sky/adaptors/ibm.py +++ b/sky/adaptors/ibm.py @@ -2,13 +2,15 @@ # pylint: disable=import-outside-toplevel -from sky import sky_logging -import yaml -import os -import json -import requests import functools +import json import multiprocessing +import os + +import requests +import yaml + +from sky import sky_logging CREDENTIAL_FILE = '~/.ibm/credentials.yaml' logger = sky_logging.init_logger(__name__) @@ -28,11 +30,11 @@ def wrapper(*args, **kwargs): global ibm_boto3, ibm_botocore if None in [ibm_vpc, ibm_cloud_sdk_core, ibm_platform_services]: try: - import ibm_vpc as _ibm_vpc - import ibm_cloud_sdk_core as _ibm_cloud_sdk_core - import ibm_platform_services as _ibm_platform_services import ibm_boto3 as _ibm_boto3 import ibm_botocore as _ibm_botocore + import ibm_cloud_sdk_core as _ibm_cloud_sdk_core + import ibm_platform_services as _ibm_platform_services + import ibm_vpc as _ibm_vpc ibm_vpc = _ibm_vpc ibm_cloud_sdk_core = _ibm_cloud_sdk_core ibm_platform_services = _ibm_platform_services diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py index 79daa6f2434..f7b349e384b 100644 --- a/sky/adaptors/kubernetes.py +++ b/sky/adaptors/kubernetes.py @@ -5,7 +5,8 @@ import functools import os -from sky.utils import ux_utils, env_options +from sky.utils import env_options +from sky.utils import ux_utils kubernetes = None urllib3 = None diff --git a/sky/authentication.py b/sky/authentication.py index d5aa2ff1787..27029b982de 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -30,18 +30,19 @@ import uuid import colorama +from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.backends import default_backend import yaml from sky import clouds from sky import sky_logging -from sky.adaptors import gcp, ibm +from sky.adaptors import gcp +from sky.adaptors import ibm +from sky.skylet.providers.lambda_cloud import lambda_utils from sky.utils import common_utils from sky.utils import subprocess_utils from sky.utils import ux_utils -from sky.skylet.providers.lambda_cloud import lambda_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/backends/__init__.py b/sky/backends/__init__.py index 81007e4e201..4e97cb70e0d 100644 --- a/sky/backends/__init__.py +++ b/sky/backends/__init__.py @@ -1,7 +1,10 @@ """Sky Backends.""" -from sky.backends.backend import Backend, ResourceHandle -from sky.backends.cloud_vm_ray_backend import CloudVmRayBackend, CloudVmRayResourceHandle -from sky.backends.local_docker_backend import LocalDockerBackend, LocalDockerResourceHandle +from sky.backends.backend import Backend +from sky.backends.backend import ResourceHandle +from sky.backends.cloud_vm_ray_backend import CloudVmRayBackend +from sky.backends.cloud_vm_ray_backend import CloudVmRayResourceHandle +from sky.backends.local_docker_backend import LocalDockerBackend +from sky.backends.local_docker_backend import LocalDockerResourceHandle __all__ = [ 'Backend', 'ResourceHandle', 'CloudVmRayBackend', diff --git a/sky/backends/backend.py b/sky/backends/backend.py index 1dbc4ad00f3..28aa981b078 100644 --- a/sky/backends/backend.py +++ b/sky/backends/backend.py @@ -3,8 +3,8 @@ from typing import Dict, Generic, Optional import sky -from sky.utils import timeline from sky.usage import usage_lib +from sky.utils import timeline if typing.TYPE_CHECKING: from sky import resources diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 4e734f7e0fb..10c49445f4f 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -12,8 +12,7 @@ import textwrap import time import typing -from typing import (Any, Dict, List, Optional, Sequence, Set, Tuple, Union) -from typing_extensions import Literal +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import uuid import colorama @@ -25,6 +24,7 @@ from requests import adapters from requests.packages.urllib3.util import retry as retry_lib import rich.progress as rich_progress +from typing_extensions import Literal import yaml import sky @@ -35,15 +35,16 @@ from sky import exceptions from sky import global_user_state from sky import provision as provision_lib -from sky import skypilot_config from sky import sky_logging +from sky import skypilot_config from sky import spot as spot_lib from sky import status_lib from sky.backends import onprem_utils from sky.skylet import constants from sky.skylet import log_lib -from sky.utils import common_utils +from sky.usage import usage_lib from sky.utils import command_runner +from sky.utils import common_utils from sky.utils import env_options from sky.utils import log_utils from sky.utils import subprocess_utils @@ -51,7 +52,6 @@ from sky.utils import tpu_utils from sky.utils import ux_utils from sky.utils import validator -from sky.usage import usage_lib if typing.TYPE_CHECKING: from sky import resources @@ -1108,7 +1108,8 @@ def write_cluster_config( user_file_dir = os.path.expanduser(f'{SKY_USER_FILE_PATH}/') - from sky.skylet.providers.gcp import config as gcp_config # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from sky.skylet.providers.gcp import config as gcp_config config = common_utils.read_yaml(os.path.expanduser(config_dict['ray'])) vpc_name = gcp_config.get_usable_vpc(config) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index aa4d534d650..bed4413e1a1 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -4,8 +4,8 @@ import enum import getpass import inspect -import math import json +import math import os import pathlib import re @@ -17,43 +17,44 @@ import threading import time import typing -from typing import Dict, Iterable, List, Optional, Tuple, Union, Set +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import colorama import filelock import sky from sky import backends -from sky import clouds from sky import cloud_stores +from sky import clouds from sky import exceptions from sky import global_user_state +from sky import optimizer from sky import provision as provision_lib from sky import resources as resources_lib from sky import sky_logging -from sky import optimizer from sky import skypilot_config from sky import spot as spot_lib from sky import status_lib from sky import task as task_lib -from sky.data import data_utils -from sky.data import storage as storage_lib from sky.backends import backend_utils from sky.backends import onprem_utils from sky.backends import wheel_utils +from sky.data import data_utils +from sky.data import storage as storage_lib from sky.skylet import autostop_lib from sky.skylet import constants from sky.skylet import job_lib from sky.skylet import log_lib +from sky.skylet.providers.scp.node_provider import SCPError +from sky.skylet.providers.scp.node_provider import SCPNodeProvider from sky.usage import usage_lib -from sky.utils import common_utils from sky.utils import command_runner +from sky.utils import common_utils from sky.utils import log_utils from sky.utils import subprocess_utils from sky.utils import timeline from sky.utils import tpu_utils from sky.utils import ux_utils -from sky.skylet.providers.scp.node_provider import SCPNodeProvider, SCPError if typing.TYPE_CHECKING: from sky import dag @@ -3647,9 +3648,10 @@ def teardown_no_lock(self, region = config['provider']['region'] # pylint: disable=import-outside-toplevel - from sky.skylet.providers.oci.query_helper import oci_query_helper from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME + from sky.skylet.providers.oci.query_helper import oci_query_helper + # 0: All terminated successfully, failed count otherwise returncode = oci_query_helper.terminate_instances_by_tags( {TAG_RAY_CLUSTER_NAME: cluster_name}, region) diff --git a/sky/backends/docker_utils.py b/sky/backends/docker_utils.py index 58a10f6c34c..bedc4e419dd 100644 --- a/sky/backends/docker_utils.py +++ b/sky/backends/docker_utils.py @@ -8,9 +8,9 @@ import colorama -from sky.adaptors import docker from sky import sky_logging from sky import task as task_mod +from sky.adaptors import docker logger = sky_logging.init_logger(__name__) diff --git a/sky/backends/local_docker_backend.py b/sky/backends/local_docker_backend.py index 9b2f2651705..0e5cb713fb6 100644 --- a/sky/backends/local_docker_backend.py +++ b/sky/backends/local_docker_backend.py @@ -8,9 +8,9 @@ from rich import console as rich_console from sky import backends -from sky.adaptors import docker from sky import global_user_state from sky import sky_logging +from sky.adaptors import docker from sky.backends import backend_utils from sky.backends import docker_utils from sky.data import storage as storage_lib diff --git a/sky/benchmark/benchmark_utils.py b/sky/benchmark/benchmark_utils.py index 0b260f194cd..1f3d4722a8f 100644 --- a/sky/benchmark/benchmark_utils.py +++ b/sky/benchmark/benchmark_utils.py @@ -28,10 +28,10 @@ from sky.backends import backend_utils from sky.benchmark import benchmark_state from sky.skylet import constants -from sky.skylet import log_lib from sky.skylet import job_lib -from sky.utils import log_utils +from sky.skylet import log_lib from sky.utils import common_utils +from sky.utils import log_utils from sky.utils import subprocess_utils from sky.utils import ux_utils diff --git a/sky/callbacks/sky_callback/__init__.py b/sky/callbacks/sky_callback/__init__.py index d8ed4a4a538..a76a03e5cc7 100644 --- a/sky/callbacks/sky_callback/__init__.py +++ b/sky/callbacks/sky_callback/__init__.py @@ -1,10 +1,8 @@ -from sky_callback.api import ( - init, - step_begin, - step_end, - step, - step_iterator, -) +from sky_callback.api import init +from sky_callback.api import step +from sky_callback.api import step_begin +from sky_callback.api import step_end +from sky_callback.api import step_iterator from sky_callback.base import BaseCallback from sky_callback.utils import CallbackLoader as _CallbackLoader diff --git a/sky/callbacks/sky_callback/integrations/keras.py b/sky/callbacks/sky_callback/integrations/keras.py index e0aa098d7a4..d90d431b932 100644 --- a/sky/callbacks/sky_callback/integrations/keras.py +++ b/sky/callbacks/sky_callback/integrations/keras.py @@ -1,9 +1,10 @@ """SkyCallback integration with Keras.""" from typing import Dict, Optional -import tensorflow as tf from tensorflow import keras +import tensorflow as tf +# isort: split from sky_callback import base from sky_callback import utils diff --git a/sky/callbacks/sky_callback/integrations/pytorch_lightning.py b/sky/callbacks/sky_callback/integrations/pytorch_lightning.py index f549b8da3e5..e813da49588 100644 --- a/sky/callbacks/sky_callback/integrations/pytorch_lightning.py +++ b/sky/callbacks/sky_callback/integrations/pytorch_lightning.py @@ -2,7 +2,6 @@ from typing import Any, Optional import pytorch_lightning as pl - from sky_callback import base from sky_callback import utils diff --git a/sky/callbacks/sky_callback/integrations/transformers.py b/sky/callbacks/sky_callback/integrations/transformers.py index 6087d3c45b1..ce728efbc4c 100644 --- a/sky/callbacks/sky_callback/integrations/transformers.py +++ b/sky/callbacks/sky_callback/integrations/transformers.py @@ -3,6 +3,7 @@ import transformers +# isort: split from sky_callback import base from sky_callback import utils diff --git a/sky/callbacks/sky_callback/utils.py b/sky/callbacks/sky_callback/utils.py index 5cbde96b1d1..09859a17001 100644 --- a/sky/callbacks/sky_callback/utils.py +++ b/sky/callbacks/sky_callback/utils.py @@ -16,11 +16,13 @@ def keras(log_dir: Optional[str] = None, total_steps: Optional[int] = None): @staticmethod def pytorch_lightning(log_dir: Optional[str] = None, total_steps: Optional[int] = None): - from sky_callback.integrations.pytorch_lightning import SkyLightningCallback + from sky_callback.integrations.pytorch_lightning import ( + SkyLightningCallback) return SkyLightningCallback(log_dir=log_dir, total_steps=total_steps) @staticmethod def transformers(log_dir: Optional[str] = None, total_steps: Optional[int] = None): - from sky_callback.integrations.transformers import SkyTransformersCallback + from sky_callback.integrations.transformers import ( + SkyTransformersCallback) return SkyTransformersCallback(log_dir=log_dir, total_steps=total_steps) diff --git a/sky/cli.py b/sky/cli.py index 1345914c906..91939a06c94 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -62,18 +62,18 @@ from sky.data import storage_utils from sky.skylet import constants from sky.skylet import job_lib -from sky.utils import common_utils +from sky.skylet.providers.kubernetes import utils as kubernetes_utils +from sky.usage import usage_lib from sky.utils import command_runner +from sky.utils import common_utils from sky.utils import dag_utils from sky.utils import env_options -from sky.utils import kubernetes_utils from sky.utils import log_utils from sky.utils import schemas from sky.utils import subprocess_utils from sky.utils import timeline from sky.utils import ux_utils from sky.utils.cli_utils import status_utils -from sky.usage import usage_lib if typing.TYPE_CHECKING: from sky.backends import backend as backend_lib diff --git a/sky/cloud_stores.py b/sky/cloud_stores.py index 85a0f80ff99..db20b531cb8 100644 --- a/sky/cloud_stores.py +++ b/sky/cloud_stores.py @@ -10,9 +10,11 @@ import subprocess import urllib.parse +from sky.adaptors import aws +from sky.adaptors import cloudflare +from sky.adaptors import ibm from sky.clouds import gcp from sky.data import data_utils -from sky.adaptors import aws, cloudflare, ibm from sky.data.data_utils import Rclone diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py index 74fc0089d5d..d3d8aab0d9f 100644 --- a/sky/clouds/__init__.py +++ b/sky/clouds/__init__.py @@ -1,18 +1,21 @@ """Clouds in Sky.""" from sky.clouds.cloud import Cloud -from sky.clouds.cloud_registry import CLOUD_REGISTRY from sky.clouds.cloud import CloudImplementationFeatures from sky.clouds.cloud import Region from sky.clouds.cloud import Zone +from sky.clouds.cloud_registry import CLOUD_REGISTRY + +# NOTE: import the above first to avoid circular imports. +# isort: split from sky.clouds.aws import AWS from sky.clouds.azure import Azure from sky.clouds.gcp import GCP +from sky.clouds.ibm import IBM +from sky.clouds.kubernetes import Kubernetes from sky.clouds.lambda_cloud import Lambda from sky.clouds.local import Local -from sky.clouds.ibm import IBM -from sky.clouds.scp import SCP from sky.clouds.oci import OCI -from sky.clouds.kubernetes import Kubernetes +from sky.clouds.scp import SCP __all__ = [ 'IBM', diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 52dca3f699b..223d257d1a6 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -7,7 +7,7 @@ import subprocess import time import typing -from typing import Dict, Iterator, List, Optional, Tuple, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple from sky import clouds from sky import exceptions @@ -477,7 +477,9 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: cls._STATIC_CREDENTIAL_HELP_STR) # Fetch the AWS catalogs - from sky.clouds.service_catalog import aws_catalog # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from sky.clouds.service_catalog import aws_catalog + # Trigger the fetch of the availability zones mapping. aws_catalog.get_default_instance_type() return True, hints @@ -724,7 +726,8 @@ def check_quota_available(cls, region = resources.region use_spot = resources.use_spot - from sky.clouds.service_catalog import aws_catalog # pylint: disable=import-outside-toplevel,unused-import + # pylint: disable=import-outside-toplevel,unused-import + from sky.clouds.service_catalog import aws_catalog quota_code = aws_catalog.get_quota_code(instance_type, use_spot) diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 055357c8e2a..fbddf3a5425 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -240,7 +240,8 @@ def make_deploy_resources_variables( custom_resources = json.dumps(acc_dict, separators=(',', ':')) else: custom_resources = None - from sky.clouds.service_catalog import azure_catalog # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from sky.clouds.service_catalog import azure_catalog gen_version = azure_catalog.get_gen_version_from_instance_type( r.instance_type) image_config = self._get_image_config(gen_version, r.instance_type) diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index e228bbd952e..f5294f48236 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -12,8 +12,8 @@ from sky.utils import ux_utils if typing.TYPE_CHECKING: - from sky import status_lib from sky import resources as resources_lib + from sky import status_lib class CloudImplementationFeatures(enum.Enum): diff --git a/sky/clouds/cloud_registry.py b/sky/clouds/cloud_registry.py index db9a84e6e17..7cf815c5ee2 100644 --- a/sky/clouds/cloud_registry.py +++ b/sky/clouds/cloud_registry.py @@ -5,6 +5,7 @@ from typing import Optional, Type from sky.utils import ux_utils + if typing.TYPE_CHECKING: import sky diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index f2b421c5262..d77d9cdae36 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -1,5 +1,6 @@ """Google Cloud Platform.""" import dataclasses +import datetime import functools import json import os @@ -7,8 +8,7 @@ import subprocess import time import typing -import datetime -from typing import Dict, Iterator, List, Optional, Tuple, Set +from typing import Dict, Iterator, List, Optional, Set, Tuple import cachetools @@ -694,8 +694,8 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: """Checks if the user has access credentials to this cloud.""" try: # pylint: disable=import-outside-toplevel,unused-import - from google import auth # type: ignore # Check google-api-python-client installation. + from google import auth # type: ignore import googleapiclient # Check the installation of google-cloud-sdk. @@ -799,8 +799,9 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: 'some time.') # pylint: disable=import-outside-toplevel,unused-import - import googleapiclient.discovery import google.auth + import googleapiclient.discovery + from sky.skylet.providers.gcp import constants # This takes user's credential info from "~/.config/gcloud/application_default_credentials.json". # pylint: disable=line-too-long @@ -984,7 +985,8 @@ def check_quota_available(cls, resources: 'resources.Resources') -> bool: use_spot = resources.use_spot region = resources.region - from sky.clouds.service_catalog import gcp_catalog # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from sky.clouds.service_catalog import gcp_catalog quota_code = gcp_catalog.get_quota_code(accelerator, use_spot) @@ -1025,7 +1027,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], """Query the status of a cluster.""" del region # unused - from sky.utils import tpu_utils # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from sky.utils import tpu_utils use_tpu_vm = kwargs.pop('use_tpu_vm', False) label_filter_str = cls._label_filter_str(tag_filters) diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index c50e3479b58..d6b41454bca 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -1,10 +1,11 @@ """IBM Web Services.""" -import colorama -import os import json +import os import typing from typing import Any, Dict, Iterator, List, Optional, Tuple +import colorama + from sky import clouds from sky import sky_logging from sky import status_lib diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 87c7244539d..0dd925e4863 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -9,9 +9,9 @@ from sky import exceptions from sky import status_lib from sky.adaptors import kubernetes +from sky.skylet.providers.kubernetes import utils as kubernetes_utils from sky.utils import common_utils from sky.utils import ux_utils -from sky.skylet.providers.kubernetes import utils as kubernetes_utils if typing.TYPE_CHECKING: # Renaming to avoid shadowing variables. diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 0e5ae888efb..487e2328262 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -1,9 +1,10 @@ """Lambda Cloud.""" import json -import requests import typing from typing import Dict, Iterator, List, Optional, Tuple +import requests + from sky import clouds from sky import exceptions from sky import status_lib diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index db24461c997..452d10692a0 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -6,20 +6,20 @@ - Hysun He (hysun.he@oracle.com) @ May 4, 2023: Support use the default image_id (configurable) if no image_id specified in the task yaml. """ -import os import json -import typing import logging +import os +import typing from typing import Dict, Iterator, List, Optional, Tuple from sky import clouds from sky import exceptions from sky import status_lib +from sky.adaptors import oci as oci_adaptor from sky.clouds import service_catalog +from sky.skylet.providers.oci.config import oci_conf from sky.utils import common_utils from sky.utils import ux_utils -from sky.adaptors import oci as oci_adaptor -from sky.skylet.providers.oci.config import oci_conf if typing.TYPE_CHECKING: # Renaming to avoid shadowing variables. diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index 91daf39c4f2..96b321bab65 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -9,11 +9,11 @@ from typing import Dict, Iterator, List, Optional, Tuple from sky import clouds +from sky import exceptions +from sky import sky_logging from sky import status_lib from sky.clouds import service_catalog from sky.skylet.providers.scp import scp_utils -from sky import exceptions -from sky import sky_logging if typing.TYPE_CHECKING: # Renaming to avoid shadowing variables. diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index c47670f7f1a..53a70bad9f6 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -4,12 +4,10 @@ import typing from typing import Dict, List, Optional, Set, Tuple, Union -from sky.clouds.service_catalog.constants import ( - HOSTED_CATALOG_DIR_URL, - CATALOG_SCHEMA_VERSION, - LOCAL_CATALOG_DIR, -) from sky.clouds.service_catalog.config import use_default_catalog +from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION +from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL +from sky.clouds.service_catalog.constants import LOCAL_CATALOG_DIR if typing.TYPE_CHECKING: from sky.clouds import cloud diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 72c74afc80c..a04ab55897b 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -3,7 +3,6 @@ This module loads the service catalog file and can be used to query instance types and pricing information for AWS. """ -import colorama import glob import hashlib import os @@ -11,6 +10,7 @@ import typing from typing import Dict, List, Optional, Tuple +import colorama import pandas as pd from sky import exceptions diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 456d0f6236d..cd1297e673e 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -1,14 +1,14 @@ """Common utilities for service catalog.""" import ast +import difflib import hashlib import os import time from typing import Dict, List, NamedTuple, Optional, Tuple -import difflib import filelock -import requests import pandas as pd +import requests from sky import sky_logging from sky.clouds import cloud as cloud_lib diff --git a/sky/clouds/service_catalog/data_fetchers/analyze.py b/sky/clouds/service_catalog/data_fetchers/analyze.py index eab32f30c8c..54fe9991a7e 100644 --- a/sky/clouds/service_catalog/data_fetchers/analyze.py +++ b/sky/clouds/service_catalog/data_fetchers/analyze.py @@ -1,6 +1,7 @@ """Analyze the new catalog fetched with the original.""" from typing import List + import pandas as pd from sky.clouds.service_catalog import common diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_aws.py b/sky/clouds/service_catalog/data_fetchers/fetch_aws.py index ddedd1e56d1..006e9608c69 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_aws.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_aws.py @@ -6,8 +6,8 @@ import itertools from multiprocessing import pool as mp_pool import os -import sys import subprocess +import sys from typing import Dict, List, Optional, Set, Tuple, Union import numpy as np diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py b/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py index 65aa7b023fd..07053a89ba0 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py @@ -11,6 +11,7 @@ import csv import json import os + import requests ENDPOINT = 'https://cloud.lambdalabs.com/api/v1/instance-types' diff --git a/sky/clouds/service_catalog/ibm_catalog.py b/sky/clouds/service_catalog/ibm_catalog.py index 7813c952b2d..9f505bed896 100644 --- a/sky/clouds/service_catalog/ibm_catalog.py +++ b/sky/clouds/service_catalog/ibm_catalog.py @@ -5,11 +5,12 @@ instance types and pricing information for IBM. """ +from typing import Dict, List, Optional, Tuple + from sky import sky_logging +from sky.adaptors import ibm from sky.clouds import cloud from sky.clouds.service_catalog import common -from sky.adaptors import ibm -from typing import Dict, List, Optional, Tuple logger = sky_logging.init_logger(__name__) diff --git a/sky/clouds/service_catalog/oci_catalog.py b/sky/clouds/service_catalog/oci_catalog.py index 06652116768..0ad44dbbf8d 100644 --- a/sky/clouds/service_catalog/oci_catalog.py +++ b/sky/clouds/service_catalog/oci_catalog.py @@ -9,18 +9,20 @@ excluding those unsubscribed regions. """ -import typing import logging import threading +import typing from typing import Dict, List, Optional, Tuple + +from sky.adaptors import oci as oci_adaptor from sky.clouds.service_catalog import common from sky.skylet.providers.oci.config import oci_conf -from sky.adaptors import oci as oci_adaptor if typing.TYPE_CHECKING: - from sky.clouds import cloud import pandas as pd + from sky.clouds import cloud # pylint: disable=ungrouped-imports + logger = logging.getLogger(__name__) _df = None diff --git a/sky/core.py b/sky/core.py index 64a3161a943..ce18f5e516b 100644 --- a/sky/core.py +++ b/sky/core.py @@ -5,24 +5,24 @@ import colorama +from sky import backends from sky import clouds from sky import dag -from sky import task -from sky import backends from sky import data from sky import exceptions from sky import global_user_state from sky import sky_logging from sky import spot from sky import status_lib +from sky import task from sky.backends import backend_utils from sky.skylet import constants from sky.skylet import job_lib from sky.usage import usage_lib from sky.utils import log_utils +from sky.utils import subprocess_utils from sky.utils import tpu_utils from sky.utils import ux_utils -from sky.utils import subprocess_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/data/__init__.py b/sky/data/__init__.py index b0b97b8753b..653e6c58017 100644 --- a/sky/data/__init__.py +++ b/sky/data/__init__.py @@ -1,4 +1,6 @@ """Sky Data.""" -from sky.data.storage import Storage, StorageMode, StoreType +from sky.data.storage import Storage +from sky.data.storage import StorageMode +from sky.data.storage import StoreType __all__ = ['Storage', 'StorageMode', 'StoreType'] diff --git a/sky/data/data_transfer.py b/sky/data/data_transfer.py index fad2b7d1224..d56bc49ceba 100644 --- a/sky/data/data_transfer.py +++ b/sky/data/data_transfer.py @@ -24,7 +24,8 @@ from sky import clouds from sky import sky_logging -from sky.adaptors import aws, gcp +from sky.adaptors import aws +from sky.adaptors import gcp from sky.data import data_utils from sky.utils import log_utils from sky.utils import ux_utils @@ -45,7 +46,8 @@ def s3_to_gcs(s3_bucket_name: str, gs_bucket_name: str) -> None: s3_bucket_name: str; Name of the Amazon S3 Bucket gs_bucket_name: str; Name of the Google Cloud Storage Bucket """ - from oauth2client.client import GoogleCredentials # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from oauth2client.client import GoogleCredentials oauth_credentials = GoogleCredentials.get_application_default() storagetransfer = gcp.build('storagetransfer', diff --git a/sky/data/data_utils.py b/sky/data/data_utils.py index 0882f755e4b..0fa7a2ee40b 100644 --- a/sky/data/data_utils.py +++ b/sky/data/data_utils.py @@ -4,17 +4,20 @@ from enum import Enum from multiprocessing import pool import os +import re import subprocess import textwrap from typing import Any, Callable, Dict, List, Optional, Tuple import urllib.parse -import re from filelock import FileLock from sky import exceptions from sky import sky_logging -from sky.adaptors import aws, gcp, cloudflare, ibm +from sky.adaptors import aws +from sky.adaptors import cloudflare +from sky.adaptors import gcp +from sky.adaptors import ibm from sky.utils import ux_utils Client = Any diff --git a/sky/data/storage.py b/sky/data/storage.py index 9634ce774bd..34f14c48eff 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -12,22 +12,22 @@ from sky import check from sky import clouds +from sky import exceptions +from sky import global_user_state +from sky import sky_logging +from sky import status_lib from sky.adaptors import aws -from sky.adaptors import gcp from sky.adaptors import cloudflare +from sky.adaptors import gcp from sky.adaptors import ibm from sky.backends import backend_utils -from sky.utils import schemas from sky.data import data_transfer from sky.data import data_utils from sky.data import mounting_utils -from sky.data.data_utils import Rclone from sky.data import storage_utils -from sky import exceptions -from sky import global_user_state -from sky import sky_logging -from sky import status_lib +from sky.data.data_utils import Rclone from sky.utils import log_utils +from sky.utils import schemas from sky.utils import ux_utils if typing.TYPE_CHECKING: diff --git a/sky/data/storage_utils.py b/sky/data/storage_utils.py index 8ee493f66b6..044e00f5aeb 100644 --- a/sky/data/storage_utils.py +++ b/sky/data/storage_utils.py @@ -1,9 +1,10 @@ """Utility functions for the storage module.""" -import colorama import os import subprocess from typing import Any, Dict, List +import colorama + from sky import exceptions from sky import sky_logging from sky.utils import log_utils diff --git a/sky/execution.py b/sky/execution.py index 4abaf886ac6..783dbe2adb9 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -15,10 +15,10 @@ import copy import enum import getpass -import tempfile import os -import uuid +import tempfile from typing import Any, Dict, List, Optional, Union +import uuid import colorama @@ -28,21 +28,22 @@ from sky import exceptions from sky import global_user_state from sky import optimizer -from sky import skypilot_config from sky import sky_logging +from sky import skypilot_config from sky import spot from sky import task as task_lib from sky.backends import backend_utils from sky.clouds import gcp from sky.data import data_utils from sky.data import storage as storage_lib -from sky.usage import usage_lib from sky.skylet import constants +from sky.usage import usage_lib from sky.utils import common_utils from sky.utils import dag_utils +from sky.utils import env_options from sky.utils import log_utils -from sky.utils import env_options, timeline from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 37d7c9ba903..68ea52a488c 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -13,11 +13,11 @@ import sqlite3 import time import typing -from typing import Any, Dict, List, Tuple, Optional, Set +from typing import Any, Dict, List, Optional, Set, Tuple import uuid -from sky import status_lib from sky import clouds +from sky import status_lib from sky.adaptors import cloudflare from sky.data import storage as storage_lib from sky.utils import common_utils diff --git a/sky/optimizer.py b/sky/optimizer.py index 8d8e5682d05..49349615879 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -18,12 +18,14 @@ from sky import task as task_lib from sky.backends import backend_utils from sky.utils import env_options -from sky.utils import ux_utils from sky.utils import log_utils +from sky.utils import ux_utils if typing.TYPE_CHECKING: import networkx as nx - from sky import dag as dag_lib # pylint: disable=ungrouped-imports + + #pylint: disable=ungrouped-imports + from sky import dag as dag_lib logger = sky_logging.init_logger(__name__) @@ -829,6 +831,7 @@ def _optimize_objective( the total estimated cost/time of the DAG becomes the minimum. """ import networkx as nx # pylint: disable=import-outside-toplevel + # TODO: The output of this function is useful. Should generate a # text plan and print to both console and a log file. diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index 4465199f09e..ed23032d18c 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -3,11 +3,10 @@ This module provides a standard low-level interface that all providers supported by SkyPilot need to follow. """ -from typing import Any, Dict, List, Optional - import functools import importlib import inspect +from typing import Any, Dict, List, Optional from sky import status_lib diff --git a/sky/provision/aws/__init__.py b/sky/provision/aws/__init__.py index 8868e4b5fe2..c7c54fe617f 100644 --- a/sky/provision/aws/__init__.py +++ b/sky/provision/aws/__init__.py @@ -1,4 +1,6 @@ """AWS provisioner for SkyPilot.""" -from sky.provision.aws.instance import (cleanup_ports, query_instances, - terminate_instances, stop_instances) +from sky.provision.aws.instance import cleanup_ports +from sky.provision.aws.instance import query_instances +from sky.provision.aws.instance import stop_instances +from sky.provision.aws.instance import terminate_instances diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index 0300a0f7b37..395adf9abb8 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -1,11 +1,11 @@ """AWS instance provisioning.""" -from typing import Dict, List, Any, Optional +from typing import Any, Dict, List, Optional from botocore import config +from sky import status_lib from sky.adaptors import aws from sky.utils import common_utils -from sky import status_lib BOTO_MAX_RETRIES = 12 # Tag uniquely identifying all nodes of a cluster diff --git a/sky/provision/gcp/__init__.py b/sky/provision/gcp/__init__.py index d2f335acacf..3fbb197adad 100644 --- a/sky/provision/gcp/__init__.py +++ b/sky/provision/gcp/__init__.py @@ -1,3 +1,5 @@ """GCP provisioner for SkyPilot.""" -from sky.provision.gcp.instance import stop_instances, terminate_instances, cleanup_ports +from sky.provision.gcp.instance import cleanup_ports +from sky.provision.gcp.instance import stop_instances +from sky.provision.gcp.instance import terminate_instances diff --git a/sky/resources.py b/sky/resources.py index 2e9065438d1..0b06befcdca 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -1,12 +1,13 @@ """Resources: compute requirements of Tasks.""" -from typing import Dict, List, Optional, Union, Set -from typing_extensions import Literal +from typing import Dict, List, Optional, Set, Union import colorama +from typing_extensions import Literal from sky import clouds from sky import global_user_state from sky import sky_logging +from sky import skypilot_config from sky import spot from sky.backends import backend_utils from sky.skylet import constants @@ -14,7 +15,6 @@ from sky.utils import schemas from sky.utils import tpu_utils from sky.utils import ux_utils -from sky import skypilot_config logger = sky_logging.init_logger(__name__) diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 6447a9af629..27361d10ee4 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -17,8 +17,8 @@ import os import platform import re -import warnings from typing import Dict, List +import warnings import setuptools diff --git a/sky/skylet/autostop_lib.py b/sky/skylet/autostop_lib.py index 5358910ba28..91c98f89da1 100644 --- a/sky/skylet/autostop_lib.py +++ b/sky/skylet/autostop_lib.py @@ -1,10 +1,11 @@ """Autostop utilities.""" import pickle -import psutil import shlex import time from typing import List, Optional +import psutil + from sky import sky_logging from sky.skylet import configs from sky.utils import common_utils diff --git a/sky/skylet/events.py b/sky/skylet/events.py index 98ec54ed239..b15de5ed150 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -11,8 +11,10 @@ import yaml from sky import sky_logging -from sky.backends import backend_utils, cloud_vm_ray_backend -from sky.skylet import autostop_lib, job_lib +from sky.backends import backend_utils +from sky.backends import cloud_vm_ray_backend +from sky.skylet import autostop_lib +from sky.skylet import job_lib from sky.spot import spot_utils from sky.utils import common_utils @@ -197,7 +199,8 @@ def _stop_cluster(self, autostop_config): def _stop_cluster_with_new_provisioner(self, autostop_config, cluster_config, provider_name): - from sky import provision as provision_lib # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from sky import provision as provision_lib autostop_lib.set_autostopping_started() cluster_name = cluster_config['cluster_name'] diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index dc12601c0fe..9ea178877bf 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -3,10 +3,10 @@ This is a remote utility module that provides job queue functionality. """ import enum +import getpass import json import os import pathlib -import psutil import shlex import subprocess import time @@ -15,7 +15,7 @@ import colorama import filelock -import getpass +import psutil from sky import sky_logging from sky.skylet import constants @@ -251,7 +251,8 @@ def _create_ray_job_submission_client(): logger.error('Failed to import ray') raise try: - from ray import job_submission # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from ray import job_submission except ImportError: logger.error( f'Failed to import job_submission with ray=={ray.__version__}') diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index 8321f21d755..ad82ebb78e0 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -8,9 +8,9 @@ import os import subprocess import sys -import time -import textwrap import tempfile +import textwrap +import time from typing import Dict, Iterator, List, Optional, Tuple, Union import colorama diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 349b88f6922..46a752d03ef 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -3,13 +3,15 @@ This file is dynamically generated by stubgen and added with the overloaded type hints for run_with_log(), as we need to determine the return type based on the value of require_outputs. """ -from sky import sky_logging as sky_logging -from sky.skylet import constants as constants, job_lib as job_lib -from sky.utils import log_utils as log_utils import typing from typing import Dict, List, Optional, Tuple, Union + from typing_extensions import Literal +from sky import sky_logging as sky_logging +from sky.skylet import constants as constants +from sky.skylet import job_lib as job_lib +from sky.utils import log_utils as log_utils class _ProcessingArgs: log_path: str diff --git a/sky/skylet/providers/aws/node_provider.py b/sky/skylet/providers/aws/node_provider.py index 17a1d153fa5..86297b748f7 100644 --- a/sky/skylet/providers/aws/node_provider.py +++ b/sky/skylet/providers/aws/node_provider.py @@ -24,9 +24,7 @@ resource_cache, client_cache, ) -from sky.skylet.providers.command_runner import SkyDockerCommandRunner from ray.autoscaler._private.cli_logger import cli_logger, cf -from ray.autoscaler._private.command_runner import SSHCommandRunner from ray.autoscaler._private.constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES from ray.autoscaler._private.log_timer import LogTimer from ray.autoscaler.node_provider import NodeProvider @@ -716,30 +714,6 @@ def fillout_available_node_types_resources( ) return cluster_config - def get_command_runner( - self, - log_prefix, - node_id, - auth_config, - cluster_name, - process_runner, - use_internal_ip, - docker_config=None, - ): - common_args = { - "log_prefix": log_prefix, - "node_id": node_id, - "provider": self, - "auth_config": auth_config, - "cluster_name": cluster_name, - "process_runner": process_runner, - "use_internal_ip": use_internal_ip, - } - if docker_config and docker_config["container_name"] != "": - return SkyDockerCommandRunner(docker_config, **common_args) - else: - return SSHCommandRunner(**common_args) - class AWSNodeProviderV2(AWSNodeProvider): """Same as V1, except head and workers use a SkyPilot IAM role. diff --git a/sky/skylet/providers/azure/azure-vm-template.json b/sky/skylet/providers/azure/azure-vm-template.json index 52e82dc532c..03ca08944c2 100644 --- a/sky/skylet/providers/azure/azure-vm-template.json +++ b/sky/skylet/providers/azure/azure-vm-template.json @@ -117,12 +117,6 @@ "metadata": { "description": "OS disk tier." } - }, - "cloudInitSetupCommands": { - "type": "string", - "metadata": { - "description": "Base64 encoded cloud-init setup commands." - } } }, "variables": { @@ -266,8 +260,7 @@ } ] } - }, - "customData": "[parameters('cloudInitSetupCommands')]" + } }, "priority": "[parameters('priority')]", "billingProfile": "[parameters('billingProfile')]" diff --git a/sky/skylet/providers/azure/config.py b/sky/skylet/providers/azure/config.py index f8738d95c68..a937102f579 100644 --- a/sky/skylet/providers/azure/config.py +++ b/sky/skylet/providers/azure/config.py @@ -81,32 +81,6 @@ def _configure_resource_group(config): with open(template_path, "r") as template_fp: template = json.load(template_fp) - # Setup firewall rules for ports - nsg_resource = None - for resource in template["resources"]: - if resource["type"] == "Microsoft.Network/networkSecurityGroups": - nsg_resource = resource - break - assert nsg_resource is not None, "Could not find NSG resource in template" - ports = config["provider"].get("ports", None) - if ports is not None: - ports = [str(port) for port in ports if port != 22] - nsg_resource["properties"]["securityRules"].append( - { - "name": "user-ports", - "properties": { - "priority": 1001, - "protocol": "TCP", - "access": "Allow", - "direction": "Inbound", - "sourceAddressPrefix": "*", - "sourcePortRange": "*", - "destinationAddressPrefix": "*", - "destinationPortRanges": ports, - }, - } - ) - logger.info("Using cluster name: %s", config["cluster_name"]) # set unique id for resources in this cluster diff --git a/sky/skylet/providers/azure/node_provider.py b/sky/skylet/providers/azure/node_provider.py index 9369499b42c..1a0306e5b67 100644 --- a/sky/skylet/providers/azure/node_provider.py +++ b/sky/skylet/providers/azure/node_provider.py @@ -15,9 +15,6 @@ bootstrap_azure, get_azure_sdk_function, ) -from sky.skylet.providers.command_runner import SkyDockerCommandRunner - -from ray.autoscaler._private.command_runner import SSHCommandRunner from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import ( TAG_RAY_CLUSTER_NAME, @@ -422,27 +419,3 @@ def _get_cached_node(self, node_id): @staticmethod def bootstrap_config(cluster_config): return bootstrap_azure(cluster_config) - - def get_command_runner( - self, - log_prefix, - node_id, - auth_config, - cluster_name, - process_runner, - use_internal_ip, - docker_config=None, - ): - common_args = { - "log_prefix": log_prefix, - "node_id": node_id, - "provider": self, - "auth_config": auth_config, - "cluster_name": cluster_name, - "process_runner": process_runner, - "use_internal_ip": use_internal_ip, - } - if docker_config and docker_config["container_name"] != "": - return SkyDockerCommandRunner(docker_config, **common_args) - else: - return SSHCommandRunner(**common_args) diff --git a/sky/skylet/providers/command_runner.py b/sky/skylet/providers/command_runner.py index 8bc2a48f1f3..db4495f9117 100644 --- a/sky/skylet/providers/command_runner.py +++ b/sky/skylet/providers/command_runner.py @@ -3,13 +3,13 @@ import os from typing import Dict -from sky.skylet import constants - from ray.autoscaler._private.cli_logger import cli_logger from ray.autoscaler._private.command_runner import DockerCommandRunner from ray.autoscaler._private.docker import check_docker_running_cmd from ray.autoscaler.sdk import get_docker_host_mount_location +from sky.skylet import constants + def docker_start_cmds( user, diff --git a/sky/skylet/providers/gcp/node_provider.py b/sky/skylet/providers/gcp/node_provider.py index 03e6557234e..1086a929820 100644 --- a/sky/skylet/providers/gcp/node_provider.py +++ b/sky/skylet/providers/gcp/node_provider.py @@ -12,7 +12,6 @@ construct_clients_from_provider_config, get_node_type, ) -from sky.skylet.providers.command_runner import SkyDockerCommandRunner from ray.autoscaler.tags import ( TAG_RAY_LAUNCH_CONFIG, @@ -20,7 +19,6 @@ TAG_RAY_USER_NODE_TYPE, ) from ray.autoscaler._private.cli_logger import cf, cli_logger -from ray.autoscaler._private.command_runner import SSHCommandRunner # The logic has been abstracted away here to allow for different GCP resources @@ -366,27 +364,3 @@ def _get_cached_node(self, node_id: str) -> GCPNode: @staticmethod def bootstrap_config(cluster_config): return bootstrap_gcp(cluster_config) - - def get_command_runner( - self, - log_prefix, - node_id, - auth_config, - cluster_name, - process_runner, - use_internal_ip, - docker_config=None, - ): - common_args = { - "log_prefix": log_prefix, - "node_id": node_id, - "provider": self, - "auth_config": auth_config, - "cluster_name": cluster_name, - "process_runner": process_runner, - "use_internal_ip": use_internal_ip, - } - if docker_config and docker_config["container_name"] != "": - return SkyDockerCommandRunner(docker_config, **common_args) - else: - return SSHCommandRunner(**common_args) diff --git a/sky/skylet/providers/ibm/node_provider.py b/sky/skylet/providers/ibm/node_provider.py index 068594d9f39..5e2a2d64493 100644 --- a/sky/skylet/providers/ibm/node_provider.py +++ b/sky/skylet/providers/ibm/node_provider.py @@ -23,30 +23,30 @@ import socket import threading import time -from pprint import pprint from pathlib import Path -from uuid import uuid4 - -from sky.adaptors import ibm +from pprint import pprint from typing import Any, Dict, List, Optional +from uuid import uuid4 from ray.autoscaler._private.cli_logger import cli_logger -from ray.autoscaler._private.util import hash_runtime_conf, hash_launch_conf +from ray.autoscaler._private.util import hash_launch_conf, hash_runtime_conf from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import ( NODE_KIND_HEAD, NODE_KIND_WORKER, TAG_RAY_CLUSTER_NAME, + TAG_RAY_FILE_MOUNTS_CONTENTS, + TAG_RAY_LAUNCH_CONFIG, TAG_RAY_NODE_KIND, TAG_RAY_NODE_NAME, - TAG_RAY_LAUNCH_CONFIG, + TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, TAG_RAY_USER_NODE_TYPE, - TAG_RAY_NODE_STATUS, - TAG_RAY_FILE_MOUNTS_CONTENTS, ) + +from sky.adaptors import ibm +from sky.skylet.providers.ibm.utils import RAY_RECYCLABLE, get_logger from sky.skylet.providers.ibm.vpc_provider import IBMVPCProvider -from sky.skylet.providers.ibm.utils import get_logger, RAY_RECYCLABLE logger = get_logger("node_provider_") diff --git a/sky/skylet/providers/ibm/utils.py b/sky/skylet/providers/ibm/utils.py index b2ace9a76ed..8b455ac2ef4 100644 --- a/sky/skylet/providers/ibm/utils.py +++ b/sky/skylet/providers/ibm/utils.py @@ -1,8 +1,8 @@ """holds common utility function/constants to be used by the providers.""" import logging -from pathlib import Path import time +from pathlib import Path RAY_RECYCLABLE = "ray-recyclable" diff --git a/sky/skylet/providers/ibm/vpc_provider.py b/sky/skylet/providers/ibm/vpc_provider.py index ef6c9147bac..6d691b765f7 100644 --- a/sky/skylet/providers/ibm/vpc_provider.py +++ b/sky/skylet/providers/ibm/vpc_provider.py @@ -4,15 +4,17 @@ nodes under the same subnet, tagged by the same cluster name. """ -from concurrent.futures import ThreadPoolExecutor -import uuid import copy -import time -import requests import json import textwrap +import time +import uuid +from concurrent.futures import ThreadPoolExecutor + +import requests + from sky.adaptors import ibm -from sky.skylet.providers.ibm.utils import get_logger, RAY_RECYCLABLE +from sky.skylet.providers.ibm.utils import RAY_RECYCLABLE, get_logger # pylint: disable=line-too-long logger = get_logger("vpc_provider_") diff --git a/sky/skylet/providers/kubernetes/__init__.py b/sky/skylet/providers/kubernetes/__init__.py index b09a3fe4183..278c1f11123 100644 --- a/sky/skylet/providers/kubernetes/__init__.py +++ b/sky/skylet/providers/kubernetes/__init__.py @@ -1,2 +1,3 @@ -from sky.skylet.providers.kubernetes.utils import get_head_ssh_port, get_port from sky.skylet.providers.kubernetes.node_provider import KubernetesNodeProvider +from sky.skylet.providers.kubernetes.utils import get_head_ssh_port +from sky.skylet.providers.kubernetes.utils import get_port diff --git a/sky/skylet/providers/kubernetes/node_provider.py b/sky/skylet/providers/kubernetes/node_provider.py index 3ab8414b2d2..b233462c2ac 100644 --- a/sky/skylet/providers/kubernetes/node_provider.py +++ b/sky/skylet/providers/kubernetes/node_provider.py @@ -5,13 +5,15 @@ from urllib.parse import urlparse from uuid import uuid4 +from ray.autoscaler._private.command_runner import SSHCommandRunner +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import NODE_KIND_HEAD +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME +from ray.autoscaler.tags import TAG_RAY_NODE_KIND + from sky.adaptors import kubernetes from sky.skylet.providers.kubernetes import config -from sky.skylet.providers.kubernetes import get_head_ssh_port from sky.skylet.providers.kubernetes import utils -from ray.autoscaler._private.command_runner import SSHCommandRunner -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import NODE_KIND_HEAD, TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND logger = logging.getLogger(__name__) @@ -116,7 +118,7 @@ def external_port(self, node_id): # TODO(romilb): Implement caching here for performance. # TODO(romilb): Multi-node would need more handling here. cluster_name = node_id.split('-ray-head')[0] - return get_head_ssh_port(cluster_name, self.namespace) + return utils.get_head_ssh_port(cluster_name, self.namespace) def internal_ip(self, node_id): pod = kubernetes.core_api().read_namespaced_pod(node_id, self.namespace) diff --git a/sky/skylet/providers/kubernetes/utils.py b/sky/skylet/providers/kubernetes/utils.py index 60bc99d0050..a6777ee23fd 100644 --- a/sky/skylet/providers/kubernetes/utils.py +++ b/sky/skylet/providers/kubernetes/utils.py @@ -1,7 +1,7 @@ -from typing import Tuple, Optional +from typing import Optional, Tuple -from sky.utils import common_utils from sky.adaptors import kubernetes +from sky.utils import common_utils DEFAULT_NAMESPACE = 'default' diff --git a/sky/skylet/providers/lambda_cloud/lambda_utils.py b/sky/skylet/providers/lambda_cloud/lambda_utils.py index 82325bab7f9..8a376e1abb0 100644 --- a/sky/skylet/providers/lambda_cloud/lambda_utils.py +++ b/sky/skylet/providers/lambda_cloud/lambda_utils.py @@ -1,10 +1,11 @@ """Lambda Cloud helper functions.""" import json import os -import requests import time from typing import Any, Dict, List, Optional, Tuple +import requests + from sky.utils import common_utils CREDENTIALS_PATH = '~/.lambda_cloud/lambda_keys' diff --git a/sky/skylet/providers/lambda_cloud/node_provider.py b/sky/skylet/providers/lambda_cloud/node_provider.py index 24fdf0b22d5..ed33b52d196 100644 --- a/sky/skylet/providers/lambda_cloud/node_provider.py +++ b/sky/skylet/providers/lambda_cloud/node_provider.py @@ -1,26 +1,25 @@ import logging import os -import time from threading import RLock +import time from typing import Any, Dict, List, Optional from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import ( - TAG_RAY_CLUSTER_NAME, - TAG_RAY_USER_NODE_TYPE, - TAG_RAY_NODE_NAME, - TAG_RAY_NODE_STATUS, - STATUS_UP_TO_DATE, - TAG_RAY_NODE_KIND, - NODE_KIND_WORKER, - NODE_KIND_HEAD, -) -from sky.skylet.providers.lambda_cloud import lambda_utils +from ray.autoscaler.tags import NODE_KIND_HEAD +from ray.autoscaler.tags import NODE_KIND_WORKER +from ray.autoscaler.tags import STATUS_UP_TO_DATE +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME +from ray.autoscaler.tags import TAG_RAY_NODE_KIND +from ray.autoscaler.tags import TAG_RAY_NODE_NAME +from ray.autoscaler.tags import TAG_RAY_NODE_STATUS +from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE + from sky import authentication as auth +from sky.skylet.providers.lambda_cloud import lambda_utils from sky.utils import command_runner +from sky.utils import common_utils from sky.utils import subprocess_utils from sky.utils import ux_utils -from sky.utils import common_utils _TAG_PATH_PREFIX = '~/.sky/generated/lambda_cloud/metadata' _REMOTE_SSH_KEY_NAME = '~/.lambda_cloud/ssh_key_name' diff --git a/sky/skylet/providers/oci/config.py b/sky/skylet/providers/oci/config.py index ba91cb2c541..08e28dc8cbf 100644 --- a/sky/skylet/providers/oci/config.py +++ b/sky/skylet/providers/oci/config.py @@ -6,6 +6,7 @@ """ import logging import os + from sky import skypilot_config logger = logging.getLogger(__name__) diff --git a/sky/skylet/providers/oci/node_provider.py b/sky/skylet/providers/oci/node_provider.py index 82fbe3cbf7a..9efde707d2b 100644 --- a/sky/skylet/providers/oci/node_provider.py +++ b/sky/skylet/providers/oci/node_provider.py @@ -10,24 +10,22 @@ """ +import copy +from datetime import datetime import logging -import time import threading -import copy +import time -from datetime import datetime -from sky.skylet.providers.oci.config import oci_conf +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME +from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG +from ray.autoscaler.tags import TAG_RAY_NODE_KIND +from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE + +from sky.adaptors import oci as oci_adaptor from sky.skylet.providers.oci import utils +from sky.skylet.providers.oci.config import oci_conf from sky.skylet.providers.oci.query_helper import oci_query_helper -from sky.adaptors import oci as oci_adaptor - -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import ( - TAG_RAY_CLUSTER_NAME, - TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, - TAG_RAY_LAUNCH_CONFIG, -) logger = logging.getLogger(__name__) diff --git a/sky/skylet/providers/oci/query_helper.py b/sky/skylet/providers/oci/query_helper.py index 29601192e5d..590a687c29c 100644 --- a/sky/skylet/providers/oci/query_helper.py +++ b/sky/skylet/providers/oci/query_helper.py @@ -7,16 +7,18 @@ """ -import logging -import traceback -import time from datetime import datetime -import pandas as pd +import logging import re +import time +import traceback from typing import Optional -from sky.skylet.providers.oci.config import oci_conf -from sky.skylet.providers.oci import utils + +import pandas as pd + from sky.adaptors import oci as oci_adaptor +from sky.skylet.providers.oci import utils +from sky.skylet.providers.oci.config import oci_conf logger = logging.getLogger(__name__) diff --git a/sky/skylet/providers/oci/utils.py b/sky/skylet/providers/oci/utils.py index 6c783a630ec..5628cee2524 100644 --- a/sky/skylet/providers/oci/utils.py +++ b/sky/skylet/providers/oci/utils.py @@ -1,6 +1,6 @@ -from logging import Logger from datetime import datetime import functools +from logging import Logger def debug_enabled(logger: Logger): diff --git a/sky/skylet/providers/scp/node_provider.py b/sky/skylet/providers/scp/node_provider.py index 9fc77f69343..ad47f39ff76 100644 --- a/sky/skylet/providers/scp/node_provider.py +++ b/sky/skylet/providers/scp/node_provider.py @@ -4,28 +4,27 @@ to provide the functions accessing SCP nodes """ +import copy +from functools import wraps import logging import os -import time from threading import RLock +import time from typing import Any, Dict, List, Optional -import copy -from functools import wraps -from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler._private.cli_logger import cli_logger -from ray.autoscaler.tags import ( - TAG_RAY_CLUSTER_NAME, - TAG_RAY_USER_NODE_TYPE, - TAG_RAY_NODE_NAME, - TAG_RAY_LAUNCH_CONFIG, - TAG_RAY_NODE_STATUS, - STATUS_UP_TO_DATE, - TAG_RAY_NODE_KIND, - NODE_KIND_WORKER, - NODE_KIND_HEAD, -) from ray.autoscaler._private.util import hash_launch_conf +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import NODE_KIND_HEAD +from ray.autoscaler.tags import NODE_KIND_WORKER +from ray.autoscaler.tags import STATUS_UP_TO_DATE +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME +from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG +from ray.autoscaler.tags import TAG_RAY_NODE_KIND +from ray.autoscaler.tags import TAG_RAY_NODE_NAME +from ray.autoscaler.tags import TAG_RAY_NODE_STATUS +from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE + from sky.skylet.providers.scp import scp_utils from sky.skylet.providers.scp.config import ZoneConfig from sky.skylet.providers.scp.scp_utils import SCPCreationFailError diff --git a/sky/skylet/providers/scp/scp_utils.py b/sky/skylet/providers/scp/scp_utils.py index e55f4861556..1086f281971 100644 --- a/sky/skylet/providers/scp/scp_utils.py +++ b/sky/skylet/providers/scp/scp_utils.py @@ -4,17 +4,18 @@ """ import base64 import datetime +from functools import wraps import hashlib import hmac import json import logging import os -import requests import time -from functools import wraps from typing import Any, Dict, List, Optional from urllib import parse +import requests + CREDENTIALS_PATH = '~/.scp/scp_credential' API_ENDPOINT = 'https://openapi.samsungsdscloud.com' TEMP_VM_JSON_PATH = '/tmp/json/tmp_vm_body.json' diff --git a/sky/skylet/subprocess_daemon.py b/sky/skylet/subprocess_daemon.py index b3af9a19c3b..89f0fa8cb5a 100644 --- a/sky/skylet/subprocess_daemon.py +++ b/sky/skylet/subprocess_daemon.py @@ -5,15 +5,15 @@ """ import argparse -import requests import sys import time import psutil - -from sky.skylet import job_lib from ray.dashboard.modules.job import common as job_common from ray.dashboard.modules.job import sdk as job_sdk +import requests + +from sky.skylet import job_lib if __name__ == '__main__': diff --git a/sky/spot/__init__.py b/sky/spot/__init__.py index 0339b763531..c262ccdd8f3 100644 --- a/sky/spot/__init__.py +++ b/sky/spot/__init__.py @@ -1,22 +1,20 @@ """Modules for managed spot clusters.""" import pathlib -from sky.spot.constants import ( - SPOT_CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP, - SPOT_CONTROLLER_TEMPLATE, - SPOT_CONTROLLER_YAML_PREFIX, - SPOT_TASK_YAML_PREFIX, -) -from sky.spot.recovery_strategy import SPOT_STRATEGIES +from sky.spot.constants import SPOT_CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP +from sky.spot.constants import SPOT_CONTROLLER_TEMPLATE +from sky.spot.constants import SPOT_CONTROLLER_YAML_PREFIX +from sky.spot.constants import SPOT_TASK_YAML_PREFIX from sky.spot.recovery_strategy import SPOT_DEFAULT_STRATEGY -from sky.spot.spot_utils import SpotCodeGen -from sky.spot.spot_utils import SPOT_CONTROLLER_NAME +from sky.spot.recovery_strategy import SPOT_STRATEGIES from sky.spot.spot_utils import dump_job_table_cache from sky.spot.spot_utils import dump_spot_job_queue -from sky.spot.spot_utils import load_job_table_cache from sky.spot.spot_utils import format_job_table from sky.spot.spot_utils import is_spot_controller_up +from sky.spot.spot_utils import load_job_table_cache from sky.spot.spot_utils import load_spot_job_queue +from sky.spot.spot_utils import SPOT_CONTROLLER_NAME +from sky.spot.spot_utils import SpotCodeGen pathlib.Path(SPOT_TASK_YAML_PREFIX).expanduser().parent.mkdir(parents=True, exist_ok=True) diff --git a/sky/spot/dashboard/dashboard.py b/sky/spot/dashboard/dashboard.py index 2790d7d110c..b62cc523d4a 100644 --- a/sky/spot/dashboard/dashboard.py +++ b/sky/spot/dashboard/dashboard.py @@ -8,9 +8,9 @@ """ import datetime import pathlib -import yaml import flask +import yaml import sky from sky import spot diff --git a/sky/spot/spot_state.py b/sky/spot/spot_state.py index 335d5d20fb1..8c05578b774 100644 --- a/sky/spot/spot_state.py +++ b/sky/spot/spot_state.py @@ -6,7 +6,7 @@ import sqlite3 import time import typing -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import colorama diff --git a/sky/spot/spot_utils.py b/sky/spot/spot_utils.py index 38cbbc3a7f1..8af33f273c1 100644 --- a/sky/spot/spot_utils.py +++ b/sky/spot/spot_utils.py @@ -8,10 +8,10 @@ import time import typing from typing import Any, Dict, List, Optional, Tuple, Union -from typing_extensions import Literal import colorama import filelock +from typing_extensions import Literal from sky import backends from sky import exceptions @@ -22,14 +22,14 @@ from sky.skylet import constants from sky.skylet import job_lib from sky.skylet.log_lib import run_bash_command_with_log +from sky.spot import spot_state from sky.utils import common_utils from sky.utils import log_utils -from sky.spot import spot_state from sky.utils import subprocess_utils if typing.TYPE_CHECKING: - from sky import dag as dag_lib import sky + from sky import dag as dag_lib logger = sky_logging.init_logger(__name__) diff --git a/sky/task.py b/sky/task.py index feaab9c9149..87b41d6b570 100644 --- a/sky/task.py +++ b/sky/task.py @@ -4,7 +4,8 @@ import os import re import typing -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import (Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, + Union) import yaml @@ -13,8 +14,8 @@ from sky import exceptions from sky import global_user_state from sky.backends import backend_utils -from sky.data import storage as storage_lib from sky.data import data_utils +from sky.data import storage as storage_lib from sky.skylet import constants from sky.utils import schemas from sky.utils import ux_utils diff --git a/sky/usage/usage_lib.py b/sky/usage/usage_lib.py index c4b5b0003d9..bd488399d0e 100644 --- a/sky/usage/usage_lib.py +++ b/sky/usage/usage_lib.py @@ -1,9 +1,8 @@ """Logging events to Grafana Loki.""" -import enum -import click import contextlib import datetime +import enum import inspect import json import os @@ -12,6 +11,7 @@ import typing from typing import Any, Callable, Dict, List, Optional, Union +import click import requests import sky @@ -21,8 +21,8 @@ from sky.utils import env_options if typing.TYPE_CHECKING: - from sky import status_lib from sky import resources as resources_lib + from sky import status_lib from sky import task as task_lib logger = sky_logging.init_logger(__name__) diff --git a/sky/utils/__init__.py b/sky/utils/__init__.py index eff27bdd65b..e69de29bb2d 100644 --- a/sky/utils/__init__.py +++ b/sky/utils/__init__.py @@ -1,2 +0,0 @@ -"""Utility functions.""" -from sky.skylet.providers.kubernetes import utils as kubernetes_utils diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index c867bddd909..08fde49354d 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -1,6 +1,6 @@ """Runner for commands to be executed on the cluster.""" -import getpass import enum +import getpass import hashlib import os import pathlib @@ -9,9 +9,10 @@ from typing import List, Optional, Tuple, Union from sky import sky_logging -from sky.utils import common_utils, subprocess_utils from sky.skylet import constants from sky.skylet import log_lib +from sky.utils import common_utils +from sky.utils import subprocess_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/utils/command_runner.pyi b/sky/utils/command_runner.pyi index fc6e2f424d9..94a13468007 100644 --- a/sky/utils/command_runner.pyi +++ b/sky/utils/command_runner.pyi @@ -6,11 +6,13 @@ determine the return type based on the value of require_outputs. """ import enum import typing +from typing import List, Optional, Tuple, Union + +from typing_extensions import Literal + from sky import sky_logging as sky_logging from sky.skylet import log_lib as log_lib from sky.utils import subprocess_utils as subprocess_utils -from typing import List, Optional, Tuple, Union -from typing_extensions import Literal GIT_EXCLUDE: str RSYNC_DISPLAY_OPTION: str diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 2467b88003b..100e63d2338 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -14,9 +14,9 @@ import time from typing import Any, Callable, Dict, List, Optional, Union import uuid -import yaml import colorama +import yaml from sky import sky_logging diff --git a/sky/utils/log_utils.py b/sky/utils/log_utils.py index 3e4a1c33998..994bef81866 100644 --- a/sky/utils/log_utils.py +++ b/sky/utils/log_utils.py @@ -1,13 +1,12 @@ """Logging utils.""" import enum import threading -from typing import Optional, List - -import rich.console as rich_console +from typing import List, Optional import colorama import pendulum import prettytable +import rich.console as rich_console from sky import sky_logging diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 40b926774d4..7db5a5557c0 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -1,12 +1,12 @@ """Utility functions for subprocesses.""" from multiprocessing import pool -import psutil import random import subprocess import time from typing import Any, Callable, List, Optional, Tuple, Union import colorama +import psutil from sky import exceptions from sky import sky_logging diff --git a/sky/utils/timeline.py b/sky/utils/timeline.py index c69de256f44..f0e320d9c87 100644 --- a/sky/utils/timeline.py +++ b/sky/utils/timeline.py @@ -3,14 +3,13 @@ The timeline follows the trace event format defined here: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview """ # pylint: disable=line-too-long -import functools -from typing import Optional, Union, Callable - import atexit +import functools import json import os import threading import time +from typing import Callable, Optional, Union import filelock diff --git a/tests/conftest.py b/tests/conftest.py index 5a199a51017..925c473e2b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,9 @@ -import pytest import tempfile from typing import List from unittest.mock import patch import pandas as pd +import pytest # Usage: use # @pytest.mark.slow @@ -182,6 +182,7 @@ def generic_cloud(request) -> str: @pytest.fixture def enable_all_clouds(monkeypatch): from sky import clouds + # Monkey-patching is required because in the test environment, no cloud is # enabled. The optimizer checks the environment to find enabled clouds, and # only generates plans within these clouds. The tests assume that all three diff --git a/tests/stress/mountedstorage/read_parallel.py b/tests/stress/mountedstorage/read_parallel.py index 2f3c8a47f4a..e6bb9f409a5 100644 --- a/tests/stress/mountedstorage/read_parallel.py +++ b/tests/stress/mountedstorage/read_parallel.py @@ -1,8 +1,8 @@ # Read all files in a directory recursively in parallel -import os -from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor +import os def parse_args(): diff --git a/tests/test_onprem.py b/tests/test_onprem.py index d4c8aea94cd..d4799411760 100644 --- a/tests/test_onprem.py +++ b/tests/test_onprem.py @@ -3,10 +3,10 @@ import sys import tempfile import textwrap -from typing import List, Optional, Tuple, NamedTuple +from typing import List, NamedTuple, Optional, Tuple -import colorama from click import testing as cli_testing +import colorama import pytest import yaml diff --git a/tests/test_smoke.py b/tests/test_smoke.py index d9f4cd58335..25eb5181abb 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -24,6 +24,7 @@ import hashlib import inspect +import os import pathlib import shutil import subprocess @@ -33,7 +34,6 @@ from typing import Dict, List, NamedTuple, Optional, Tuple import urllib.parse import uuid -import os import warnings import colorama @@ -42,9 +42,11 @@ import sky from sky import global_user_state -from sky.adaptors import ibm from sky.adaptors import cloudflare -from sky.clouds import AWS, GCP, Azure +from sky.adaptors import ibm +from sky.clouds import AWS +from sky.clouds import Azure +from sky.clouds import GCP from sky.data import data_utils from sky.data import storage as storage_lib from sky.data.data_utils import Rclone diff --git a/tests/test_storage.py b/tests/test_storage.py index 80cd624f676..97a4dc59863 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,7 +1,7 @@ +import tempfile import time import pytest -import tempfile from sky import exceptions from sky.data import storage as storage_lib diff --git a/tests/test_wheels.py b/tests/test_wheels.py index 1d2a18ce789..b8ee53fff26 100644 --- a/tests/test_wheels.py +++ b/tests/test_wheels.py @@ -1,6 +1,6 @@ import os -import time import shutil +import time from sky.backends import wheel_utils diff --git a/tests/unit_tests/sky/clouds/test_gcp.py b/tests/unit_tests/sky/clouds/test_gcp.py index 5d8136296d3..95db1fef8ac 100644 --- a/tests/unit_tests/sky/clouds/test_gcp.py +++ b/tests/unit_tests/sky/clouds/test_gcp.py @@ -1,7 +1,11 @@ from unittest.mock import patch -from sky.clouds.gcp import GCP, GCPReservation, SpecificReservation + import pytest +from sky.clouds.gcp import GCP +from sky.clouds.gcp import GCPReservation +from sky.clouds.gcp import SpecificReservation + @pytest.mark.parametrize(( 'mock_return', 'expected' diff --git a/tests/unit_tests/sky/test_resources.py b/tests/unit_tests/sky/test_resources.py index ae458aa2b88..4a01842cad6 100644 --- a/tests/unit_tests/sky/test_resources.py +++ b/tests/unit_tests/sky/test_resources.py @@ -1,6 +1,7 @@ -from sky.resources import Resources from unittest.mock import Mock +from sky.resources import Resources + def test_get_reservations_available_resources(): mock = Mock()