From ea96fb341787712d2fb0f8c3009096c5b487aa60 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 27 Jun 2024 01:29:24 -0700 Subject: [PATCH] [Bugfix] Fix img_sizes Parsing in Phi3-Vision (#5888) --- vllm/model_executor/models/phi3v.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index dac832a686c2c..578e22beaa3d6 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -65,12 +65,6 @@ def __init__(self, wte=None) -> None: self.type_feature: str self.img_processor: CLIPVisionModel - def set_img_features(self, img_features: torch.FloatTensor) -> None: - self.img_features = img_features - - def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: - self.img_sizes = img_sizes - def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: LAYER_IDX = self.layer_idx @@ -144,21 +138,16 @@ def __init__(self, self.layer_idx = config.img_processor.get('layer_idx', -2) self.type_feature = config.img_processor.get('type_feature', 'patch') - def forward(self, - input_ids: torch.LongTensor, + def forward(self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, - image_sizes=None) -> torch.FloatTensor: + image_sizes: torch.Tensor) -> torch.FloatTensor: """process and merge text embeddings with image embeddings.""" + # (batch_size, max_num_crops, 3, height, width) img_embeds = pixel_values - img_sizes = image_sizes - if self.img_features is not None: - img_embeds = self.img_features.clone() - self.img_features = None - - if self.img_sizes is not None: - img_sizes = self.img_sizes + # (batch_size, 2) + img_sizes = image_sizes input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) @@ -190,11 +179,8 @@ def forward(self, output_imgs = [] output_len = [] - if isinstance(img_sizes, torch.Tensor): - img_sizes.squeeze_(0) - for _bs in range(bs): - h, w = img_sizes + h, w = img_sizes[_bs] h = h // 336 w = w // 336 B_ = h * w