Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloading Revamp #3216

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from

Conversation

AntonioMacaronio
Copy link
Contributor

@AntonioMacaronio AntonioMacaronio commented Jun 12, 2024

Problems and Background

  • With a sufficiently large enough dataset, the current parallel_datamanager.py will try to cache the entire dataset into RAM, which will lead to an OOM error
  • parallel_datamanager.py only uses one worker to generate ray bundles. Since various subprocesses such as unprojecting during ray generation, or pixel sampling within a custom mask can be a CPU-intensive task, it may be better suited to parallelize this. While parallel_datamanager.py does support multiple workers, each worker caches the entire dataset to RAM and it does not support massive datasets, leading to duplicate copies of the dataset in computer memory.
  • Additionally, both VanillaDataManager and ParallelDataManager rely on CacheDataloader, which subclasses torch.utils.data.DataLoader, which is a strange coding practice
  • Similarly for full_images_datamanager.py: As we can not fit the entire dataset in RAM, the current implementation loads in entire dataset into the FullImageDataloader's cached_train attribute. To do this efficiently, we need multiprocess parallelization to load in a batch of images (support for batched image dataloading since gsplat now supports batched rastuerization)

Overview of Changes

  • Replacing CacheDataloader with RayBatchStream, which subclasses torch.utils.data.IterableDataset. The goal of this class is to generate ray bundles directly without caching all images to RAM. This is done by collating a sampled batch of images to sample from.
  • Adding an ImageBatchStream to expand while simplifying FullImageDataloader
  • A new pil_to_numpy() function is added. This function reads a PIL.Image's data buffer and fills an empty numpy array while reading, hastening the conversion process and removing an extra memory allocation. It is the fastest way to get from a PIL Image to a Pytorch tensor averaging ~2.5ms for a 1080x1920 image (~40% faster)

Impact

  • Checkout these comparisons! The left was trained on 200 images of a 4k video, while the right was trained on 2000 images of the same 4k video.

Copy link
Contributor

@pwais pwais left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice progress! sorry its not fast but i think i know why:

i think the main reason this is slower than expected is because _get_collated_batch() gets called per raybundle and sadly _get_collated_batch() is AFAIK needlessly slow.

  • take note about how the current CachedDataloader avoids doing _get_collated_batch() per raybundle. it would have been nice for the author to have left some notes about how slow _get_collated_batch() is, but evidently that author found it's necessary to not collate images per raybundle .
  • in my impl, I just _get_collated_batch() once on a small set of images an keep that batch cached. the main problem I saw is that _get_collated_batch() on thousands of images seemed to use 2x or 3x as much RAM as actually needed and thus cause many minutes of swapping and stuff

Even if you only call _get_collated_batch() once tho, you might need a bigger prefetch factor and/or more workers depending on the model.

IMO it's worth trying to find a way to get the result of nerfstudio_collate on cameras (I think the cameras do need to be collated because they can be ragged? i could be wrong and they don't need collation) but on images just have the worker read image files / buffers and never call collate on those tensors.

Just to be clear, this is the line where collate on images can go nuts and start taking forever to allocate 200GB or more of RAM for many images in code in main:

storage = elem.storage()._new_shared(numel, device=elem.device)

So! If a worker is just emitting raybundles then the images never need to be in shared tensor memory then eh? Thus should be able to save some RAM and CPU by skipping that line for images. Still need to think about the cost of reading the images themselves, but collate is definitely a troublemaker.

"""The limit number of batches a worker will start loading once an iterator is created.
Each next() call on the iterator has the CPU prepare more batches up to this
limit while the GPU is performing forward and backward passes on the model."""
dataloader_num_workers: int = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW for a 3090 i was using 16 workers and prefetch factor of 16, and a train ray batch size of 24000. And I was getting the same "rays per sec" in the console output or better as with the in-repo impl (ParallelDataManger). Steady-state I was using less than 16 CPU I believe. If batch size and prefetch is small, then definitely need more workers

self.device = device
self.collate_fn = collate_fn
# self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults
self.num_image_load_threads = num_image_load_threads # kwargs.get("num_workers", 4) # nb only 4 in defaults
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really a hack to hide disk I/O ... I only needed 2 here. it really depends on how much RAM / disk cache the user has.

# print(indices)
# print(type(batch_list[0])) # prints <class 'dict'>
# print(self.collate_fn) # prints nerfstudio_collate
collated_batch = self.collate_fn(batch_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wish we knew if there's some way to get rid of collate on images because that appears to be the biggest waste

nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
if self.config.use_ray_train_dataloader:
import torch.multiprocessing as mp

mp.set_start_method("spawn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should get removed and something farther up the stack should call if needed. i think we shouldn't need it if the workers don't use cuda?

Copy link
Contributor

@pwais pwais left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just took a quick look (can't do a full review right now), so cool to see this coming along!!

Sounds like this change will target the case that uncompressed image tensors can't fit in RAM, but the raw image files (typically jpeg) do fit in RAM. In that case I guess we do want each worker to literally load the file bytes into Python RAM (as implemented) versus let the OS disk cache work, because the idea is that the uncompressed image tensors will otherwise blow out the disk cache.

I think it would be important to test in the end like a case where the user only has limited RAM (say 16GB) and e.g. a 8GB laptop graphics card, in that case I think there are moderate or larger image datasets where the whole thing would OOM when using the current cache impl. In that case, it would be helpful to have some way to disable the cache, or just communicate to the user that they simply have too weak of a machine for the dataset (e.g. just a CONSOLE.print("[bold yellow]Warning ...") in the line where the workers start reading image files into RAM.

nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/datamanagers/base_datamanager.py Outdated Show resolved Hide resolved
nerfstudio/data/utils/data_utils.py Show resolved Hide resolved
nerfstudio/data/utils/dataloaders.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pwais pwais left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow the visual results look amazing! Thank you so much for continuing to hack on this!

class FullImageBatchStreamConfig(DataManagerConfig):
_target: Type = field(default_factory=lambda: ImageBatchStream)

## Let's implement a parallelized splat dataloader!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

"indices": torch.cat([batch_i["indices"] for batch_i in batch_list], dim=0),
}
# end = time.time()
# print((end - start) * 1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, this function is pretty fast right?

@@ -178,6 +206,7 @@ def __init__(self):
self.train_count = 0
self.eval_count = 0
if self.train_dataset and self.test_mode != "inference":
# print(self.setup_train) # prints <bound method ParallelFullImageDatamanager.setup_train of ParallelFullImageDatamanager()>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# print(self.setup_train) # prints <bound method ParallelFullImageDatamanager.setup_train of ParallelFullImageDatamanager()>

self.exclude_batch_keys_from_device = exclude_batch_keys_from_device
# print("self.exclude_batch_keys_from_device", self.exclude_batch_keys_from_device) # usually prints ['image']
self.datamanager_config = datamanager_config
self.pixel_sampler: PixelSampler = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.pixel_sampler: PixelSampler = None
self.pixel_sampler: Optional[PixelSampler] = None # lazy init

# print("self.exclude_batch_keys_from_device", self.exclude_batch_keys_from_device) # usually prints ['image']
self.datamanager_config = datamanager_config
self.pixel_sampler: PixelSampler = None
self.ray_generator: RayGenerator = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.ray_generator: RayGenerator = None
self.ray_generator: Optional[RayGenerator] = None # lazy init

camera.metadata["cam_idx"] = idx
i += 1
if torch.sum(camera.camera_to_worlds) == 0:
print(i, camera.camera_to_worlds, "YOYO INSIDE IMAGEBATCHSTREAM")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yoyo im inside ur GPU eating ur RAMs!

camera.metadata = {}
camera.metadata["cam_idx"] = idx
i += 1
if torch.sum(camera.camera_to_worlds) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ratrow, this doesn't mean the splat optimizer will get empty poses as well?

@@ -724,7 +724,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
render_mode = "RGB+ED"
else:
render_mode = "RGB"

# breakpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# breakpoint()

@@ -118,25 +123,101 @@ def read_trajectory_csv_to_dict(file_iterable_csv: str) -> TimedPoses:
)


def undistort_image_and_calibration(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the changes in this file / module should be broken out into separate PR at some point. and could probably ship sooner too then (?)

self.input_dataset = input_dataset
self.device = device

def __iter__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iteresting so this works fast enough to undistort and not need caching of the undistorted image? curious what CPU / GPU combo and number of workers was good here.

or maybe it does slow training a bit, but rather it works without OOM on bigger datasets :)

Copy link
Contributor Author

@AntonioMacaronio AntonioMacaronio Sep 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it does slow training quite a bit. With 1 worker, training time went from 9 minutes -> 18 minutes (with downsampling to a resolution inside 1600x1600)

It also depends on what type of undistortion is occuring. I only have tested this with datasets of CAMERA.PERSPECTIVE types that had small amounts of radial and tangential distortion, but with fisheye and equirectangular camera models, it will definitely take longer to undistort, will do some further benchmarks on this to find how many workers are needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting! 2x is not that bad, given that this change makes training feasible when low RAM.

maybe caching the undistorted image could help then? so in your other code, you try to cache the jpeg bytes which helps, and clearly here you can't necessarily cache the compressed image and so you have to cache the undistorted image.

all that said, it's very nice to even have training enabled even if u don't implement / test caching

@AntonioMacaronio AntonioMacaronio changed the title Dataloading revamp Dataloading Revamp Sep 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants