Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial version of JAX-Meep wrapper #1569

Merged
merged 9 commits into from
May 28, 2021

Conversation

ianwilliamson
Copy link
Contributor

This PR is an initial implementation of a class for wrapping a Meep simulation into a JAX-differentiable callable. This API will allow users to flexibly compose one or more Meep simulations with other JAX functions with support for end-to-end differentiability.

This implementation uses JAX's custom VJP rule, which has one limitation in that it does not support defining batching rules. In practical terms, this unfortunately means that the interface implemented in this PR will not be able to take advantage of Meep's parallelism over frequency in multi-objective (minimax style) optimizations. However, this interface will work well for scalar-valued loss functions that are widely used in inverse design and machine learning (via the jax.value_and_grad interface in JAX).

In the future, to support frequency-parallelism in multi-objective optimizations, it may be worth investigating whether Meep can be wrapped into a JAX primitive, which has first-class support for batching rules. However, the challenge with this approach will be that, rather than directly defining the reverse mode differentiation rule, JAX requires the definition of a forward mode differentiation rule as well as a transpose rule for primitives. Together these two rules are used to create a reverse mode differentiation rule. It's unclear whether we could pull this off with Meep.

The initial implementation of the JAX-Meep wrapper in this PR only has support for eigenmode coefficient outputs, but in a follow up we can enable support for differentiable DFT fields and near-to-far field transformations. Once some initial feedback has been received, I can add some test cases to this PR.

@ianwilliamson ianwilliamson changed the title Initial version of JAX-Meep wrapper WIP: Initial version of JAX-Meep wrapper May 13, 2021
@ianwilliamson
Copy link
Contributor Author

Just wanted to clarify that, as implemented right now, this could be merged without affecting the existing OptimizationProblem interface. The components that this uses from meep.adjoint (just the EigenmodeCoefficient and the DesignRegion) do not rely on HIPS autograd. I think the main question is around organizing different adjoint / gradient interfaces.

@stevengj
Copy link
Collaborator

cc @smartalecH

@ianwilliamson
Copy link
Contributor Author

I also wanted to add that much of the code in this class can, in principle, be reused to support manual back-propagation for a minimax-style optimization use case. The underlying steps performed here are essentially the same as what Meep's existing adjoint code is doing.

I imagine that we could expose a public API for each step (forward, adjoint, VJP calculation) and provide a parameter to specify whether to reduce the gradient array over its frequency axis. The existing JAX-composable API would remain for scalar loss functions, but the same class / object could be used by users who wish to manually setup a minimax-style optimization. I think this part would look very similar to how the existing OptimizationProblem interface is used, but perhaps without embedding the loss function into the class.

@smartalecH
Copy link
Collaborator

smartalecH commented May 18, 2021

Looks fantastic from my perspective. Very nice addition.

It's a pity that batching isn't supported... it seems odd that there isn't a way around this...

Does JAX's vmap operator not work unless a forward rule is defined?

@ianwilliamson
Copy link
Contributor Author

Does JAX's vmap operator not work unless a forward rule is defined?

JAX's vmap uses batching rules for the mapped function, i.e. the batching rules of whatever primitives are composed to define the mapped function. It's not that vmap itself directly requires a forward mode differentiation rule, rather we would need to wrap Meep in a JAX primitive, which is what requires the forward mode rule. We would also need to define a transpose rule that would allow JAX to convert the forward mode rule into a reverse mode rule.

This is a design decision made in JAX that is, honestly, pretty clever. It works well for machine learning-like calculations but, unfortunately for us, I think it will be very difficult to define these things for something like Meep. We could look into it though! One hack that has been suggested to me before is to define the reverse mode rule for Meep as the forward mode rule for the JAX primitive, and then define the primitive's transpose rule as an identity operation.

@ianwilliamson
Copy link
Contributor Author

Here are a few updates to this PR:

  • Added a try / except around import of the JAX adjoint module
  • Added a unittest mixin for floating point vector comparisons (useful for the gradient comparisons)
  • Added unit tests for the JAX adjoint module
  • Updated the Makefile to include the JAX adjoint module and its test

class VectorComparisonMixin(unittest.TestCase):
"""A mixin for adding proper floating point value and vector comparison."""

def assertVectorsClose(self, x, y, epsilon = 1e-2, msg = ''):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to be consistent with the function naming convention: all--lower-case words separated by an underscore; i.e. assert_vec_close as suggested in #1575. (This is consistent with the other function defined this file compare_arrays.)

Also, why does this function need to be part of a mixin class when its arguments are only of a single type: arrays of floating-point numbers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of going the mixin route was to use the assertLessEqual() method of unittest.TestCase, which tends to have better reporting and testing framework integration than using a vanilla assert. This strikes me as being cleaner than the approach used by compare_arrays() which takes test_instance but isn't a proper method of the test case class. The camel case naming in assertVectorsClose() matches the convention of the other methods in the TestCase class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at compare_arrays(), it actually seems to be almost the same comparison that I've implemented... I didn't actually look at it in detail... I just saw this file and dumped my class in here so it might be reused by the other gradient tests in the future. I actually wonder why the other gradient tests don't already use compare_arrays()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, rather than add a new function/class for comparing two arrays as in this PR, it might make more sense to simply revamp compare_arrays to switch from an L2 to L norm and update the affected tests in a separate PR.

@stevengj stevengj merged commit 1061a29 into NanoComp:master May 28, 2021
@ianwilliamson ianwilliamson deleted the jax-meep-wrapper branch May 28, 2021 02:13
bencbartlett pushed a commit to bencbartlett/meep that referenced this pull request Sep 9, 2021
* Initial version of JAX-Meep wrapper

* Tests for JAX-Meep wrapper

* Factor out more utility functions

* Add more type annotations and update tests

* Add try / catch for jax portion of adjoint module

* Add jax adjoint module and its unit tests to Makefile

* Test tweaks

* Update conditional import and tweak test parameters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants