Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Randperm on PyTorch and MxNet #2084

Merged
merged 4 commits into from
Nov 1, 2022
Merged

Conversation

KexinFeng
Copy link
Contributor

Description

  • A new method randPerm is added.

@codecov-commenter
Copy link

codecov-commenter commented Oct 14, 2022

Codecov Report

Base: 72.08% // Head: 71.36% // Decreases project coverage by -0.71% ⚠️

Coverage data is based on head (e767e89) compared to base (bb5073f).
Patch coverage: 71.66% of modified lines in pull request are covered.

❗ Current head e767e89 differs from pull request most recent head 485fe3b. Consider uploading reports for the commit 485fe3b to get more accurate results

Additional details and impacted files
@@             Coverage Diff              @@
##             master    #2084      +/-   ##
============================================
- Coverage     72.08%   71.36%   -0.72%     
- Complexity     5126     6279    +1153     
============================================
  Files           473      624     +151     
  Lines         21970    27806    +5836     
  Branches       2351     2997     +646     
============================================
+ Hits          15838    19845    +4007     
- Misses         4925     6501    +1576     
- Partials       1207     1460     +253     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
...rc/main/java/ai/djl/modality/cv/output/Joints.java 71.42% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...main/java/ai/djl/modality/cv/output/Rectangle.java 72.41% <0.00%> (ø)
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <0.00%> (-5.24%) ⬇️
.../modality/cv/translator/ImageFeatureExtractor.java 0.00% <0.00%> (ø)
.../ai/djl/modality/cv/translator/YoloTranslator.java 27.77% <0.00%> (+18.95%) ⬆️
...modality/cv/translator/wrapper/FileTranslator.java 44.44% <ø> (ø)
...y/cv/translator/wrapper/InputStreamTranslator.java 44.44% <ø> (ø)
... and 558 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@KexinFeng KexinFeng changed the title Randperm on PyTorch Randperm on PyTorch and MxNet Oct 14, 2022
api/src/main/java/ai/djl/ndarray/BaseNDManager.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/nn/core/Linear.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/translate/StackBatchifier.java Outdated Show resolved Hide resolved
for (int i = 0; i < extraFileKeys.length; i++) {
properties.put(extraFileKeys[i], extraFileValues[i]);
}
// Freeze the parameters if not retrain
for (Pair<String, Parameter> paramPair : block.getParameters()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that retrain has to be added as an option to the model? Can we not just direct users to call freeze() themselves if they don't wish to retrain the model? And as part of that, I am not sure if freezing the parameters is the default that users would expect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yes, this should be replaced by baseBlock.freezeParameters(false);. I'll edit it in the pr #2070.

The reason the default is to set the baseBlock frozen() is that even the pretrained model is not frozen, its learning rate cannot be very large, or the same as the subsequent blocks. Ie the pretrained parameters should not be trained too much. So it is safe to freeze those parameters by default. Maybe we can add some comments about this.

@@ -56,6 +57,9 @@ protected Sgd(Builder builder) {
/** {@inheritDoc} */
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
if (learningRateTracker instanceof FixedPerVarTracker) {
((FixedPerVarTracker) learningRateTracker).setParameterId(parameterId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't great to use a system like this to pass a special case argument of parameterId here. Instead, maybe we could add a new kind of tracker named something like ParameterTracker. The ParamterTracker can require the parameter to get a new value, and all standard Trackers would also be a ParameterTracker if we have Tracker extend ParameterTracker. If we leave it as a special case like this, it would make it difficult if a new tracker that also required parameters were created or if new optimizers were created without the special handling for the FixedPerVarTracker

@@ -44,7 +44,25 @@ Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleLoad__Ljava_lang_String_2_3IZ_3Ljav
map[name] = "";
}

JITCallGuard guard;
if (!jretrain) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find the difference in behavior between with and without retrain. Is there one?

.addTrainingListeners(listener);
}

private static class SoftmaxCrossEntropy extends Loss {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the standard SoftmaxCrossEntropyLoss?

Copy link
Contributor Author

@KexinFeng KexinFeng Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard SoftmaxCrossEntropyLoss doesn't take the output of the softmax function; it currently either take the values before applying softmax, or the values after applying logit ie log(softmax(*)).
Here, I built the network to output softmax. This is mainly because it is natural to understand (as probabilities). Correspondingly, I also corrected the calculation in Accuracy.java. You can take a look if it is correct. (It looks like before the edition, the calculation there does not work well)

training data cut

transfer features
@KexinFeng KexinFeng requested a review from zachgk October 31, 2022 22:05
@KexinFeng KexinFeng merged commit b0427a8 into deepjavalibrary:master Nov 1, 2022
@KexinFeng KexinFeng deleted the randperm branch November 22, 2022 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants