Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progress in app + sys mem fallback fix #76

Merged
merged 3 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,16 @@ def resize_crop_image(img: PIL.Image.Image, tgt_width, tgt_height):
return img

# Function to generate text-to-video
def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution):
def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, progress=gr.Progress()):
progress(0, desc="Loading model")
print("[DEBUG] generate_text_to_video called.")
variant = '768p' if resolution == "768p" else '384p'
height = height_high if resolution == "768p" else height_low
width = width_high if resolution == "768p" else width_low

def progress_callback(i, m):
progress(i/m)

# Initialize model based on user options using cached function
try:
model, torch_dtype_selected = initialize_model_cached(variant)
Expand All @@ -179,6 +183,7 @@ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, r
output_type="pil",
cpu_offloading=cpu_offloading,
save_memory=True,
callback=progress_callback,
)
print("[INFO] Text-to-video generation completed.")
except Exception as e:
Expand All @@ -195,7 +200,8 @@ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, r
return video_path

# Function to generate image-to-video
def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolution):
def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolution, progress=gr.Progress()):
progress(0, desc="Loading model")
print("[DEBUG] generate_image_to_video called.")
variant = '768p' if resolution == "768p" else '384p'
height = height_high if resolution == "768p" else height_low
Expand All @@ -208,6 +214,9 @@ def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolutio
print(f"[ERROR] Error processing image: {e}")
return f"Error processing image: {e}"

def progress_callback(i, m):
progress(i/m)

# Initialize model based on user options using cached function
try:
model, torch_dtype_selected = initialize_model_cached(variant)
Expand All @@ -227,6 +236,7 @@ def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolutio
output_type="pil",
cpu_offloading=cpu_offloading,
save_memory=True,
callback=progress_callback,
)
print("[INFO] Image-to-video generation completed.")
except Exception as e:
Expand Down
15 changes: 15 additions & 0 deletions pyramid_dit/pyramid_dit_for_video_gen_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import os
import gc
import sys
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -317,6 +318,7 @@ def generate_i2v(
save_memory: bool = True,
cpu_offloading: bool = False, # If true, reload device will be cuda.
inference_multigpu: bool = False,
callback: Optional[Callable[[int, int, Dict], None]] = None,
):
device = self.device if not cpu_offloading else torch.device("cuda")
dtype = self.dtype
Expand Down Expand Up @@ -429,6 +431,12 @@ def generate_i2v(
torch.cuda.empty_cache()

for unit_index in tqdm(range(1, num_units)):
gc.collect()
torch.cuda.empty_cache()

if callback:
callback(unit_index, num_units)

if use_linear_guidance:
self._guidance_scale = guidance_scale_list[unit_index]
self._video_guidance_scale = guidance_scale_list[unit_index]
Expand Down Expand Up @@ -519,6 +527,7 @@ def generate(
save_memory: bool = True,
cpu_offloading: bool = False, # If true, reload device will be cuda.
inference_multigpu: bool = False,
callback: Optional[Callable[[int, int, Dict], None]] = None,
):
device = self.device if not cpu_offloading else torch.device("cuda")
dtype = self.dtype
Expand Down Expand Up @@ -613,6 +622,12 @@ def generate(
last_generated_latents = None

for unit_index in tqdm(range(num_units)):
gc.collect()
torch.cuda.empty_cache()

if callback:
callback(unit_index, num_units)

if use_linear_guidance:
self._guidance_scale = guidance_scale_list[unit_index]
self._video_guidance_scale = guidance_scale_list[unit_index]
Expand Down