competitive/num/
dual_number.rs

1use super::{One, Zero};
2use std::{
3    iter::{Product, Sum},
4    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
8pub struct DualNumber<T>(pub T, pub T);
9
10impl<T> DualNumber<T> {
11    pub fn transpose(self) -> Self {
12        Self(self.1, self.0)
13    }
14}
15impl<T> Zero for DualNumber<T>
16where
17    T: Zero,
18{
19    fn zero() -> Self {
20        Self(T::zero(), T::zero())
21    }
22}
23impl<T> One for DualNumber<T>
24where
25    T: Zero + One,
26{
27    fn one() -> Self {
28        Self(T::one(), T::zero())
29    }
30}
31impl<T> DualNumber<T>
32where
33    T: Zero + One,
34{
35    pub fn epsilon() -> Self {
36        Self(T::zero(), T::one())
37    }
38}
39impl<T> DualNumber<T>
40where
41    T: Neg<Output = T>,
42{
43    pub fn conjugate(self) -> Self {
44        Self(self.0, -self.1)
45    }
46}
47impl<T> DualNumber<T>
48where
49    T: Add<Output = T> + Mul<Output = T>,
50{
51    pub fn eval(self, eps: T) -> T {
52        self.0 + self.1 * eps
53    }
54}
55impl<T> DualNumber<T>
56where
57    T: Div<Output = T> + Neg<Output = T>,
58{
59    pub fn root(self) -> T {
60        -self.0 / self.1
61    }
62}
63
64impl<T> Add for DualNumber<T>
65where
66    T: Add<Output = T>,
67{
68    type Output = Self;
69    fn add(self, rhs: Self) -> Self::Output {
70        Self(self.0 + rhs.0, self.1 + rhs.1)
71    }
72}
73impl<T> Add<T> for DualNumber<T>
74where
75    T: Add<Output = T>,
76{
77    type Output = Self;
78    fn add(self, rhs: T) -> Self::Output {
79        Self(self.0 + rhs, self.1)
80    }
81}
82impl<T> Sub for DualNumber<T>
83where
84    T: Sub<Output = T>,
85{
86    type Output = Self;
87    fn sub(self, rhs: Self) -> Self::Output {
88        Self(self.0 - rhs.0, self.1 - rhs.1)
89    }
90}
91impl<T> Sub<T> for DualNumber<T>
92where
93    T: Sub<Output = T>,
94{
95    type Output = Self;
96    fn sub(self, rhs: T) -> Self::Output {
97        Self(self.0 - rhs, self.1)
98    }
99}
100impl<T> Mul for DualNumber<T>
101where
102    T: Clone + Add<Output = T> + Sub<Output = T> + Mul<Output = T>,
103{
104    type Output = Self;
105    fn mul(self, rhs: Self) -> Self::Output {
106        Self(
107            self.0.clone() * rhs.0.clone(),
108            self.0 * rhs.1 + self.1 * rhs.0,
109        )
110    }
111}
112impl<T> Mul<T> for DualNumber<T>
113where
114    T: Clone + Mul<Output = T>,
115{
116    type Output = Self;
117    fn mul(self, rhs: T) -> Self::Output {
118        Self(self.0 * rhs.clone(), self.1 * rhs)
119    }
120}
121impl<T> Div for DualNumber<T>
122where
123    T: Clone + One + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
124{
125    type Output = Self;
126    fn div(self, rhs: Self) -> Self::Output {
127        let d = T::one() / rhs.1.clone();
128        Self(
129            self.0.clone() * d.clone(),
130            (self.1 * rhs.0 - self.0 * rhs.1) * d.clone() * d,
131        )
132    }
133}
134impl<T> Div<T> for DualNumber<T>
135where
136    T: Clone + One + Div<Output = T>,
137{
138    type Output = Self;
139    fn div(self, rhs: T) -> Self::Output {
140        let d = T::one() / rhs.clone();
141        Self(self.0 / d.clone(), self.1 / d)
142    }
143}
144impl<T> Neg for DualNumber<T>
145where
146    T: Neg<Output = T>,
147{
148    type Output = Self;
149    fn neg(self) -> Self::Output {
150        Self(-self.0, -self.1)
151    }
152}
153macro_rules! impl_dual_number_ref_binop {
154    (impl<$T:ident> $imp:ident $method:ident ($l:ty, $r:ty) where $($w:ident)* $(+ $($v:ident)*)?) => {
155        impl<$T> $imp<$r> for &$l
156        where
157            $T: Clone $(+ $w<Output = $T>)* $($(+ $v)*)?,
158        {
159            type Output = <$l as $imp<$r>>::Output;
160            fn $method(self, rhs: $r) -> <$l as $imp<$r>>::Output {
161                $imp::$method(self.clone(), rhs)
162            }
163        }
164        impl<$T> $imp<&$r> for $l
165        where
166            $T: Clone $(+ $w<Output = $T>)* $($(+ $v)*)?,
167        {
168            type Output = <$l as $imp<$r>>::Output;
169            fn $method(self, rhs: &$r) -> <$l as $imp<$r>>::Output {
170                $imp::$method(self, rhs.clone())
171            }
172        }
173        impl<$T> $imp<&$r> for &$l
174        where
175            $T: Clone $(+ $w<Output = $T>)* $($(+ $v)*)?,
176        {
177            type Output = <$l as $imp<$r>>::Output;
178            fn $method(self, rhs: &$r) -> <$l as $imp<$r>>::Output {
179                $imp::$method(self.clone(), rhs.clone())
180            }
181        }
182    };
183}
184impl_dual_number_ref_binop!(impl<T> Add add (DualNumber<T>, DualNumber<T>) where Add);
185impl_dual_number_ref_binop!(impl<T> Add add (DualNumber<T>, T) where Add);
186impl_dual_number_ref_binop!(impl<T> Sub sub (DualNumber<T>, DualNumber<T>) where Sub);
187impl_dual_number_ref_binop!(impl<T> Sub sub (DualNumber<T>, T) where Sub);
188impl_dual_number_ref_binop!(impl<T> Mul mul (DualNumber<T>, DualNumber<T>) where Add Sub Mul);
189impl_dual_number_ref_binop!(impl<T> Mul mul (DualNumber<T>, T) where Mul);
190impl_dual_number_ref_binop!(impl<T> Div div (DualNumber<T>, DualNumber<T>) where Add Sub Mul Div + One);
191impl_dual_number_ref_binop!(impl<T> Div div (DualNumber<T>, T) where Div + One);
192macro_rules! impl_dual_number_ref_unop {
193    (impl<$T:ident> $imp:ident $method:ident ($t:ty) where $($w:ident)*) => {
194        impl<$T> $imp for &$t
195        where
196            $T: Clone $(+ $w<Output = $T>)*,
197        {
198            type Output = <$t as $imp>::Output;
199            fn $method(self) -> <$t as $imp>::Output {
200                $imp::$method(self.clone())
201            }
202        }
203    };
204}
205impl_dual_number_ref_unop!(impl<T> Neg neg (DualNumber<T>) where Neg);
206macro_rules! impl_dual_number_op_assign {
207    (impl<$T:ident> $imp:ident $method:ident ($l:ty, $r:ty) $fromimp:ident $frommethod:ident where $($w:ident)* $(+ $($v:ident)*)?) => {
208        impl<$T> $imp<$r> for $l
209        where
210            $T: Clone $(+ $w<Output = $T>)* $($(+ $v)*)?,
211        {
212            fn $method(&mut self, rhs: $r) {
213                *self = $fromimp::$frommethod(self.clone(), rhs);
214            }
215        }
216        impl<$T> $imp<&$r> for $l
217        where
218            $T: Clone $(+ $w<Output = $T>)* $($(+ $v)*)?,
219        {
220            fn $method(&mut self, rhs: &$r) {
221                $imp::$method(self, rhs.clone());
222            }
223        }
224    };
225}
226impl_dual_number_op_assign!(impl<T> AddAssign add_assign (DualNumber<T>, DualNumber<T>) Add add where Add);
227impl_dual_number_op_assign!(impl<T> AddAssign add_assign (DualNumber<T>, T) Add add where Add);
228impl_dual_number_op_assign!(impl<T> SubAssign sub_assign (DualNumber<T>, DualNumber<T>) Sub sub where Sub);
229impl_dual_number_op_assign!(impl<T> SubAssign sub_assign (DualNumber<T>, T) Sub sub where Sub);
230impl_dual_number_op_assign!(impl<T> MulAssign mul_assign (DualNumber<T>, DualNumber<T>) Mul mul where Add Sub Mul);
231impl_dual_number_op_assign!(impl<T> MulAssign mul_assign (DualNumber<T>, T) Mul mul where Mul);
232impl_dual_number_op_assign!(impl<T> DivAssign div_assign (DualNumber<T>, DualNumber<T>) Div div where Add Sub Mul Div + One);
233impl_dual_number_op_assign!(impl<T> DivAssign div_assign (DualNumber<T>, T) Div div where Div + One);
234macro_rules! impl_dual_number_fold {
235    (impl<$T:ident> $imp:ident $method:ident ($t:ty) $identimp:ident $identmethod:ident $fromimp:ident $frommethod:ident where $($w:ident)* $(+ $x:ident)*) => {
236        impl<$T> $imp for $t
237        where
238            $T: $identimp $(+ $w<Output = $T>)* $(+ $x)*,
239        {
240            fn $method<I: Iterator<Item = Self>>(iter: I) -> Self {
241                iter.fold(<Self as $identimp>::$identmethod(), $fromimp::$frommethod)
242            }
243        }
244        impl<'a, $T: 'a> $imp<&'a $t> for $t
245        where
246            $T: Clone + $identimp $(+ $w<Output = $T>)* $(+ $x)*,
247        {
248            fn $method<I: Iterator<Item = &'a $t>>(iter: I) -> Self {
249                iter.fold(<Self as $identimp>::$identmethod(), $fromimp::$frommethod)
250            }
251        }
252    };
253}
254impl_dual_number_fold!(impl<T> Sum sum (DualNumber<T>) Zero zero Add add where Add);
255impl_dual_number_fold!(impl<T> Product product (DualNumber<T>) One one Mul mul where Add Sub Mul + Zero + Clone);