-
Notifications
You must be signed in to change notification settings - Fork 654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Randperm on PyTorch and MxNet #2084
Conversation
Codecov ReportBase: 72.08% // Head: 71.36% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #2084 +/- ##
============================================
- Coverage 72.08% 71.36% -0.72%
- Complexity 5126 6279 +1153
============================================
Files 473 624 +151
Lines 21970 27806 +5836
Branches 2351 2997 +646
============================================
+ Hits 15838 19845 +4007
- Misses 4925 6501 +1576
- Partials 1207 1460 +253
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java
Outdated
Show resolved
Hide resolved
for (int i = 0; i < extraFileKeys.length; i++) { | ||
properties.put(extraFileKeys[i], extraFileValues[i]); | ||
} | ||
// Freeze the parameters if not retrain | ||
for (Pair<String, Parameter> paramPair : block.getParameters()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason that retrain has to be added as an option to the model? Can we not just direct users to call freeze()
themselves if they don't wish to retrain the model? And as part of that, I am not sure if freezing the parameters is the default that users would expect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yes, this should be replaced by baseBlock.freezeParameters(false);
. I'll edit it in the pr #2070.
The reason the default is to set the baseBlock frozen() is that even the pretrained model is not frozen, its learning rate cannot be very large, or the same as the subsequent blocks. Ie the pretrained parameters should not be trained too much. So it is safe to freeze those parameters by default. Maybe we can add some comments about this.
@@ -56,6 +57,9 @@ protected Sgd(Builder builder) { | |||
/** {@inheritDoc} */ | |||
@Override | |||
public void update(String parameterId, NDArray weight, NDArray grad) { | |||
if (learningRateTracker instanceof FixedPerVarTracker) { | |||
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't great to use a system like this to pass a special case argument of parameterId here. Instead, maybe we could add a new kind of tracker named something like ParameterTracker
. The ParamterTracker can require the parameter to get a new value, and all standard Trackers would also be a ParameterTracker
if we have Tracker extend ParameterTracker
. If we leave it as a special case like this, it would make it difficult if a new tracker that also required parameters were created or if new optimizers were created without the special handling for the FixedPerVarTracker
@@ -44,7 +44,25 @@ Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljav | |||
map[name] = ""; | |||
} | |||
|
|||
JITCallGuard guard; | |||
if (!jretrain) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find the difference in behavior between with and without retrain. Is there one?
examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java
Outdated
Show resolved
Hide resolved
.addTrainingListeners(listener); | ||
} | ||
|
||
private static class SoftmaxCrossEntropy extends Loss { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use the standard SoftmaxCrossEntropyLoss
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The standard SoftmaxCrossEntropyLoss
doesn't take the output of the softmax function; it currently either take the values before applying softmax, or the values after applying logit
ie log(softmax(*))
.
Here, I built the network to output softmax. This is mainly because it is natural to understand (as probabilities). Correspondingly, I also corrected the calculation in Accuracy.java. You can take a look if it is correct. (It looks like before the edition, the calculation there does not work well)
Description
randPerm
is added.