-
Notifications
You must be signed in to change notification settings - Fork 637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Initial version of JAX-Meep wrapper #1569
Conversation
Just wanted to clarify that, as implemented right now, this could be merged without affecting the existing |
cc @smartalecH |
I also wanted to add that much of the code in this class can, in principle, be reused to support manual back-propagation for a minimax-style optimization use case. The underlying steps performed here are essentially the same as what Meep's existing adjoint code is doing. I imagine that we could expose a public API for each step (forward, adjoint, VJP calculation) and provide a parameter to specify whether to reduce the gradient array over its frequency axis. The existing JAX-composable API would remain for scalar loss functions, but the same class / object could be used by users who wish to manually setup a minimax-style optimization. I think this part would look very similar to how the existing |
Looks fantastic from my perspective. Very nice addition. It's a pity that batching isn't supported... it seems odd that there isn't a way around this... Does JAX's |
JAX's This is a design decision made in JAX that is, honestly, pretty clever. It works well for machine learning-like calculations but, unfortunately for us, I think it will be very difficult to define these things for something like Meep. We could look into it though! One hack that has been suggested to me before is to define the reverse mode rule for Meep as the forward mode rule for the JAX primitive, and then define the primitive's transpose rule as an identity operation. |
… jax-meep-wrapper � Conflicts: � python/adjoint/jax/wrapper.py
Here are a few updates to this PR:
|
class VectorComparisonMixin(unittest.TestCase): | ||
"""A mixin for adding proper floating point value and vector comparison.""" | ||
|
||
def assertVectorsClose(self, x, y, epsilon = 1e-2, msg = ''): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful to be consistent with the function naming convention: all--lower-case words separated by an underscore; i.e. assert_vec_close
as suggested in #1575. (This is consistent with the other function defined this file compare_arrays
.)
Also, why does this function need to be part of a mixin class when its arguments are only of a single type: arrays of floating-point numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal of going the mixin route was to use the assertLessEqual()
method of unittest.TestCase
, which tends to have better reporting and testing framework integration than using a vanilla assert
. This strikes me as being cleaner than the approach used by compare_arrays()
which takes test_instance
but isn't a proper method of the test case class. The camel case naming in assertVectorsClose()
matches the convention of the other methods in the TestCase
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at compare_arrays()
, it actually seems to be almost the same comparison that I've implemented... I didn't actually look at it in detail... I just saw this file and dumped my class in here so it might be reused by the other gradient tests in the future. I actually wonder why the other gradient tests don't already use compare_arrays()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, rather than add a new function/class for comparing two arrays as in this PR, it might make more sense to simply revamp compare_arrays
to switch from an L2 to L∞ norm and update the affected tests in a separate PR.
* Initial version of JAX-Meep wrapper * Tests for JAX-Meep wrapper * Factor out more utility functions * Add more type annotations and update tests * Add try / catch for jax portion of adjoint module * Add jax adjoint module and its unit tests to Makefile * Test tweaks * Update conditional import and tweak test parameters
This PR is an initial implementation of a class for wrapping a Meep simulation into a JAX-differentiable callable. This API will allow users to flexibly compose one or more Meep simulations with other JAX functions with support for end-to-end differentiability.
This implementation uses JAX's custom VJP rule, which has one limitation in that it does not support defining batching rules. In practical terms, this unfortunately means that the interface implemented in this PR will not be able to take advantage of Meep's parallelism over frequency in multi-objective (minimax style) optimizations. However, this interface will work well for scalar-valued loss functions that are widely used in inverse design and machine learning (via the
jax.value_and_grad
interface in JAX).In the future, to support frequency-parallelism in multi-objective optimizations, it may be worth investigating whether Meep can be wrapped into a JAX primitive, which has first-class support for batching rules. However, the challenge with this approach will be that, rather than directly defining the reverse mode differentiation rule, JAX requires the definition of a forward mode differentiation rule as well as a transpose rule for primitives. Together these two rules are used to create a reverse mode differentiation rule. It's unclear whether we could pull this off with Meep.
The initial implementation of the JAX-Meep wrapper in this PR only has support for eigenmode coefficient outputs, but in a follow up we can enable support for differentiable DFT fields and near-to-far field transformations. Once some initial feedback has been received, I can add some test cases to this PR.