Skip to content

Commit

Permalink
📽️ Olive StrEnumBase IntEnumBase (#1290)
Browse files Browse the repository at this point in the history
## Describe your changes

python/cpython#100458
Quarot require python>=3.11 where the mixin usage of (str, Enum) did not
work. This PR is used to create olive strEnum based on python version.

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
trajepl committed Aug 9, 2024
1 parent 7b4cefe commit 213ac7a
Show file tree
Hide file tree
Showing 21 changed files with 82 additions and 64 deletions.
5 changes: 3 additions & 2 deletions examples/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -11,13 +10,15 @@
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from transformers import PreTrainedTokenizer

from olive.common.utils import StrEnumBase

if TYPE_CHECKING:
from kv_cache_utils import Cache, IOBoundCache
from numpy.typing import NDArray
from onnx import ValueInfoProto


class AdapterMode(Enum):
class AdapterMode(StrEnumBase):
"""Enum for adapter modes."""

inputs = "inputs"
Expand Down
4 changes: 2 additions & 2 deletions olive/auto_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import logging
from copy import deepcopy
from enum import Enum
from typing import List, Optional

from olive.auto_optimizer.regulate_mixins import RegulatePassConfigMixin
from olive.common.config_utils import ConfigBase
from olive.common.pydantic_v1 import validator
from olive.common.utils import StrEnumBase
from olive.data.config import DataConfig
from olive.evaluator.olive_evaluator import OliveEvaluatorConfig
from olive.hardware.accelerator import AcceleratorSpec
Expand All @@ -19,7 +19,7 @@
logger = logging.getLogger(__name__)


class Precision(str, Enum):
class Precision(StrEnumBase):
FP32 = "fp32"
FP16 = "fp16"
INT8 = "int8"
Expand Down
5 changes: 2 additions & 3 deletions olive/common/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import inspect
import json
import logging
from enum import Enum
from functools import partial
from pathlib import Path
from types import FunctionType, MethodType
Expand All @@ -14,7 +13,7 @@
import yaml

from olive.common.pydantic_v1 import BaseModel, create_model, root_validator, validator
from olive.common.utils import hash_function, hash_object
from olive.common.utils import StrEnumBase, hash_function, hash_object

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -212,7 +211,7 @@ def gather_nested_field(cls, values):
return values


class CaseInsensitiveEnum(str, Enum):
class CaseInsensitiveEnum(StrEnumBase):
"""StrEnum class that is insensitive to the case of the input string.
Note: Only insensitive when creating the enum object like `CaseInsensitiveEnum("value")`.
Expand Down
4 changes: 2 additions & 2 deletions olive/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import Enum
from olive.common.utils import StrEnumBase


class OS(str, Enum):
class OS(StrEnumBase):
WINDOWS = "Windows"
LINUX = "Linux"

Expand Down
24 changes: 21 additions & 3 deletions olive/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,39 @@
import shlex
import shutil
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

from olive.common.constants import OS

logger = logging.getLogger(__name__)


if sys.version_info >= (3, 11):
from enum import IntEnum, StrEnum

class StrEnumBase(StrEnum):
pass

class IntEnumBase(IntEnum):
pass

else:
from enum import Enum

class StrEnumBase(str, Enum):
pass

class IntEnumBase(int, Enum):
pass


def run_subprocess(cmd, env=None, cwd=None, check=False):
logger.debug("Running command: %s", cmd)

assert isinstance(cmd, (str, list)), f"cmd must be a string or a list, got {type(cmd)}."
windows = platform.system() == OS.WINDOWS
windows = platform.system() == "Windows"
if isinstance(cmd, str):
# In posix model, the cmd string will be handled with specific posix rules.
# https://docs.python.org/3.8/library/shlex.html#parsing-rules
Expand Down
6 changes: 3 additions & 3 deletions olive/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import Enum
from olive.common.utils import StrEnumBase


class Framework(str, Enum):
class Framework(StrEnumBase):
"""Framework of the model."""

ONNX = "ONNX"
Expand All @@ -16,7 +16,7 @@ class Framework(str, Enum):
OPENVINO = "OpenVINO"


class ModelFileFormat(str, Enum):
class ModelFileFormat(StrEnumBase):
"""Given a framework, there might be 1 or more on-disk model file format(s), model save/Load logic may differ."""

ONNX = "ONNX"
Expand Down
4 changes: 2 additions & 2 deletions olive/data/component/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from enum import Enum
from pathlib import Path
from random import Random
from typing import Callable, Dict, List, Optional, Union
Expand All @@ -14,11 +13,12 @@
from olive.common.config_utils import ConfigBase, validate_config, validate_object
from olive.common.pydantic_v1 import validator
from olive.common.user_module_loader import UserModuleLoader
from olive.common.utils import StrEnumBase
from olive.data.component.dataset import BaseDataset
from olive.data.constants import IGNORE_INDEX


class TextGenStrategy(str, Enum):
class TextGenStrategy(StrEnumBase):
"""Strategy for tokenizing a dataset."""

LINE_BY_LINE = "line-by-line" # each line is a sequence, in order of appearance
Expand Down
10 changes: 5 additions & 5 deletions olive/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from enum import Enum
from olive.common.utils import StrEnumBase

# index for targets that should be ignored when computing metrics
IGNORE_INDEX = -100


class DataComponentType(Enum):
class DataComponentType(StrEnumBase):
"""enumerate for the different types of data components."""

# dataset component type: to load data into memory
Expand All @@ -22,13 +22,13 @@ class DataComponentType(Enum):
DATALOADER = "dataloader"


class DataContainerType(Enum):
class DataContainerType(StrEnumBase):
"""enumerate for the different types of data containers."""

DATA_CONTAINER = "data_container"


class DefaultDataComponent(Enum):
class DefaultDataComponent(StrEnumBase):
"""enumerate for the default data components."""

LOAD_DATASET = "default_load_dataset"
Expand All @@ -37,7 +37,7 @@ class DefaultDataComponent(Enum):
DATALOADER = "default_dataloader"


class DefaultDataContainer(Enum):
class DefaultDataContainer(StrEnumBase):
"""enumerate for the default data containers."""

DATA_CONTAINER = "DataContainer"
6 changes: 3 additions & 3 deletions olive/engine/packaging/packaging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import Enum
from typing import Optional, Union

from olive.common.config_utils import CaseInsensitiveEnum, ConfigBase, NestedConfig, validate_config
from olive.common.constants import BASE_IMAGE
from olive.common.pydantic_v1 import validator
from olive.common.utils import StrEnumBase


class PackagingType(CaseInsensitiveEnum):
Expand Down Expand Up @@ -43,7 +43,7 @@ class DockerfilePackagingConfig(CommonPackagingConfig):
requirements_file: Optional[str] = None


class InferencingServerType(str, Enum):
class InferencingServerType(StrEnumBase):
AzureMLOnline = "AzureMLOnline"
AzureMLBatch = "AzureMLBatch"

Expand All @@ -54,7 +54,7 @@ class InferenceServerConfig(ConfigBase):
scoring_script: str


class AzureMLModelModeType(str, Enum):
class AzureMLModelModeType(StrEnumBase):
download = "download"
copy = "copy"

Expand Down
10 changes: 5 additions & 5 deletions olive/evaluator/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from olive.common.config_utils import ConfigBase, validate_config
from olive.common.pydantic_v1 import validator
from olive.common.utils import StrEnumBase
from olive.data.config import DataConfig
from olive.evaluator.accuracy import AccuracyBase
from olive.evaluator.metric_config import LatencyMetricConfig, MetricGoal, ThroughputMetricConfig, get_user_config_class

logger = logging.getLogger(__name__)


class MetricType(str, Enum):
class MetricType(StrEnumBase):
# TODO(trajep): support throughput
ACCURACY = "accuracy"
LATENCY = "latency"
THROUGHPUT = "throughput"
CUSTOM = "custom"


class AccuracySubType(str, Enum):
class AccuracySubType(StrEnumBase):
ACCURACY_SCORE = "accuracy_score"
F1_SCORE = "f1_score"
PRECISION = "precision"
Expand All @@ -32,7 +32,7 @@ class AccuracySubType(str, Enum):
PERPLEXITY = "perplexity"


class LatencySubType(str, Enum):
class LatencySubType(StrEnumBase):
# unit: millisecond
AVG = "avg"
MAX = "max"
Expand All @@ -45,7 +45,7 @@ class LatencySubType(str, Enum):
P999 = "p999"


class ThroughputSubType(str, Enum):
class ThroughputSubType(StrEnumBase):
# unit: token per second, tps
AVG = "avg"
MAX = "max"
Expand Down
4 changes: 2 additions & 2 deletions olive/hardware/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# --------------------------------------------------------------------------
import logging
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union

from olive.common.utils import StrEnumBase
from olive.hardware.constants import DEVICE_TO_EXECUTION_PROVIDERS

logger = logging.getLogger(__name__)


class Device(str, Enum):
class Device(StrEnumBase):
CPU = "cpu"
CPU_SPR = "cpu_spr"
GPU = "gpu"
Expand Down
6 changes: 3 additions & 3 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import logging
import os
import tempfile
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Union

from olive.common.utils import IntEnumBase, StrEnumBase
from olive.hardware.accelerator import AcceleratorSpec, Device
from olive.model import HfModelHandler, ONNXModelHandler
from olive.model.utils import resolve_onnx_path
Expand All @@ -28,15 +28,15 @@ class ModelBuilder(Pass):
See https://github.com/microsoft/onnxruntime-genai
"""

class Precision(str, Enum):
class Precision(StrEnumBase):
FP32 = "fp32"
FP16 = "fp16"
INT4 = "int4"

def __str__(self) -> str:
return self.value

class AccuracyLevel(int, Enum):
class AccuracyLevel(IntEnumBase):
fp32 = 1
fp16 = 2
bf16 = 3
Expand Down
6 changes: 3 additions & 3 deletions olive/passes/onnx/nvmo_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Union

from olive.common.config_utils import validate_config
from olive.common.utils import StrEnumBase
from olive.data.config import DataConfig
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import OliveModelHandler
Expand All @@ -34,15 +34,15 @@
class NVModelOptQuantization(Pass):
"""Quantize ONNX model with Nvidia-ModelOpt."""

class Precision(str, Enum):
class Precision(StrEnumBase):
FP8 = "fp8"
INT8 = "int8"
INT4 = "int4"

def __str__(self) -> str:
return self.value

class Algorithm(str, Enum):
class Algorithm(StrEnumBase):
RTN = "RTN"
AWQ = "AWQ"

Expand Down
Loading

0 comments on commit 213ac7a

Please sign in to comment.