Skip to content

Commit

Permalink
add test for torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 30, 2023
1 parent d3e0eb7 commit 4f9ecd3
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions tests/logger/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
def create_mlp(input_size, hidden_size, num_classes):
model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Tanh(),
nn.Linear(hidden_size, num_classes),
)
return model
Expand Down Expand Up @@ -123,6 +123,49 @@ def compute_loss_func(_params, _buffers, _batch):
func_grad = grads_dict[module_name + ".weight"]
self.assertTrue(torch.allclose(analog_grad, func_grad, atol=1e-6))

def test_per_sample_gradient_with_compile(self):
# Instantiate AnaLog
analog = AnaLog(project="test")
analog.watch(self.model)

compiled_model = torch.compile(self.model)

# Input and target for batch size of 4
inputs = torch.randn(4, 4) # Dummy input
labels = torch.tensor([1, 3, 0, 2]) # Dummy labels
batch = (inputs, labels)

# functorch per-sample gradient
def compute_loss_func(_params, _buffers, _batch):
_output = torch.func.functional_call(
self.func_model,
(_params, _buffers),
args=(_batch[0],),
)
_loss = F.cross_entropy(_output, _batch[1])
return _loss

func_compute_grad = torch.func.grad(compute_loss_func, has_aux=False)

grads_dict = torch.func.vmap(
func_compute_grad,
in_dims=(None, None, 0),
randomness="same",
)(self.func_params, self.func_buffers, batch)

# Forward pass with original model
with analog(data_id=inputs, log=["grad"], hessian=False, save=False):
compiled_model.zero_grad()
output = compiled_model(inputs)
loss = F.cross_entropy(output, labels, reduction="sum")
loss.backward()
analog_grads_dict = analog.get_log()

for module_name in analog_grads_dict:
analog_grad = analog_grads_dict[module_name]
func_grad = grads_dict[module_name + ".weight"]
self.assertTrue(torch.allclose(analog_grad, func_grad, atol=1e-6))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4f9ecd3

Please sign in to comment.