|
5 | 5 | from dataclasses import dataclass
|
6 | 6 | from functools import cached_property
|
7 | 7 | from pathlib import Path
|
8 |
| -from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict, |
9 |
| - TypeVar) |
| 8 | +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, |
| 9 | + TypedDict, TypeVar) |
10 | 10 |
|
11 | 11 | import pytest
|
12 | 12 | import torch
|
13 | 13 | import torch.nn as nn
|
14 | 14 | import torch.nn.functional as F
|
15 | 15 | from PIL import Image
|
16 | 16 | from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
17 |
| - AutoProcessor, AutoTokenizer, BatchEncoding) |
| 17 | + AutoTokenizer, BatchEncoding) |
18 | 18 |
|
19 | 19 | from vllm import LLM, SamplingParams
|
20 | 20 | from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
21 | 21 | from vllm.distributed import (destroy_distributed_environment,
|
22 | 22 | destroy_model_parallel)
|
23 | 23 | from vllm.inputs import TextPrompt
|
24 | 24 | from vllm.logger import init_logger
|
25 |
| -from vllm.multimodal import MultiModalData |
26 |
| -from vllm.multimodal.image import ImageFeatureData, ImagePixelData |
| 25 | + |
| 26 | +if TYPE_CHECKING: |
| 27 | + from vllm.multimodal import MultiModalData |
| 28 | +else: |
| 29 | + # it will call torch.cuda.device_count() |
| 30 | + MultiModalData = None |
27 | 31 | from vllm.sequence import SampleLogprobs
|
28 | 32 | from vllm.utils import cuda_device_count_stateless, is_cpu
|
29 | 33 |
|
@@ -63,6 +67,10 @@ def for_hf(self) -> Image.Image:
|
63 | 67 | return self.pil_image
|
64 | 68 |
|
65 | 69 | def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
|
| 70 | + # don't put this import at the top level |
| 71 | + # it will call torch.cuda.device_count() |
| 72 | + from vllm.multimodal.image import ImageFeatureData # noqa: F401 |
| 73 | + from vllm.multimodal.image import ImagePixelData |
66 | 74 | image_input_type = vision_config.image_input_type
|
67 | 75 | ImageInputType = VisionLanguageConfig.ImageInputType
|
68 | 76 |
|
@@ -216,6 +224,9 @@ def __init__(
|
216 | 224 | )
|
217 | 225 |
|
218 | 226 | try:
|
| 227 | + # don't put this import at the top level |
| 228 | + # it will call torch.cuda.device_count() |
| 229 | + from transformers import AutoProcessor # noqa: F401 |
219 | 230 | self.processor = AutoProcessor.from_pretrained(
|
220 | 231 | model_name,
|
221 | 232 | torch_dtype=torch_dtype,
|
|
0 commit comments