-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add diffusion model solution nb #347
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there are different variations of Diffusion model implementation, it would be better to explain which model this notebook will implement. Maybe DDPM?
https://arxiv.org/abs/2006.11239
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 (good to add the link to the second paper too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth adding an explanation about why this schedule is important and why using cosine scheduling? Ideally the link to a paper that introduced the idea.
(DDPM was using linear and quadratic schedule. And DDIM introduced the cosine scheduling to my understanding.)
https://arxiv.org/pdf/2102.09672.pdf
"locally" -> "directly"?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to elaborate why we use U-Net for this since this task is not semantic segmentation.
Also, it's good to add a link to the semantic segmentation with U-Net notebook?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this code block is very large and it includes important concepts, I think it's better to explain the purpose and overview of each function first, and then elaborate the train_step
function.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #143. self.network.weights, self.ema_network.weights
Should explain the purpose of EMA and this network?
Reply via ReviewNB
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #22. tf.keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
I like this callback!
Reply via ReviewNB
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though the generation is plotted during the training, I think it's also nice to have separate generation command after training.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
I added a few comments mainly regarding the markdown descriptions, but this is really a great notebook! |
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have the default to be mnist
since it's likely easier/faster to test.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a markdown cell above explaining a few of these constants, especially: CENTRAL_CROP
, MIN_SIGNAL_RATE
, MAX_SIGNAL_RATE
, PLOT_DIFFUSION_STEPS
, and EMA
.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain in more detail what the underlying process that's happening during training is so that we can understand what the time
,noise_rate
,and signal_rate
are. From the current wording, it sounds like this is a continuous diffusion process rather than a discrete, one: maybe clarify that point too. I am having a hard time making sense of the current explanation, which I find a bit too short:
"We generate the noisy image (for the forward diffusion processes) by weighting random noise by the noise rate, and the training image by the signal rate, then adding them together. "
e.g. what do you mean by "weighting random noise by the noise rate" or "the training image by the signal rate"?
I agree with @takumi about adding explanation about what the diffusion_schedule()
is (and also why it's a function).
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more thorough explanation and visualization.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to this! Maybe a more clear breakdown of what happens in each line in the markdown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done (added visual and more explanation)
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to avoid calling a "private" method like .__next__()
?
What about? data_batch = next(ds_iter)
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give the necessary background to understand what's happening in the following line
noise_rates, signal_rates = diffusion_schedule(diffusion_times)
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be more clear now with the added explanations + comments in the code.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not understanding the following phrases
- "the variances of their noise components"
- "The variances are required since denoising a signal requires different operations at different levels of noise."
- " the model needs to learn how to denoise at different noise rates "
It would be useful to define the terms and expand the explanation. For instance, what does it mean for the model to de-noise at different rates? Do you mean for each different random image you need a different rate? Or do you mean at different steps of the de-noising you need different rates? If so, how do you chose these rates and why do you need different rates? (I don't think I actually understood what you meant so my questions are probably meaningless)
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reworded some things and added more explanation.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you had some explanation before this cell about what this layer does and why we need it?
This is a very long class. Is possible to break this class in smaller chunk and explain the methods?
One way to do that would be to extract some of the methods (the most important) as standalone functions, comment on them in markdown cells, and reuse these functions to define the methods, making the class much shorter.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broke apart sinembedding from other blocks and added markdown.
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #4. # diffusion times -> angles
Can you expand on this comment? Are you saying you're converting dissusion_times to angles?
Reply via ReviewNB
@@ -0,0 +1,723 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a cell on prediction? Just using model.predict should be enough.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM so far! Thanks for building a great notebook!
Solution notebook for training diffusion model for image generation from scratch.
Please test with GPU.