Skip to content

Commit

Permalink
[Refactor] ladi-vton model의 I/O 형식 변경 #33
Browse files Browse the repository at this point in the history
- in memory로 input data를 불러들이는 방식으로 변경
- 기존에는 경로에 저장된 이미지를 읽는 방식이었음
- 함수의 input으로 parsing image, 3D keypoint dictionary, target과
garmetn의 image byte, garment mask 값이 들어감

related to #31
  • Loading branch information
Hyunmin-H committed Aug 17, 2023
1 parent e132d86 commit 0cc7df2
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 12 deletions.
5 changes: 3 additions & 2 deletions backend/app/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def main():
import time
t = time.time()
response = requests.post("http://localhost:8001/order", files=files)
response.raise_for_status() ## 200이 아니면 예외처리
# response.raise_for_status() ## 200이 아니면 예외처리
print('total processing time: ', time.time() - t)

empty_slot.empty()
Expand All @@ -310,7 +310,8 @@ def main():
if category =='upper_lower':
final_img = Image.open(os.path.join(final_result_dir, 'lower_body.png'))
else :
final_img = Image.open(os.path.join(final_result_dir, f'{category}.png'))
# final_img = Image.open(os.path.join(final_result_dir, f'{category}.png'))
final_img = response.content

st.write(' ')
st.write(' ')
Expand Down
24 changes: 15 additions & 9 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from fastapi import FastAPI, UploadFile, File
from fastapi.param_functions import Depends
from pydantic import BaseModel, Field
from fastapi.responses import StreamingResponse
from uuid import UUID, uuid4
from typing import List, Union, Optional, Dict, Any
from datetime import datetime
from PIL import Image
import io

Expand All @@ -23,7 +22,7 @@

from get_clothing_mask import main_mask, main_mask_fromImageByte

from inference import main_ladi
from inference import main_ladi, main_ladi_fromImageByte
from face_cut_and_paste import main_cut_and_paste

import torch
Expand All @@ -39,6 +38,10 @@
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/backend/gcp')
from cloud_storage import GCSUploader, load_gcp_config_from_yaml


sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/backend/app')
from .utils import PIL2Byte

app = FastAPI()
ladi_models = None

Expand Down Expand Up @@ -158,13 +161,14 @@ def inference_allModels(target_bytes, garment_bytes, category, db_dir):

garment_mask.save('./garment_mask.jpg')

exit()
# ladi-vton
output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer')
os.makedirs(output_ladi_buffer_dir, exist_ok=True)

main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models)
main_cut_and_paste(category, db_dir)
# main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models)
finalResult_img = main_ladi_fromImageByte(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models)
finalResult_img = main_cut_and_paste(category, target_bytes, finalResult_img, schp_img)
return finalResult_img

def inference_ladi(category, db_dir, target_name='target.jpg'):
input_dir = os.path.join(db_dir, 'input')
Expand Down Expand Up @@ -231,7 +235,7 @@ async def make_order(files: List[UploadFile] = File(...)):
garment_image_lower.save(f'{input_dir}/buffer/garment/lower_body.jpg')


inference_allModels('upper_body', db_dir)
finalResult_img = inference_allModels('upper_body', db_dir)
shutil.copy(os.path.join(db_dir, 'ladi/buffer', 'upper_body.png'), f'{input_dir}/buffer/target/upper_body.jpg')
inference_ladi('lower_body', db_dir, target_name='upper_body.jpg')

Expand All @@ -253,6 +257,8 @@ async def make_order(files: List[UploadFile] = File(...)):

gcs.upload_blob(garment_bytes, f'{input_dir}/buffer/garment/{category}.jpg')

inference_allModels(target_bytes, garment_bytes, category, user_name)
finalResult_img = inference_allModels(target_bytes, garment_bytes, category, user_name)

return None
finalResult_bytes = PIL2Byte(finalResult_img)
gcs.upload_blob(finalResult_bytes, f'{input_dir}/ladi/buffer/final.jpg')
return StreamingResponse(io.BytesIO(finalResult_bytes), media_type="image/jpg")
7 changes: 7 additions & 0 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

import io
def PIL2Byte(image):
image_bytes = io.BytesIO()
image.save(image_bytes, format="JPEG")
image_bytes = image_bytes.getvalue()
return image_bytes
195 changes: 195 additions & 0 deletions model/ladi_vton/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, AutoProcessor

from dataset.dresscode import DressCodeDataset
from dataset.ganddddi import GanddddiDataset
from dataset.vitonhd import VitonHDDataset
from models.AutoencoderKL import AutoencoderKL

Expand Down Expand Up @@ -319,6 +320,200 @@ def main_ladi(category_, db_dir, output_buffer_dir, ladi_models, target_name='ta
with open(os.path.join(save_dir, f"metrics_{args.test_order}_{args.category}.json"), "w+") as f:
json.dump(metrics, f, indent=4)

def main_ladi_fromImageByte(category_, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models, target_name='target.jpg'):
args = parse_args()
args.category = category_

# Setup accelerator and device.
accelerator = Accelerator(mixed_precision=args.mixed_precision)
device = accelerator.device

# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

t = time.time()

weight_dtype = torch.float32
if args.mixed_precision == 'fp16':
weight_dtype = torch.float16

val_scheduler, text_encoder,vae , vision_encoder ,processor ,tokenizer ,unet ,emasc ,inversion_adapter, tps ,refinement = ladi_models

print('***Ladi load time', time.time() - t)


int_layers = [1, 2, 3, 4, 5]

# Enable xformers memory efficient attention if requested
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

# Load the datasets

outputlist = ['image', 'pose_map', 'inpaint_mask', 'im_mask', 'category', 'cloth']

test_dataset = GanddddiDataset(
target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask,
phase='test',
order=args.test_order,
radius=5,
outputlist=outputlist,
category=category_,
size=(512, 384),
target_name=target_name
)

test_dataloader = torch.utils.data.DataLoader(
test_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
)



# Set to eval mode
text_encoder.eval()
vae.eval()
emasc.eval()
inversion_adapter.eval()
unet.eval()
tps.eval()
refinement.eval()
vision_encoder.eval()

# Create the pipeline
val_pipe = StableDiffusionTryOnePipeline(
text_encoder=text_encoder,
vae=vae,
tokenizer=tokenizer,
unet=unet,
scheduler=val_scheduler,
emasc=emasc,
emasc_int_layers=int_layers,
).to(device)

# Prepare the dataloader and create the output directory
test_dataloader = accelerator.prepare(test_dataloader)
save_dir = os.path.join(args.output_dir, args.test_order) ## 수정
save_dir = args.output_dir
os.makedirs(save_dir, exist_ok=True)
generator = torch.Generator("cuda").manual_seed(args.seed)

# Generate the images
for idx, batch in enumerate(tqdm(test_dataloader)):
model_img = batch.get("image").to(weight_dtype)
mask_img = batch.get("inpaint_mask").to(weight_dtype)
if mask_img is not None:
mask_img = mask_img.to(weight_dtype)
pose_map = batch.get("pose_map").to(weight_dtype)
category = batch.get("category")
cloth = batch.get("cloth").to(weight_dtype)
im_mask = batch.get('im_mask').to(weight_dtype)

# Generate the warped cloth
# For sake of performance, the TPS parameters are predicted on a low resolution image

low_cloth = torchvision.transforms.functional.resize(cloth, (256, 192),
torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True)
low_im_mask = torchvision.transforms.functional.resize(im_mask, (256, 192),
torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True)
low_pose_map = torchvision.transforms.functional.resize(pose_map, (256, 192),
torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True)
agnostic = torch.cat([low_im_mask, low_pose_map], 1)

low_grid, theta, rx, ry, cx, cy, rg, cg = tps(low_cloth, agnostic)

# We upsample the grid to the original image size and warp the cloth using the predicted TPS parameters
highres_grid = torchvision.transforms.functional.resize(low_grid.permute(0, 3, 1, 2),
size=(512, 384),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True).permute(0, 2, 3, 1)

warped_cloth = F.grid_sample(cloth, highres_grid, padding_mode='border')

# Refine the warped cloth using the refinement network
warped_cloth = torch.cat([im_mask, pose_map, warped_cloth], 1)
warped_cloth = refinement(warped_cloth)
warped_cloth = warped_cloth.clamp(-1, 1)

# Get the visual features of the in-shop cloths
input_image = torchvision.transforms.functional.resize((cloth + 1) / 2, (224, 224),
antialias=True).clamp(0, 1)
processed_images = processor(images=input_image, return_tensors="pt")
clip_cloth_features = vision_encoder(
processed_images.pixel_values.to(model_img.device, dtype=weight_dtype)).last_hidden_state

# Compute the predicted PTEs
word_embeddings = inversion_adapter(clip_cloth_features.to(model_img.device))
word_embeddings = word_embeddings.reshape((word_embeddings.shape[0], args.num_vstar, -1))

category_text = {
'dresses': 'a dress',
'upper_body': 'an upper body garment',
'lower_body': 'a lower body garment',
}
text = [f'a photo of a model wearing {category_text[category]} {" $ " * args.num_vstar}' for
category in batch['category']]

# Tokenize text
tokenized_text = tokenizer(text, max_length=tokenizer.model_max_length, padding="max_length",
truncation=True, return_tensors="pt").input_ids
tokenized_text = tokenized_text.to(word_embeddings.device)

# Encode the text using the PTEs extracted from the in-shop cloths
encoder_hidden_states = encode_text_word_embedding(text_encoder, tokenized_text,
word_embeddings, args.num_vstar).last_hidden_state

# Generate images
generated_images = val_pipe(
image=model_img,
mask_image=mask_img,
pose_map=pose_map,
warped_cloth=warped_cloth,
prompt_embeds=encoder_hidden_states,
height=512,
width=384,
guidance_scale=args.guidance_scale,
num_images_per_prompt=1,
generator=generator,
cloth_input_type='warped',
num_inference_steps=args.num_inference_steps
).images

# Save images
for gen_image, cat in zip(generated_images, category):
pass
# if not os.path.exists(os.path.join(save_dir, cat)):
# os.makedirs(os.path.join(save_dir, cat))

return gen_image


# Free up memory
del val_pipe
del text_encoder
del vae
del emasc
del unet
del tps
del refinement
del vision_encoder
torch.cuda.empty_cache()

if args.compute_metrics:
metrics = compute_metrics(save_dir, args.test_order, args.dataset, args.category, ['all'],
None, args.vitonhd_dataroot)

with open(os.path.join(save_dir, f"metrics_{args.test_order}_{args.category}.json"), "w+") as f:
json.dump(metrics, f, indent=4)

# if __name__ == "__main__":
# main()
1 change: 0 additions & 1 deletion model/ladi_vton/src/dataset/dresscode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def __getitem__(self, index):
# Clothing image
cloth = Image.open(os.path.join(dataroot, 'input/buffer/garment', c_name))

#############수정 해야함 !! test로 mask 없앰
# mask = Image.open(os.path.join(dataroot, 'mask/buffer', c_name.replace(".jpg", ".png")))
mask = Image.open(os.path.join(dataroot, 'mask/buffer', c_name))

Expand Down

0 comments on commit 0cc7df2

Please sign in to comment.