Skip to content

Commit

Permalink
revert formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Sep 28, 2024
1 parent 53c2503 commit 3356a2a
Show file tree
Hide file tree
Showing 39 changed files with 116 additions and 14 deletions.
2 changes: 1 addition & 1 deletion torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn.functional as F

from torchao.quantization.utils import _MultiInput
from torchao.quantization.utils import _lm_eval_available, _MultiInput

import lm_eval
try: # lm_eval version 0.4
Expand Down
1 change: 1 addition & 0 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def update(self, input_pos, k_val, v_val):
return k_out, v_out


from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine
from torchao.quantization.utils import quantize_activation_per_token_absmax

class AffineQuantizedKVCache(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import tqdm
import torch
import fire
from metrics import calculate_miou, create_result_entry
from data import build_data, setup_coco_img_ids
import math
import segment_anything_fast
import time
import resource

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayoutType
from torchao.dtypes import SemiSparseLayoutType, MarlinSparseLayoutType
from torchao.utils import unwrap_tensor_subclass
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down
7 changes: 6 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from typing import Tuple, Optional, Union
from collections import defaultdict
import functools
import math
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
Expand All @@ -14,6 +16,9 @@
quantize_affine_floatx,
dequantize_affine_floatx,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
)
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.utils import (
LayoutType,
Expand Down Expand Up @@ -1625,7 +1630,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
)

def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
from torchao.sparsity.marlin import marlin_24_workspace
from torchao.sparsity.marlin import marlin_24_workspace, const
from torchao.ops import marlin_24_gemm

assert isinstance(weight_tensor, AffineQuantizedTensor)
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/floatx/floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchao.dtypes.utils import (
LayoutType,
)
from torchao.quantization.quant_api import _get_linear_subclass_inserter
from dataclasses import dataclass
from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls

Expand Down
4 changes: 3 additions & 1 deletion torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
A simple module swap UX for a float8 version of `torch.nn.Linear`.
"""

import dataclasses
import enum

from typing import Optional

Expand All @@ -31,7 +33,7 @@
ScaledMMConfig,
)

from torchao.float8.float8_utils import e4m3_dtype
from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax

from torchao.float8.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Callable, Optional
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
Expand Down
2 changes: 2 additions & 0 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
)

from torchao.float8.float8_utils import (
amax_history_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
Expand Down
1 change: 1 addition & 0 deletions torchao/kernel/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import pathlib
import pickle

import torch
import triton
Expand Down
1 change: 1 addition & 0 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import os
import torch

Expand Down
1 change: 1 addition & 0 deletions torchao/kernel/intmm_triton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import os

import torch

Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/autoround/autoround_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
from typing import Optional

import torch
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/autoround/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, Dict, Optional, Tuple

import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

import torchao.prototype.autoround.utils as ar_utils
import torchao.quantization as ao_quant
Expand Down
2 changes: 2 additions & 0 deletions torchao/prototype/dora/dora_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import bitsandbytes as bnb
import torch
import torch.nn as nn
from bitsandbytes.nn import Linear4bit
from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear

from prototypes.dora.kernels.matmul import triton_mm
from prototypes.dora.kernels.smallk import triton_mm_small_k
Expand Down
6 changes: 6 additions & 0 deletions torchao/prototype/dora/kernels/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
import triton.language as tl

# Re-exports
from triton.ops.matmul import (
early_config_prune,
estimate_matmul_time,
get_configs_io_bound,
get_higher_dtype,
)
from triton.runtime import Config


Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/galore/kernels/adam_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def make_data(M, N, rank, dtype):


if __name__ == "__main__":
from triton.testing import do_bench

M = N = 4096
rank = 128
Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/hqq/example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from torchao.prototype.hqq.core import HQQQuantizer
from torchao.dtypes.affine_quantized_tensor import (
to_affine_quantized_intx,
ZeroPointDomain,
PlainAQTLayout,
PlainLayoutType,
TensorCoreTiledAQTLayout,
TensorCoreTiledLayoutType,
MappingType,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import sys

import torch
import torch.nn as nn
from torchao.quantization import quantize_
import random

from naive_intNwo import intN_weight_only

import copy
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import get_task_dict

from transformers import AutoModelForCausalLM, AutoTokenizer
from ax.service.ax_client import AxClient, ObjectiveProperties
import torch.multiprocessing as mp
from ax.modelbridge.cross_validation import cross_validate
from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config, load_parameters_from_json, load_initial_samples
from BO_acc_throughput import define_parameter_list

Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
import sys

import torch
import time
import torch.nn as nn
from torchao.quantization import quantize_
import random
from naive_intNwo import intN_weight_only

import copy
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import get_task_dict

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from ax.service.ax_client import AxClient, ObjectiveProperties
import torch.multiprocessing as mp

import os
import sys
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
from datetime import datetime
import torchao
import torch._dynamo.config
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.quant_api import int4_weight_only
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao._models._eval import TransformerEvalWrapper, InputRecorder

from torchao.dtypes import TensorCoreTiledLayoutType

from torchao._models.llama.generate import (
device_sync,
multinomial_sample_one_no_sync,
logits_to_probs,
sample,
prefill,
decode_one_token,
model_forward,
encode_tokens,
_load_model,
)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/prototype/mixed_precision/scripts/fit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
import numpy as np
import os
from tqdm import tqdm
import transformers
from datasets import load_dataset
import random
from torch.nn.attention import SDPBackend, sdpa_kernel

def get_wikitext2(nsamples, seed, seqlen, tokenizer):
traindata = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import numpy as np
import os
from tqdm import tqdm
import transformers
from datasets import load_dataset
import random
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.autograd.functional import hvp

def group_product(xs, ys):
return [torch.sum(x * y) for (x, y) in zip(xs, ys)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch
import torchvision.models as models
import numpy as np
import os
from tqdm import tqdm
import transformers
from datasets import load_dataset
import random
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.autograd.functional import vhp
from torch.autograd.functional import hvp, vhp


def group_product(xs, ys):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
from lm_eval.evaluator import evaluate
from lm_eval.tasks import get_task_dict

from torchao.quantization import quantize_
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight
from torchao._models._eval import TransformerEvalWrapper

from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization.quant_api import autoquant

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import csv
import sys

import torch
import torch.nn as nn
from torchao.quantization import quantize_
import random

from naive_intNwo import intN_weight_only

import copy
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import get_task_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchao.quantization.quant_primitives import (
_get_and_check_qmin_qmax,
choose_qparams_affine,
fake_quantize_affine,
ZeroPointDomain,
MappingType,
)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, List, Optional

import torch
import torch.nn.functional as F

from torchao.dtypes import (
TensorCoreTiledLayoutType,
Expand Down Expand Up @@ -87,6 +88,7 @@ def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32):
quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32))
"""
# avoid circular dep
from torchao.dtypes import to_affine_quantized_intx

def _apply_weight_fake_quant(weight: torch.Tensor):
mapping_type = MappingType.SYMMETRIC
Expand Down
5 changes: 4 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import torch
import torchao
import torch.nn as nn
from typing import Callable, Union, Optional, Tuple
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple
import types

from torchao.dtypes.uintx.uintx import UintxLayoutType
Expand All @@ -37,6 +38,7 @@
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)
from .subclass import (
QuantizedLinearWeightBase,
Expand All @@ -54,6 +56,7 @@
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Expand Down
Loading

0 comments on commit 3356a2a

Please sign in to comment.