-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unittest PR for workflow update #14
Conversation
LGTM. You can merge the request after fixing the black-formatting issue. |
827110e
to
9866c7f
Compare
6b32cc6
to
1e439b0
Compare
76b208e
to
e4344cb
Compare
26795eb
to
d55deaf
Compare
def construct_mlp(num_inputs=784, num_classes=10): | ||
return torch.nn.Sequential( | ||
nn.Flatten(), | ||
nn.Linear(num_inputs, 4, bias=False), | ||
nn.ReLU(), | ||
nn.Linear(4, 2, bias=False), | ||
nn.ReLU(), | ||
nn.Linear(2, num_classes, bias=False), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to make to smaller version of the model.
import unittest | ||
import torch | ||
import torchvision | ||
import torch.nn as nn | ||
import numpy as np | ||
import os | ||
|
||
DEVICE = torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only testing the compute_influence as I think the train.py
does not include anything from analog
lib.
I strongly think that this test may have to be removed before the public release from the /tests/
as this is not pertaining to the actual module we are building. However, I am leaving it here as per the internal contributors may overlook the test run of their work in local with the mnist data.
Caveat, this may not well represent the overall behavior of the module as the sheer size of the model architecture.
No description provided.