Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add diffusion model solution nb #347

Merged
merged 2 commits into from
Jul 20, 2023
Merged

add diffusion model solution nb #347

merged 2 commits into from
Jul 20, 2023

Conversation

kylesteckler
Copy link
Collaborator

Solution notebook for training diffusion model for image generation from scratch.

Please test with GPU.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are different variations of Diffusion model implementation, it would be better to explain which model this notebook will implement. Maybe DDPM?

https://arxiv.org/abs/2006.11239


Reply via ReviewNB

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 (good to add the link to the second paper too)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding an explanation about why this schedule is important and why using cosine scheduling? Ideally the link to a paper that introduced the idea.

(DDPM was using linear and quadratic schedule. And DDIM introduced the cosine scheduling to my understanding.)

https://arxiv.org/pdf/2102.09672.pdf

"locally" -> "directly"?


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to elaborate why we use U-Net for this since this task is not semantic segmentation.

Also, it's good to add a link to the semantic segmentation with U-Net notebook?


Reply via ReviewNB

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code block is very large and it includes important concepts, I think it's better to explain the purpose and overview of each function first, and then elaborate the train_step function.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #143.                self.network.weights, self.ema_network.weights

Should explain the purpose of EMA and this network?


Reply via ReviewNB

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #22.            tf.keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),

I like this callback!


Reply via ReviewNB

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@takumiohym takumiohym Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the generation is plotted during the training, I think it's also nice to have separate generation command after training.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@takumiohym
Copy link
Collaborator

I added a few comments mainly regarding the markdown descriptions, but this is really a great notebook!
Thanks Kyle!

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* "de-noinsing" --> "de-noising"

Great intro!


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have the default to be mnist since it's likely easier/faster to test.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a markdown cell above explaining a few of these constants, especially: CENTRAL_CROP, MIN_SIGNAL_RATE, MAX_SIGNAL_RATE, PLOT_DIFFUSION_STEPS , and EMA.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a link to the MNIST dataset too.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in more detail what the underlying process that's happening during training is so that we can understand what the time,noise_rate,and signal_rate are. From the current wording, it sounds like this is a continuous diffusion process rather than a discrete, one: maybe clarify that point too. I am having a hard time making sense of the current explanation, which I find a bit too short:

"We generate the noisy image (for the forward diffusion processes) by weighting random noise by the noise rate, and the training image by the signal rate, then adding them together. "

e.g. what do you mean by "weighting random noise by the noise rate" or "the training image by the signal rate"?

I agree with @takumi about adding explanation about what the diffusion_schedule() is (and also why it's a function).


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more thorough explanation and visualization.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need more explanation to understand what is going on in the cell above.


Reply via ReviewNB

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to this! Maybe a more clear breakdown of what happens in each line in the markdown.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (added visual and more explanation)

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to avoid calling a "private" method like .__next__()?

What about? data_batch = next(ds_iter)


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give the necessary background to understand what's happening in the following line

noise_rates, signal_rates = diffusion_schedule(diffusion_times)

Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be more clear now with the added explanations + comments in the code.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not understanding the following phrases

  • "the variances of their noise components"
  • "The variances are required since denoising a signal requires different operations at different levels of noise."
  • " the model needs to learn how to denoise at different noise rates "

It would be useful to define the terms and expand the explanation. For instance, what does it mean for the model to de-noise at different rates? Do you mean for each different random image you need a different rate? Or do you mean at different steps of the de-noising you need different rates? If so, how do you chose these rates and why do you need different rates? (I don't think I actually understood what you meant so my questions are probably meaningless)


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworded some things and added more explanation.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@BenoitDherin BenoitDherin Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you had some explanation before this cell about what this layer does and why we need it?

This is a very long class. Is possible to break this class in smaller chunk and explain the methods?

One way to do that would be to extract some of the methods (the most important) as standalone functions, comment on them in markdown cells, and reuse these functions to define the methods, making the class much shorter.


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broke apart sinembedding from other blocks and added markdown.

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@sanjanalreddy sanjanalreddy Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #4.        # diffusion times -> angles

Can you expand on this comment? Are you saying you're converting dissusion_times to angles?


Reply via ReviewNB

@@ -0,0 +1,723 @@
{
Copy link
Collaborator

@sanjanalreddy sanjanalreddy Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a cell on prediction? Just using model.predict should be enough.


Reply via ReviewNB

@takumiohym takumiohym self-requested a review July 20, 2023 06:58
Copy link
Collaborator

@takumiohym takumiohym left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM so far! Thanks for building a great notebook!

@kylesteckler kylesteckler merged commit 76932d3 into master Jul 20, 2023
@kylesteckler kylesteckler deleted the diffusion-models branch July 20, 2023 13:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants