Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds TabNet model for tabular dataset in modelzoo #2036

Merged
merged 10 commits into from
Sep 29, 2022
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/nn/core/SparseMax.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

warthecatalyst marked this conversation as resolved.
Show resolved Hide resolved
package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
Expand Down Expand Up @@ -66,7 +65,7 @@ public SparseMax(int axis, int topK) {
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
// the shape of input and output are the same
return new Shape[0];
return new Shape[] {inputShapes[0]};
}

/** {@inheritDoc} */
Expand Down
44 changes: 44 additions & 0 deletions api/src/main/java/ai/djl/training/loss/TabNetLoss.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/**
* Calculates the loss for tabNet. Actually, loss has been calculated through the forward function
* of tabNet. What's done here is just getting the loss function from prediction.
*/
public class TabNetLoss extends Loss {
/** Calculates the loss of a TabNet instance. */
public TabNetLoss() {
this("TabNetLoss");
}

/**
* Calculates the loss of a TabNet instance.
*
* @param name the name of the loss function
*/
public TabNetLoss(String name) {
super(name);
}

/** {@inheritDoc} */
@Override
public NDArray evaluate(NDList labels, NDList predictions) {
// loss is already calculated inside the forward of tabNet
// so here we just need to get it out from prediction
return predictions.get(1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.integration.tests.model_zoo.tabular;

import ai.djl.Model;
import ai.djl.basicmodelzoo.tabular.TabNet;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.testing.Assertions;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Batch;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.loss.Loss;
import ai.djl.training.loss.TabNetLoss;
import ai.djl.translate.Batchifier;
import ai.djl.util.PairList;

import org.testng.Assert;
import org.testng.annotations.Test;

public class TabNetTest {
@Test
public void testTabNetGLU() {
TrainingConfig config =
new DefaultTrainingConfig(Loss.l2Loss())
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
try (Model model = Model.newInstance("model")) {
model.setBlock(TabNet.tabNetGLUBlock(1));

try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(4));
NDManager manager = trainer.getManager();
NDArray data = manager.create(new float[] {1, 2, 3, 4});
data = data.reshape(2, 2);
// expected calculated through pytorch
NDArray expected = manager.create(new float[] {0.8808f, 2.946f});
NDArray result = trainer.forward(new NDList(data)).singletonOrThrow().squeeze();
Assertions.assertAlmostEquals(result, expected);
}
}
}

@Test
public void testTrainingAndLogic() {
TrainingConfig config =
new DefaultTrainingConfig(new TabNetLoss())
.optDevices(Engine.getInstance().getDevices(2));

Block tabNet = TabNet.builder().setOutDim(10).build();
try (Model model = Model.newInstance("tabNet")) {
model.setBlock(tabNet);
try (Trainer trainer = model.newTrainer(config)) {
int batchSize = 1;
Shape inputShape = new Shape(batchSize, 128);
trainer.initialize(inputShape);
NDManager manager = trainer.getManager();
NDArray input = manager.randomUniform(0, 1, inputShape);
NDArray label = manager.ones(new Shape(batchSize, 10));
Batch batch =
new Batch(
manager.newSubManager(),
new NDList(input),
new NDList(label),
batchSize,
Batchifier.STACK,
Batchifier.STACK,
0,
0);
PairList<String, Parameter> parameters = tabNet.getParameters();
EasyTrain.trainBatch(trainer, batch);
trainer.step();
// the gamma of batchNorm Layer
Assert.assertEquals(
parameters.get(0).getValue().getArray().getShape(), new Shape(1));

// weight of shared fullyConnected Block0
Assert.assertEquals(
parameters.get(4).getValue().getArray().getShape(), new Shape(256, 128));

// the parameter value of a shared fc Block should be the same
Assert.assertEquals(parameters.get(8).getValue(), parameters.get(4).getValue());
Assert.assertEquals(parameters.get(32).getValue(), parameters.get(4).getValue());

// fc's weight of attention Transformer of step01
Assert.assertEquals(
parameters.get(56).getValue().getArray().getShape(), new Shape(128, 64));

// the final fc Block
Assert.assertEquals(
parameters.get(152).getValue().getArray().getShape(), new Shape(10, 64));
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains tests using the engine for {@link ai.djl.basicmodelzoo.tabular}. */
package ai.djl.integration.tests.model_zoo.tabular;
Loading