Skip to content

Commit

Permalink
optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Oct 6, 2024
1 parent 27fea01 commit 21bef83
Showing 1 changed file with 47 additions and 46 deletions.
93 changes: 47 additions & 46 deletions examples/inference/distributed/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import os
import pathlib
Expand All @@ -32,6 +31,20 @@
from accelerate import PartialState


"""
Additional requirements: flash_attn einops timm
pip install flash_attn einops timm
Example:
accelerate launch --num_processes=2 florence2.py --data_path "https://huggingface.co/datasets/pixparse/cc3m-wds/resolve/main/cc3m-train-0000.tar" --output_path outputs --batch_size 12 --num_workers 1 --prompt "<CAPTION>"
On 2x4090: 420it [03:15, 2.15it/s] (~25.8 images/s)
With --prompt "<DETAILED_CAPTION>": 420it [08:16, 1.18s/it] (~10.17 images/s)
"""


def main(
data_path: str,
output_path: str,
Expand Down Expand Up @@ -74,45 +87,35 @@ def __call__(self, x):
else:
return True

def pil_to_bytes(pil_image: Image.Image, format="PNG"):
byte_arr = io.BytesIO()
pil_image.save(byte_arr, format=format)
im_bytes = byte_arr.getvalue()
return im_bytes

def preprocess_fn(sample):
def preprocess_fn(sample, processor):
image: Image.Image = sample["jpg"].convert("RGB")
img_hash = insecure_hashlib.sha1(image.tobytes()).hexdigest()
img_bytes = pil_to_bytes(image)
inputs = processor(
text=prompt,
images=image,
return_tensors="pt",
)
return {
"img_bytes": img_bytes,
"input_ids": inputs["input_ids"],
"pixel_values": inputs["pixel_values"],
"image": image,
"img_hash": img_hash,
"original_caption": sample["txt"],
"height": image.height,
"width": image.width,
}

def collate_fn(examples, processor):
original_images = [Image.open(io.BytesIO(sample["img_bytes"])).convert("RGB") for sample in examples]
inputs = processor(
text=[prompt] * len(original_images),
images=original_images,
return_tensors="pt",
)

img_bytes = [example["img_bytes"] for example in examples]
def collate_fn(examples):
input_ids = torch.cat([example["input_ids"] for example in examples])
pixel_values = torch.cat([example["pixel_values"] for example in examples])
images = [example["image"] for example in examples]
img_hashes = [example["img_hash"] for example in examples]
captions = [example["original_caption"] for example in examples]
heights = [example["height"] for example in examples]
widths = [example["width"] for example in examples]
inputs.update(
{
"img_bytes": img_bytes,
"original_captions": captions,
"heights": heights,
"widths": widths,
}
)
return dict(inputs)
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"images": images,
"img_hashes": img_hashes,
"original_captions": captions,
}

exist_filter = ExistsFilter(output_dir)
dataset = (
Expand All @@ -121,16 +124,17 @@ def collate_fn(examples, processor):
handler=wds.warn_and_continue,
nodesplitter=None,
shardshuffle=False,
empty_check=False,
)
.decode("pil", handler=wds.warn_and_continue)
.map(preprocess_fn, handler=wds.warn_and_continue)
.map(partial(preprocess_fn, processor=processor), handler=wds.warn_and_continue)
)
if len(exist_filter.current_training_img_hashes) > 0:
dataset = dataset.select(exist_filter)
dataset = dataset.batched(
batch_size,
partial=False,
collation_fn=partial(collate_fn, processor=processor),
collation_fn=collate_fn,
)
dataloader = wds.WebLoader(
dataset,
Expand All @@ -146,24 +150,22 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, processor)
item = output_queue.get(timeout=5)
if item is None:
break
original_captions, predictions, img_bytes, heights, widths = item
original_captions, predictions, images, img_hashes = item
predicted_captions = processor.batch_decode(
predictions,
skip_special_tokens=False,
)
for caption, pred_caption, img_byte, height, width in zip(
original_captions, predicted_captions, img_bytes, heights, widths
for caption, pred_caption, image, img_hash in zip(
original_captions, predicted_captions, images, img_hashes
):
processed_caption = processor.post_process_generation(
pred_caption, task=prompt, image_size=(width, height)
pred_caption, task=prompt, image_size=(image.width, image.height)
)[prompt]
original_image = Image.open(io.BytesIO(img_byte)).convert("RGB")
hash_image = insecure_hashlib.sha1(original_image.tobytes()).hexdigest()
img_path = output_dir.joinpath(f"{hash_image}.jpg")
original_image.save(img_path)
img_path = output_dir.joinpath(f"{img_hash}.jpg")
image.save(img_path)

caption_dict = {"original": caption, "predicted": processed_caption}
with output_dir.joinpath(f"{hash_image}_caption.json").open("w") as f:
with output_dir.joinpath(f"{img_hash}_caption.json").open("w") as f:
json.dump(caption_dict, f, indent=4)

except queue.Empty:
Expand All @@ -189,9 +191,8 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, processor)
(
batch["original_captions"],
outputs,
batch["img_bytes"],
batch["heights"],
batch["widths"],
batch["images"],
batch["img_hashes"],
)
)
finally:
Expand Down

0 comments on commit 21bef83

Please sign in to comment.