Skip to content

Commit

Permalink
Generically implement atan2 for all dual numbers (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
prehner authored Nov 14, 2024
1 parent 7e06070 commit 21474f4
Show file tree
Hide file tree
Showing 15 changed files with 429 additions and 69 deletions.
7 changes: 7 additions & 0 deletions src/derivatives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ macro_rules! impl_derivatives {
chain_rule!($deriv, Self::chain_rule(self, f0, f1, f2, f3))
}

#[inline]
fn atan2(&self, other: Self) -> Self {
let mut res = (self / other.clone()).atan();
res.re = self.re.atan2(other.re);
res
}

#[inline]
fn sinh(&self) -> Self {
let s = self.re.sinh();
Expand Down
20 changes: 1 addition & 19 deletions src/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,7 @@ where

#[inline]
fn atan2(self, other: Self) -> Self {
let re = self.re.atan2(other.re);
let eps =
(self.eps * other.re - other.eps * self.re) / (self.re.powi(2) + other.re.powi(2));
Dual::new(re, eps)
DualNum::atan2(&self, other)
}

#[inline]
Expand Down Expand Up @@ -788,18 +785,3 @@ where
Some(Self::from_re(T::max_value()))
}
}

#[cfg(test)]
mod test {
use super::*;
use approx::assert_relative_eq;

#[test]
fn test_atan2() {
let x = Dual64::from(2.0).derivative();
let y = Dual64::from(-3.0);
let z = x.atan2(y);
let z2 = (x / y).atan();
assert_relative_eq!(z.eps, z2.eps, epsilon = 1e-14);
}
}
32 changes: 1 addition & 31 deletions src/dual2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,21 +718,7 @@ where

#[inline]
fn atan2(self, other: Self) -> Self {
let re = self.re.atan2(other.re);
let den = self.re.powi(2) + other.re.powi(2);

let da = other.re / den;
let db = -self.re / den;
let v1 = self.v1 * da + other.v1 * db;

let daa = db * da * (T::one() + T::one());
let dab = db * db - da * da;
let dbb = -daa;
let ca = self.v1 * daa + other.v1 * dab;
let cb = self.v1 * dab + other.v1 * dbb;
let v2 = self.v2 * da + other.v2 * db + ca * self.v1 + cb * other.v1;

Self::new(re, v1, v2)
DualNum::atan2(&self, other)
}

#[inline]
Expand Down Expand Up @@ -862,19 +848,3 @@ where
Some(Self::from_re(T::max_value()))
}
}

#[cfg(test)]
mod test {
use super::*;
use approx::assert_relative_eq;

#[test]
fn test_atan2() {
let x = Dual2_64::from(2.0).derivative();
let y = Dual2_64::from(-3.0);
let z = x.atan2(y);
let z2 = (x / y).atan();
assert_relative_eq!(z.v1, z2.v1, epsilon = 1e-14);
assert_relative_eq!(z.v2, z2.v2, epsilon = 1e-14);
}
}
16 changes: 1 addition & 15 deletions src/dual2_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,21 +763,7 @@ where

#[inline]
fn atan2(self, other: Self) -> Self {
let re = self.re.atan2(other.re);
let den = self.re.powi(2) + other.re.powi(2);

let da = other.re / den;
let db = -self.re / den;
let v1 = &self.v1 * da + &other.v1 * db;

let daa = db * da * (T::one() + T::one());
let dab = db * db - da * da;
let dbb = -daa;
let ca = &self.v1 * daa + &other.v1 * dab;
let cb = &self.v1 * dab + &other.v1 * dbb;
let v2 = self.v2 * da + other.v2 * db + ca.tr_mul(&self.v1) + cb.tr_mul(&other.v1);

Self::new(re, v1, v2)
DualNum::atan2(&self, other)
}

#[inline]
Expand Down
5 changes: 1 addition & 4 deletions src/dual_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,7 @@ where

#[inline]
fn atan2(self, other: Self) -> Self {
let re = self.re.atan2(other.re);
let eps =
(self.eps * other.re - other.eps * self.re) / (self.re.powi(2) + other.re.powi(2));
DualVec::new(re, eps)
DualNum::atan2(&self, other)
}

#[inline]
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ pub trait DualNum<F>:
/// Arctangent
fn atan(&self) -> Self;

/// Arctangent
fn atan2(&self, other: Self) -> Self;

/// Hyperbolic sine
fn sinh(&self) -> Self;

Expand Down Expand Up @@ -316,6 +319,9 @@ macro_rules! impl_dual_num_float {
fn atan(&self) -> Self {
<$float>::atan(*self)
}
fn atan2(&self, other: $float) -> Self {
<$float>::atan2(*self, other)
}
fn sin_cos(&self) -> (Self, Self) {
<$float>::sin_cos(*self)
}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,34 @@ fn test_dual_atan() {
assert!((res.eps - 0.961538461538462).abs() < 1e-12);
}

#[test]
fn test_dual_atan2_1() {
let res = Dual64::from(0.2).derivative().atan2((0.4).into());
assert!((res.re - 0.463647609000806).abs() < 1e-12);
assert!((res.eps - 2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_atan2_2() {
let res = Dual64::from(-0.2).derivative().atan2((0.4).into());
assert!((res.re - -0.463647609000806).abs() < 1e-12);
assert!((res.eps - 2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_atan2_3() {
let res = Dual64::from(0.2).derivative().atan2((-0.4).into());
assert!((res.re - 2.67794504458899).abs() < 1e-12);
assert!((res.eps - -2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_atan2_4() {
let res = Dual64::from(-0.2).derivative().atan2((-0.4).into());
assert!((res.re - -2.67794504458899).abs() < 1e-12);
assert!((res.eps - -2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_sinh() {
let res = Dual64::from(1.2).derivative().sinh();
Expand Down
32 changes: 32 additions & 0 deletions tests/test_dual2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,38 @@ fn test_dual2_atan() {
assert!((res.v2 - -0.369822485207101).abs() < 1e-12);
}

#[test]
fn test_dual2_atan2_1() {
let res = Dual2_64::from(0.2).derivative().atan2((0.4).into());
assert!((res.re - 0.463647609000806).abs() < 1e-12);
assert!((res.v1 - 2.00000000000000).abs() < 1e-12);
assert!((res.v2 - -4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_atan2_2() {
let res = Dual2_64::from(-0.2).derivative().atan2((0.4).into());
assert!((res.re - -0.463647609000806).abs() < 1e-12);
assert!((res.v1 - 2.00000000000000).abs() < 1e-12);
assert!((res.v2 - 4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_atan2_3() {
let res = Dual2_64::from(0.2).derivative().atan2((-0.4).into());
assert!((res.re - 2.67794504458899).abs() < 1e-12);
assert!((res.v1 - -2.00000000000000).abs() < 1e-12);
assert!((res.v2 - 4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_atan2_4() {
let res = Dual2_64::from(-0.2).derivative().atan2((-0.4).into());
assert!((res.re - -2.67794504458899).abs() < 1e-12);
assert!((res.v1 - -2.00000000000000).abs() < 1e-12);
assert!((res.v2 - -4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_sinh() {
let res = Dual2_64::from(1.2).derivative().sinh();
Expand Down
76 changes: 76 additions & 0 deletions tests/test_dual2_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,82 @@ fn test_dual2_vec_atan() {
assert!((v2[(1, 1)] - -0.369822485207101).abs() < 1e-12);
}

#[test]
fn test_dual2_vec_atan2_1() {
let res = Dual2SVec64::new(
0.2,
Derivative::some(RowSVector::from([1.0, 1.0])),
Derivative::none(),
)
.atan2((0.4).into());
let v1 = res.v1.unwrap_generic(Const::<1>, Const::<2>);
let v2 = res.v2.unwrap_generic(Const::<2>, Const::<2>);
assert!((res.re - 0.463647609000806).abs() < 1e-12);
assert!((v1[0] - 2.00000000000000).abs() < 1e-12);
assert!((v1[1] - 2.00000000000000).abs() < 1e-12);
assert!((v2[(0, 0)] - -4.00000000000000).abs() < 1e-12);
assert!((v2[(0, 1)] - -4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 0)] - -4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 1)] - -4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_vec_atan2_2() {
let res = Dual2SVec64::new(
-0.2,
Derivative::some(RowSVector::from([1.0, 1.0])),
Derivative::none(),
)
.atan2((0.4).into());
let v1 = res.v1.unwrap_generic(Const::<1>, Const::<2>);
let v2 = res.v2.unwrap_generic(Const::<2>, Const::<2>);
assert!((res.re - -0.463647609000806).abs() < 1e-12);
assert!((v1[0] - 2.00000000000000).abs() < 1e-12);
assert!((v1[1] - 2.00000000000000).abs() < 1e-12);
assert!((v2[(0, 0)] - 4.00000000000000).abs() < 1e-12);
assert!((v2[(0, 1)] - 4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 0)] - 4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 1)] - 4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_vec_atan2_3() {
let res = Dual2SVec64::new(
0.2,
Derivative::some(RowSVector::from([1.0, 1.0])),
Derivative::none(),
)
.atan2((-0.4).into());
let v1 = res.v1.unwrap_generic(Const::<1>, Const::<2>);
let v2 = res.v2.unwrap_generic(Const::<2>, Const::<2>);
assert!((res.re - 2.67794504458899).abs() < 1e-12);
assert!((v1[0] - -2.00000000000000).abs() < 1e-12);
assert!((v1[1] - -2.00000000000000).abs() < 1e-12);
assert!((v2[(0, 0)] - 4.00000000000000).abs() < 1e-12);
assert!((v2[(0, 1)] - 4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 0)] - 4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 1)] - 4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_vec_atan2_4() {
let res = Dual2SVec64::new(
-0.2,
Derivative::some(RowSVector::from([1.0, 1.0])),
Derivative::none(),
)
.atan2((-0.4).into());
let v1 = res.v1.unwrap_generic(Const::<1>, Const::<2>);
let v2 = res.v2.unwrap_generic(Const::<2>, Const::<2>);
assert!((res.re - -2.67794504458899).abs() < 1e-12);
assert!((v1[0] - -2.00000000000000).abs() < 1e-12);
assert!((v1[1] - -2.00000000000000).abs() < 1e-12);
assert!((v2[(0, 0)] - -4.00000000000000).abs() < 1e-12);
assert!((v2[(0, 1)] - -4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 0)] - -4.00000000000000).abs() < 1e-12);
assert!((v2[(1, 1)] - -4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual2_vec_sinh() {
let res = Dual2SVec64::new(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_dual3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,42 @@ fn test_dual3_atan() {
assert!((res.v3 - -1.56463359126081).abs() < 1e-12);
}

#[test]
fn test_dual3_atan2_1() {
let res = Dual3_64::from(0.2).derivative().atan2((0.4).into());
assert!((res.re - 0.463647609000806).abs() < 1e-12);
assert!((res.v1 - 2.00000000000000).abs() < 1e-12);
assert!((res.v2 - -4.00000000000000).abs() < 1e-12);
assert!((res.v3 - -4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual3_atan2_2() {
let res = Dual3_64::from(-0.2).derivative().atan2((0.4).into());
assert!((res.re - -0.463647609000806).abs() < 1e-12);
assert!((res.v1 - 2.00000000000000).abs() < 1e-12);
assert!((res.v2 - 4.00000000000000).abs() < 1e-12);
assert!((res.v3 - -4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual3_atan2_3() {
let res = Dual3_64::from(0.2).derivative().atan2((-0.4).into());
assert!((res.re - 2.67794504458899).abs() < 1e-12);
assert!((res.v1 - -2.00000000000000).abs() < 1e-12);
assert!((res.v2 - 4.00000000000000).abs() < 1e-12);
assert!((res.v3 - 4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual3_atan2_4() {
let res = Dual3_64::from(-0.2).derivative().atan2((-0.4).into());
assert!((res.re - -2.67794504458899).abs() < 1e-12);
assert!((res.v1 - -2.00000000000000).abs() < 1e-12);
assert!((res.v2 - -4.00000000000000).abs() < 1e-12);
assert!((res.v3 - 4.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual3_sinh() {
let res = Dual3_64::from(1.2).derivative().sinh();
Expand Down
36 changes: 36 additions & 0 deletions tests/test_dual_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,42 @@ fn test_dual_vec_atan() {
assert!((eps[1] - 0.961538461538462).abs() < 1e-12);
}

#[test]
fn test_dual_vec_atan2_1() {
let res = DualSVec64::new(0.2, Derivative::some(Vector::from([1.0, 1.0]))).atan2((0.4).into());
let eps = res.eps.unwrap_generic(Const::<2>, Const::<1>);
assert!((res.re - 0.463647609000806).abs() < 1e-12);
assert!((eps[0] - 2.00000000000000).abs() < 1e-12);
assert!((eps[1] - 2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_vec_atan2_2() {
let res = DualSVec64::new(-0.2, Derivative::some(Vector::from([1.0, 1.0]))).atan2((0.4).into());
let eps = res.eps.unwrap_generic(Const::<2>, Const::<1>);
assert!((res.re - -0.463647609000806).abs() < 1e-12);
assert!((eps[0] - 2.00000000000000).abs() < 1e-12);
assert!((eps[1] - 2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_vec_atan2_3() {
let res = DualSVec64::new(0.2, Derivative::some(Vector::from([1.0, 1.0]))).atan2((-0.4).into());
let eps = res.eps.unwrap_generic(Const::<2>, Const::<1>);
assert!((res.re - 2.67794504458899).abs() < 1e-12);
assert!((eps[0] - -2.00000000000000).abs() < 1e-12);
assert!((eps[1] - -2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_vec_atan2_4() {
let res = DualSVec64::new(-0.2, Derivative::some(Vector::from([1.0, 1.0]))).atan2((-0.4).into());
let eps = res.eps.unwrap_generic(Const::<2>, Const::<1>);
assert!((res.re - -2.67794504458899).abs() < 1e-12);
assert!((eps[0] - -2.00000000000000).abs() < 1e-12);
assert!((eps[1] - -2.00000000000000).abs() < 1e-12);
}

#[test]
fn test_dual_vec_sinh() {
let res = DualSVec64::new(1.2, Derivative::some(Vector::from([1.0, 1.0]))).sinh();
Expand Down
Loading

0 comments on commit 21474f4

Please sign in to comment.