Resnet:
When learning the transition from one input to the direct next output, it is simpler to focus on the difference between the two. This approach ensures that the gradient
Let's consider the layer index
By imagining that
This is a simple ordinary differential equation. We now went from an understandable discrete hidden-layer Network to a description of a dynamic system.
How can we interpret this system now? What is the input, intermediate, and final output?
The general function
Note: In the end, we want to train the internal weights and biases analog to a conventional neural network.
The NODE can be expressed as follows:
Here,
In conventional artificial neural networks, the output of the current layer is the input of the next layer.
Here, we only have one "hidden layer".
Instead of transforming the intermediate output (input)
To understand this even better, we will take a look at one forward pass.
Define the input (e.g. RGB-image) as
Differences, now the function f can be a learnable parameterized neural layer or a combination of layers, as long as the input shape matches the output shape since the output of one-time step will be the input of the next time step.
One Forward-Pass computes the output (and the trajectory) by the well-known Euler method.
The gradients (green arrows) can be computed since we have the formula of
We first compute the gradient of the initial starting point
The result of
We will end at the last point
Note: Because we go in discrete steps defined by
Instead of learning the transformation of the input data from layer to layer, we try to learn the underlying dynamics of the transformation itself. The weights and biases of a NODE-Layer do not only contain the transformative information, which is learned to transform the data from the current to the direct next step (such as in discrete layers), rather it contains the information of transforming the data from every step to the next step.
Instead of stacking more and more layers on top of each other to obtain a more powerful network, a NODE contains only one "Layer" (or sub-network) and computes the output "over time". The dynamics of the transformation, which is desired by the user, are directly learned and not approximated by a big architecture and discrete transformation steps.
The above figure from the original paper highlights the difference in data transformation. ResNets (left) try to learn the transformation (colorful vectors at each depth (y-axis)). NODE (right) tries to learn the underlying vector field itself. Hence, intermediate data can be spread in a non-uniform manner along the depth-axis, which represents the continuous time.
Since, the mismatch of the output, in our case, the output of the model
For example, we can use a standard mean-squared error:
Hence,
Since we already have the trajectory of
So, we do another approach.
At the end of our backward pass, we are interested in following gradient
Using the so-called adjoint sensitivity method, we utilize a mathematical trick to first transform our initial differential equation to another differential equation.
The first step is to define the following adjoint state
One can show, that the adjoint state follows the following relationship:
This relationship comes from a derivation, which is shown in Appendix B.1 in the original paper. For now, we will focus on the backward pass itself.
Moreover, one can show, that the derivative
To solve this integral, we need the trajectory of
So, by starting drawing our trajectory at time step T and computing the derivatives for a(t) and x(t) (since we have the diff. eq. at hand)
we can compute our final derivative
The memory cost of our NODE is constant, i.e.,
Euler`s method is only one type of ODE Solvers, more modern solvers can solve ODEs more efficiently and accurately and have the ability for the user to guarantee maximum growth of the approximation error.
Based on the precision of the ODESolver (which may be defined by the user based on the available resources), the cost of evaluating the NODE model is proportional to the initial complexity. Quote from Chen et al.: "After training, accuracy can be reduced for real-time or low-power applications."
First, we take a look at a simple mapping problem, which we want to solve with an NODE: Imagine our input consists of one-dimensional scalar values, consisting of either -1 or 1. This input value should now be mapped to the corresponding labels (1 (for input=-1) and -1 (for input=1))
We can visualize this problem with the following figure from this paper:
Imagine the input points (red and blue points at the left-hand side) should be transformed along a defined vector space in time to the output points (right-hand side).
Intuitively, we can understand, that this is not possible with a standard NODE, since there is no possibility to define a 2D-vector space in which the two desired trajectories can cross. Note: The mathematical proof of this is in the paper, but let's not bother with this.
To tackle this issue, Dupont et al. introduced augmented neural ODE. The main idea here is to insert additional dimensions to the input, such that the network can find a vector space, in which the desired trajectories are not forced to merge.
For example, a 2D-vector space can be learned such that the mapping is achieved without the two trajectories crossing.
Another example:
It is impossible to train a NODE to map the points inside the blue circle to -1 and the red points to 1. We understand this intuitively since the blue points cannot cross the red circle with a 2D vector field.
The solution to these problems is to insert a third dimension at the input of the NODE to "give the NODE more space" to form a vector space to achieve the desired mapping.
Take a look at the 3D vector space for the latter problem to achieve the desired binary mapping:
Despite the power of NODE, ResNet can easily outperform the former for the above mapping problems.
A ResNet does not learn the underlying dynamics of the transformative system, it learns every discrete step one by one. Hence, the ResNet can insert some form of desired error in the right step of the trajectories, such that the two trajectories can cross in the transformation space. The "vector field" of a simple ResNet can be "distorted" such that the vectors can cross at the right depth of the network.
Take a look at the first two images inside the first figure of # Augmented neural ODE
The following images contain the proof of the Adjoint Sensitivity Formulas:
Note: The images are directly copied from Chen et al.
In my opinion, the most difficult step to understand is step from (40) to (41).
Here, a Taylor series of first order is constructed around
Reminder for a first-order Taylor series for a function
Imagine
How can we obtain the desired gradients
Take a look here:
By generalizing this derivation to a 3-dimensional variable problem, we can obtain our desired gradients, which we need for the Backward-Pass.
Here, it is crucial to understand, that in this formulation not only