competitive/num/
quad_double.rs

1use super::{Bounded, Decimal, IterScan, One, Zero};
2use std::{
3    cmp::Ordering,
4    fmt::{self, Display},
5    num::ParseFloatError,
6    ops::{Add, Div, Index, Mul, Neg, Sub},
7    str::FromStr,
8};
9
10/// ref: <https://na-inet.jp/na/qd_ja.pdf>
11#[derive(Clone, Copy, Debug, Default, PartialEq)]
12pub struct QuadDouble(f64, f64, f64, f64);
13
14impl QuadDouble {
15    fn renormalize(a0: f64, a1: f64, a2: f64, a3: f64, a4: f64) -> Self {
16        let (s, t4) = quick_two_sum(a3, a4);
17        let (s, t3) = quick_two_sum(a2, s);
18        let (s, t2) = quick_two_sum(a1, s);
19        let (mut s, t1) = quick_two_sum(a0, s);
20        let mut k = 0;
21        let mut b = [s, t1, t2, t3];
22        for &t in [t1, t2, t3, t4].iter() {
23            let (s_, e) = quick_two_sum(s, t);
24            s = s_;
25            if e != 0. {
26                b[k] = s;
27                s = e;
28                k += 1;
29            }
30        }
31        Self(b[0], b[1], b[2], b[3])
32    }
33}
34
35fn quick_two_sum(a: f64, b: f64) -> (f64, f64) {
36    let s = a + b;
37    let e = b - (s - a);
38    (s, e)
39}
40
41fn two_sum(a: f64, b: f64) -> (f64, f64) {
42    let s = a + b;
43    let v = s - a;
44    let e = (a - (s - v)) + (b - v);
45    (s, e)
46}
47
48fn split(a: f64) -> (f64, f64) {
49    let t = 134_217_729. * a; // 134217729 = 2 ** 27 + 1
50    let ahi = t - (t - a);
51    let alo = a - ahi;
52    (ahi, alo)
53}
54
55fn two_prod(a: f64, b: f64) -> (f64, f64) {
56    let p = a * b;
57    let (ahi, alo) = split(a);
58    let (bhi, blo) = split(b);
59    let e = ((ahi * bhi - p) + ahi * blo + alo * bhi) + alo * blo;
60    (p, e)
61}
62
63fn three_three_sum(a: f64, b: f64, c: f64) -> (f64, f64, f64) {
64    let (u, v) = two_sum(a, b);
65    let (r0, w) = two_sum(u, c);
66    let (r1, r2) = two_sum(v, w);
67    (r0, r1, r2)
68}
69
70fn three_two_sum(a: f64, b: f64, c: f64) -> (f64, f64) {
71    let (u, v) = two_sum(a, b);
72    let (r0, w) = two_sum(u, c);
73    let r1 = v + w;
74    (r0, r1)
75}
76
77fn multiple_three_sum(xs: &[f64]) -> (f64, f64, f64) {
78    let (mut r0, mut r1, mut r2) = (*xs.first().unwrap_or(&0.), 0., 0.);
79    for &x in xs.iter() {
80        let (s, e) = two_sum(r0, x);
81        r0 = s;
82        let (s, e) = two_sum(r1, e);
83        r1 = s;
84        r2 += e;
85    }
86    (r0, r1, r2)
87}
88
89fn multiple_two_sum(xs: &[f64]) -> (f64, f64) {
90    let (mut r0, mut r1) = (*xs.first().unwrap_or(&0.), 0.);
91    for &x in xs.iter() {
92        let (s, e) = two_sum(r0, x);
93        r0 = s;
94        r1 += e;
95    }
96    (r0, r1)
97}
98
99impl Add<f64> for QuadDouble {
100    type Output = Self;
101    fn add(self, rhs: f64) -> Self::Output {
102        let (t0, e) = two_sum(self.0, rhs);
103        let (t1, e) = two_sum(self.1, e);
104        let (t2, e) = two_sum(self.2, e);
105        let (t3, t4) = two_sum(self.3, e);
106        Self::renormalize(t0, t1, t2, t3, t4)
107    }
108}
109
110fn double_accumulate(u: f64, v: f64, x: f64) -> (f64, f64, f64) {
111    let (s, mut v) = two_sum(v, x);
112    let (mut s, mut u) = two_sum(u, s);
113    if u == 0. {
114        u = s;
115        s = 0.;
116    }
117    if v == 0. {
118        v = u;
119        u = s;
120        s = 0.
121    }
122    (s, u, v)
123}
124
125impl Add<QuadDouble> for QuadDouble {
126    type Output = Self;
127    fn add(self, rhs: Self) -> Self::Output {
128        let mut x = [0.; 8];
129        let (mut i, mut j, mut k) = (0, 0, 0);
130        while k < 8 {
131            if j >= 4 || i < 4 && self[i].abs() > rhs[j].abs() {
132                x[k] = self[i];
133                i += 1;
134            } else {
135                x[k] = rhs[j];
136                j += 1;
137            }
138            k += 1;
139        }
140
141        let (mut u, mut v) = (0., 0.);
142        let (mut k, mut i) = (0, 0);
143        let mut c = [0.; 4];
144        while k < 4 && i < 8 {
145            let tpl = double_accumulate(u, v, x[i]);
146            let s = tpl.0;
147            u = tpl.1;
148            v = tpl.2;
149            if s != 0. {
150                c[k] = s;
151                k += 1;
152            }
153            i += 1;
154        }
155        if k < 2 {
156            c[k + 1] = v;
157        }
158        if k < 3 {
159            c[k] = u;
160        }
161        Self::renormalize(c[0], c[1], c[2], c[3], 0.)
162    }
163}
164
165impl Sub for QuadDouble {
166    type Output = Self;
167    fn sub(self, rhs: Self) -> Self::Output {
168        self + -rhs
169    }
170}
171
172impl Neg for QuadDouble {
173    type Output = Self;
174    fn neg(self) -> Self::Output {
175        Self(-self.0, -self.1, -self.2, -self.3)
176    }
177}
178
179impl Mul<f64> for QuadDouble {
180    type Output = Self;
181    fn mul(self, rhs: f64) -> Self::Output {
182        let (t0, e0) = two_prod(self.0, rhs);
183        let (p1, e1) = two_prod(self.1, rhs);
184        let (p2, e2) = two_prod(self.2, rhs);
185        let p3 = self.3 * rhs;
186
187        let (t1, e4) = two_sum(p1, e0);
188        let (t2, e5, e6) = three_three_sum(p2, e1, e4);
189        let (t3, e7) = three_two_sum(p3, e2, e5);
190        let t4 = e7 + e6;
191        Self::renormalize(t0, t1, t2, t3, t4)
192    }
193}
194
195impl Mul<QuadDouble> for QuadDouble {
196    type Output = Self;
197    fn mul(self, rhs: Self) -> Self::Output {
198        let (t0, q00) = two_prod(self.0, rhs.0);
199
200        let (p01, q01) = two_prod(self.0, rhs.1);
201        let (p10, q10) = two_prod(self.1, rhs.0);
202
203        let (p02, q02) = two_prod(self.0, rhs.2);
204        let (p11, q11) = two_prod(self.1, rhs.1);
205        let (p20, q20) = two_prod(self.2, rhs.0);
206
207        let (p03, q03) = two_prod(self.0, rhs.3);
208        let (p12, q12) = two_prod(self.1, rhs.2);
209        let (p21, q21) = two_prod(self.2, rhs.1);
210        let (p30, q30) = two_prod(self.3, rhs.0);
211
212        let p13 = self.1 * rhs.3;
213        let p22 = self.2 * rhs.2;
214        let p31 = self.3 * rhs.1;
215
216        let (t1, e1, e2) = three_three_sum(q00, p01, p10);
217        let (t2, e3, e4) = multiple_three_sum(&[e1, q01, q10, p02, p11, p20]);
218        let (t3, e5) = multiple_two_sum(&[e2, e3, q02, q11, q20, p03, p12, p21, p30]);
219        let t4 = e4 + e5 + q03 + q12 + q21 + q30 + p13 + p22 + p31;
220        Self::renormalize(t0, t1, t2, t3, t4)
221    }
222}
223
224impl Div<QuadDouble> for QuadDouble {
225    type Output = Self;
226    fn div(self, rhs: Self) -> Self::Output {
227        let q0 = self.0 / rhs.0;
228        let r = self - rhs * q0;
229        let q1 = r.0 / rhs.0;
230        let r = r - rhs * q1;
231        let q2 = r.0 / rhs.0;
232        let r = r - rhs * q2;
233        let q3 = r.0 / rhs.0;
234        let r = r - rhs * q3;
235        let q4 = r.0 / rhs.0;
236        Self::renormalize(q0, q1, q2, q3, q4)
237    }
238}
239
240impl Index<usize> for QuadDouble {
241    type Output = f64;
242    fn index(&self, index: usize) -> &Self::Output {
243        match index {
244            0 => &self.0,
245            1 => &self.1,
246            2 => &self.2,
247            3 => &self.3,
248            _ => panic!(),
249        }
250    }
251}
252
253impl From<QuadDouble> for f64 {
254    fn from(x: QuadDouble) -> f64 {
255        x.3 + x.2 + x.1 + x.0
256    }
257}
258
259impl From<QuadDouble> for i64 {
260    fn from(mut x: QuadDouble) -> i64 {
261        let is_neg = x.0.is_sign_negative();
262        if is_neg {
263            x = -x;
264        }
265        let mut i = 0i64;
266        for k in (1..64).rev() {
267            let t = (k as f64).exp2();
268            if x.0 >= t {
269                x = x + -t;
270                i += 1 << k;
271            }
272        }
273        i += x.0.round() as i64;
274        if is_neg {
275            i = -i;
276        }
277        i
278    }
279}
280
281impl From<f64> for QuadDouble {
282    fn from(x: f64) -> Self {
283        Self(x, 0., 0., 0.)
284    }
285}
286
287impl Display for QuadDouble {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        write!(
290            f,
291            "{}",
292            Decimal::from(self.0)
293                + Decimal::from(self.1)
294                + Decimal::from(self.2)
295                + Decimal::from(self.3)
296        )
297    }
298}
299
300#[derive(Debug, Clone)]
301pub enum ParseDoubleDoubleError {
302    ParseFloatError(ParseFloatError),
303    ParseDecimalError(super::decimal::convert::ParseDecimalError),
304}
305
306impl From<ParseFloatError> for ParseDoubleDoubleError {
307    fn from(e: ParseFloatError) -> Self {
308        Self::ParseFloatError(e)
309    }
310}
311
312impl From<super::decimal::convert::ParseDecimalError> for ParseDoubleDoubleError {
313    fn from(e: super::decimal::convert::ParseDecimalError) -> Self {
314        Self::ParseDecimalError(e)
315    }
316}
317
318impl FromStr for QuadDouble {
319    type Err = ParseDoubleDoubleError;
320    fn from_str(s: &str) -> Result<Self, Self::Err> {
321        let f0: f64 = s.parse()?;
322        let d1 = Decimal::from_str(s)? - Decimal::from(f0);
323        let f1: f64 = d1.to_string().parse()?;
324        let d2 = d1 - Decimal::from(f1);
325        let f2: f64 = d2.to_string().parse()?;
326        let d3 = d2 - Decimal::from(f2);
327        let f3: f64 = d3.to_string().parse()?;
328        Ok(Self::renormalize(f0, f1, f2, f3, 0.))
329    }
330}
331
332impl Eq for QuadDouble {}
333impl PartialOrd for QuadDouble {
334    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
335        Some(self.cmp(other))
336    }
337}
338impl Ord for QuadDouble {
339    fn cmp(&self, other: &Self) -> Ordering {
340        fn total_cmp(x: f64, y: f64) -> Ordering {
341            let mut left = x.to_bits() as i64;
342            let mut right = y.to_bits() as i64;
343            left ^= (((left >> 63) as u64) >> 1) as i64;
344            right ^= (((right >> 63) as u64) >> 1) as i64;
345            left.cmp(&right)
346        }
347        total_cmp(self.0, other.0).then_with(|| total_cmp(self.1, other.1))
348    }
349}
350impl Bounded for QuadDouble {
351    fn maximum() -> Self {
352        Self::from(<f64 as Bounded>::maximum())
353    }
354    fn minimum() -> Self {
355        Self::from(<f64 as Bounded>::minimum())
356    }
357}
358
359impl Zero for QuadDouble {
360    fn zero() -> Self {
361        Self::from(0.)
362    }
363    fn is_zero(&self) -> bool
364    where
365        Self: PartialEq,
366    {
367        self.0 == 0.
368    }
369}
370
371impl One for QuadDouble {
372    fn one() -> Self {
373        Self::from(1.)
374    }
375    fn is_one(&self) -> bool
376    where
377        Self: PartialEq,
378    {
379        self.0 == 1.
380    }
381}
382
383impl IterScan for QuadDouble {
384    type Output = Self;
385    fn scan<'a, I: Iterator<Item = &'a str>>(iter: &mut I) -> Option<Self::Output> {
386        iter.next().and_then(|s| s.parse().ok())
387    }
388}
389
390impl QuadDouble {
391    pub fn is_zero(&self) -> bool {
392        self.0 == 0.
393    }
394    pub fn sqrt(self) -> Self {
395        if self.is_zero() {
396            return Self::from(0.);
397        }
398        let x = Self::from(1. / self.0.sqrt());
399        let x = x + x * (Self::from(1.) - self * x * x).div2(2.);
400        let x = x + x * (Self::from(1.) - self * x * x).div2(2.);
401        let x = x + x * (Self::from(1.) - self * x * x).div2(2.);
402        x * self
403    }
404    pub fn abs(self) -> Self {
405        if self.0.is_sign_negative() {
406            -self
407        } else {
408            self
409        }
410    }
411    fn div2(self, rhs: f64) -> Self {
412        Self(self.0 / rhs, self.1 / rhs, self.2 / rhs, self.3 / rhs)
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_display() {
422        let x = QuadDouble::from(1.234);
423        assert_eq!(x.to_string(), "1.234");
424        let x = QuadDouble::from(1.234e-10);
425        assert_eq!(x.to_string(), "0.0000000001234");
426        let x = QuadDouble::from(1.234e10);
427        assert_eq!(x.to_string(), "12340000000");
428        let x = QuadDouble::from(1.234e-10) + QuadDouble::from(1.234e10);
429        assert_eq!(x.to_string(), "12340000000.0000000001234");
430    }
431
432    #[test]
433    fn test_from_str() {
434        let x = QuadDouble::from_str("1.234").unwrap();
435        assert_eq!(x, QuadDouble::from(1.234));
436        let x = QuadDouble::from_str("0.0000000001234").unwrap();
437        assert_eq!(x, QuadDouble::from(1.234e-10));
438        let x = QuadDouble::from_str("12340000000").unwrap();
439        assert_eq!(x, QuadDouble::from(1.234e10));
440        let x = QuadDouble::from_str("12340000000.0000000001234").unwrap();
441        assert_eq!(x, QuadDouble::from(1.234e10) + QuadDouble::from(1.234e-10));
442    }
443}