Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding isort #2375

Merged
merged 15 commits into from
Aug 11, 2023
12 changes: 12 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
pip install yapf==0.32.0
pip install toml==0.10.2
pip install black==22.10.0
pip install isort==5.12.0
- name: Running yapf
run: |
yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \
Expand All @@ -42,3 +43,14 @@ jobs:
sky/skylet/providers/gcp/ \
sky/skylet/providers/azure/ \
sky/skylet/providers/ibm/
- name: Running isort for black formatted files
run: |
isort --diff --check --profile black -l 88 -m 3 \
sky/skylet/providers/ibm/
- name: Running isort for yapf formatted files
run: |
isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \
--sg 'sky/skylet/providers/aws/**' \
--sg 'sky/skylet/providers/gcp/**' \
--sg 'sky/skylet/providers/azure/**' \
Comment on lines +53 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 that we want to skip formatting sky/skylet/providers/{aws,gcp,azure}. Reason is those files are forked from upstream Ray, and each time when we do a Ray dependency upgrade, we do a diff against the upstream files to understand whether our genuine changes are affected.

However, it seems like this PR has changed files under those dirs?

.../ibm should be okay to be isort-ed since it's not forked from Ray.

--sg 'sky/skylet/providers/ibm/**'
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Configuration file for the Sphinx documentation builder.

import sys
import os
import sys

sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('../'))
Expand Down
3 changes: 2 additions & 1 deletion examples/docker/echo_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
# python echo_app.py

import random
import sky
import string

import sky

with sky.Dag() as dag:
# The setup command to build the container image
setup = 'docker build -t echo:v0 /echo_app'
Expand Down
4 changes: 2 additions & 2 deletions examples/example_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
Incorporate the notion of region/zone (affects pricing).
Incorporate the notion of per-account egress quota (affects pricing).
"""
import sky

import time_estimators

import sky


def make_application():
"""A simple application: train_op -> infer_op."""
Expand Down
3 changes: 2 additions & 1 deletion examples/horovod_distributed_tf_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import json
from typing import Dict, List

import sky
import time_estimators

import sky

IPAddr = str

with sky.Dag() as dag:
Expand Down
2 changes: 1 addition & 1 deletion examples/local/launch_cloud_onprem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import tempfile
import textwrap
import uuid
import yaml

from click import testing as cli_testing
import yaml

from sky import cli
from sky import global_user_state
Expand Down
3 changes: 2 additions & 1 deletion examples/playground/storage_playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# These are not exhaustive tests. Actual Tests are in tests/test_storage.py and
# tests/test_smoke.py.

from sky.data import storage, StoreType
from sky.data import storage
from sky.data import StoreType


def get_args():
Expand Down
11 changes: 5 additions & 6 deletions examples/ray_tune_examples/tune_ptl_example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
### Source: https://docs.ray.io/en/latest/tune/examples/mnist_ptl_mini.html
import math
import os

import torch
from filelock import FileLock
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback

import pytorch_lightning as pl
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import torch
from torch.nn import functional as F


class LightningMNISTClassifier(pl.LightningModule):
Expand Down
22 changes: 13 additions & 9 deletions examples/spot/lightning_cifar10/train.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
# Code modified from https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/cifar10-baseline.html

import argparse
import glob
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import LightningModule
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
from torch.optim.swa_utils import AveragedModel
from torch.optim.swa_utils import update_bn
from torchmetrics.functional import accuracy

import argparse, glob
import torchvision

seed_everything(7)

Expand Down
16 changes: 7 additions & 9 deletions examples/spot/resnet_ddp/resnet_ddp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import argparse
import os
import random

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np

import wandb


Expand Down
4 changes: 2 additions & 2 deletions examples/tpu/tpu_app_code/run_tpu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import tensorflow_datasets as tfds
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
from transformers import TFDistilBertForSequenceClassification
from transformers import TFBertForSequenceClassification
from transformers import TFDistilBertForSequenceClassification

tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
Expand Down
21 changes: 18 additions & 3 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ YAPF_EXCLUDES=(
'--exclude' 'sky/skylet/providers/ibm/**'
)

ISORT_YAPF_EXCLUDES=(
'--sg' 'build/**'
'--sg' 'sky/skylet/providers/aws/**'
'--sg' 'sky/skylet/providers/gcp/**'
'--sg' 'sky/skylet/providers/azure/**'
Comment on lines +59 to +61
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(See other comment) Some files under these dirs seem formatted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added isort with the black profile (L120) since they were already formatted by black (L105-106).
They are formatted but it seems that mostly things that were moved were sky imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @Michaelvll. The imports being formatted may still hinder humans (us :)) looking at diffs when we're upgrading Ray. Can we skip formatting sky/skylet/providers/{aws,gcp,azure} even with the black profile, both in format.sh and in .github/workflows/format.yml?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will do

'--sg' 'sky/skylet/providers/ibm/**'
)

BLACK_INCLUDES=(
'sky/skylet/providers/aws'
'sky/skylet/providers/gcp'
Expand Down Expand Up @@ -86,9 +94,12 @@ format_changed() {

# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sky tests examples
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sky tests examples llm
}

echo 'SkyPilot Black:'
black "${BLACK_INCLUDES[@]}"

## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
Expand All @@ -102,8 +113,12 @@ else
format_changed
fi
echo 'SkyPilot yapf: Done'
echo 'SkyPilot Black:'
black "${BLACK_INCLUDES[@]}"

echo 'SkyPilot isort:'
isort sky tests examples llm docs "${ISORT_YAPF_EXCLUDES[@]}"

isort --profile black -l 88 -m 3 "sky/skylet/providers/ibm"


# Run mypy
# TODO(zhwu): When more of the codebase is typed properly, the mypy flags
Expand Down
13 changes: 6 additions & 7 deletions llm/vicuna-llama-2/scripts/flash_attn_patch.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import List, Optional, Tuple
import logging
from typing import List, Optional, Tuple

from einops import rearrange
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
# pip3 install "flash-attn>=2.0"
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
import torch
from torch import nn

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
from flash_attn.flash_attn_interface import ( # pip3 install "flash-attn>=2.0"
flash_attn_varlen_qkvpacked_func,)
from flash_attn.bert_padding import unpad_input, pad_input


def forward(
self,
Expand Down
10 changes: 5 additions & 5 deletions llm/vicuna-llama-2/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from dataclasses import dataclass
from dataclasses import field
import json
import pathlib
import os
import pathlib
import shutil
import subprocess
from typing import Dict, Optional

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


Expand Down
3 changes: 1 addition & 2 deletions llm/vicuna-llama-2/scripts/train_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from flash_attn_patch import (
replace_llama_attn_with_flash_attn,)
from flash_attn_patch import replace_llama_attn_with_flash_attn

replace_llama_attn_with_flash_attn()

Expand Down
3 changes: 1 addition & 2 deletions llm/vicuna-llama-2/scripts/train_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from xformers_patch import (
replace_llama_attn_with_xformers_attn,)
from xformers_patch import replace_llama_attn_with_xformers_attn

replace_llama_attn_with_xformers_attn()

Expand Down
2 changes: 1 addition & 1 deletion llm/vicuna-llama-2/scripts/xformers_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from typing import Optional, Tuple

import torch
import transformers.models.llama.modeling_llama
from torch import nn
import transformers.models.llama.modeling_llama

try:
import xformers.ops
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ python_version = "3.8"
follow_imports = "skip"
ignore_missing_imports = true
allow_redefinition = true

[tool.isort]
profile = "google"
line_length = 80
multi_line_output = 0
combine_as_imports = true
use_parentheses = true
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ black==22.10.0
# https://github.com/edaniszewski/pylint-quotes
pylint-quotes==0.2.3
toml==0.10.2
isort==5.12.0

# type checking
mypy==0.991
Expand Down
35 changes: 26 additions & 9 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,35 @@
from sky import benchmark
from sky import clouds
from sky.clouds.service_catalog import list_accelerators
from sky.core import autostop
from sky.core import cancel
from sky.core import cost_report
from sky.core import down
from sky.core import download_logs
from sky.core import job_status
from sky.core import queue
from sky.core import spot_cancel
from sky.core import spot_queue
from sky.core import spot_status
from sky.core import start
from sky.core import status
from sky.core import stop
from sky.core import storage_delete
from sky.core import storage_ls
from sky.core import tail_logs
from sky.dag import Dag
from sky.execution import launch, exec, spot_launch # pylint: disable=redefined-builtin
from sky.data import Storage
from sky.data import StorageMode
from sky.data import StoreType
from sky.execution import exec # pylint: disable=redefined-builtin
from sky.execution import launch
from sky.execution import spot_launch
from sky.optimizer import Optimizer
from sky.optimizer import OptimizeTarget
from sky.resources import Resources
from sky.task import Task
from sky.optimizer import Optimizer, OptimizeTarget
from sky.data import Storage, StorageMode, StoreType
from sky.status_lib import ClusterStatus
from sky.skylet.job_lib import JobStatus
from sky.core import (status, start, stop, down, autostop, queue, cancel,
tail_logs, download_logs, job_status, spot_queue,
spot_status, spot_cancel, storage_ls, storage_delete,
cost_report)
from sky.status_lib import ClusterStatus
from sky.task import Task

# Aliases.
IBM = clouds.IBM
Expand Down
2 changes: 1 addition & 1 deletion sky/adaptors/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import contextlib
import functools
import threading
import os
import threading
from typing import Dict, Optional, Tuple

from sky.utils import ux_utils
Expand Down
2 changes: 1 addition & 1 deletion sky/adaptors/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def wrapper(*args, **kwargs):
global googleapiclient, google
if googleapiclient is None or google is None:
try:
import googleapiclient as _googleapiclient
import google as _google
import googleapiclient as _googleapiclient
googleapiclient = _googleapiclient
google = _google
except ImportError:
Expand Down
Loading