competitive/num/
quad_double.rs

1#![allow(clippy::suspicious_arithmetic_impl)]
2
3use super::{Bounded, Decimal, IterScan, One, Zero};
4use std::{
5    cmp::Ordering,
6    fmt::{self, Display},
7    num::ParseFloatError,
8    ops::{Add, Div, Index, Mul, Neg, Sub},
9    str::FromStr,
10};
11
12/// ref: <https://na-inet.jp/na/qd_ja.pdf>
13#[derive(Clone, Copy, Debug, Default, PartialEq)]
14pub struct QuadDouble(f64, f64, f64, f64);
15
16impl QuadDouble {
17    fn renormalize(a0: f64, a1: f64, a2: f64, a3: f64, a4: f64) -> Self {
18        let (s, t4) = quick_two_sum(a3, a4);
19        let (s, t3) = quick_two_sum(a2, s);
20        let (s, t2) = quick_two_sum(a1, s);
21        let (mut s, t1) = quick_two_sum(a0, s);
22        let mut k = 0;
23        let mut b = [s, t1, t2, t3];
24        for &t in [t1, t2, t3, t4].iter() {
25            let (s_, e) = quick_two_sum(s, t);
26            s = s_;
27            if e != 0. {
28                b[k] = s;
29                s = e;
30                k += 1;
31            }
32        }
33        Self(b[0], b[1], b[2], b[3])
34    }
35}
36
37fn quick_two_sum(a: f64, b: f64) -> (f64, f64) {
38    let s = a + b;
39    let e = b - (s - a);
40    (s, e)
41}
42
43fn two_sum(a: f64, b: f64) -> (f64, f64) {
44    let s = a + b;
45    let v = s - a;
46    let e = (a - (s - v)) + (b - v);
47    (s, e)
48}
49
50fn split(a: f64) -> (f64, f64) {
51    let t = 134_217_729. * a; // 134217729 = 2 ** 27 + 1
52    let ahi = t - (t - a);
53    let alo = a - ahi;
54    (ahi, alo)
55}
56
57fn two_prod(a: f64, b: f64) -> (f64, f64) {
58    let p = a * b;
59    let (ahi, alo) = split(a);
60    let (bhi, blo) = split(b);
61    let e = ((ahi * bhi - p) + ahi * blo + alo * bhi) + alo * blo;
62    (p, e)
63}
64
65fn three_three_sum(a: f64, b: f64, c: f64) -> (f64, f64, f64) {
66    let (u, v) = two_sum(a, b);
67    let (r0, w) = two_sum(u, c);
68    let (r1, r2) = two_sum(v, w);
69    (r0, r1, r2)
70}
71
72fn three_two_sum(a: f64, b: f64, c: f64) -> (f64, f64) {
73    let (u, v) = two_sum(a, b);
74    let (r0, w) = two_sum(u, c);
75    let r1 = v + w;
76    (r0, r1)
77}
78
79fn multiple_three_sum(xs: &[f64]) -> (f64, f64, f64) {
80    let (mut r0, mut r1, mut r2) = (*xs.first().unwrap_or(&0.), 0., 0.);
81    for &x in xs.iter() {
82        let (s, e) = two_sum(r0, x);
83        r0 = s;
84        let (s, e) = two_sum(r1, e);
85        r1 = s;
86        r2 += e;
87    }
88    (r0, r1, r2)
89}
90
91fn multiple_two_sum(xs: &[f64]) -> (f64, f64) {
92    let (mut r0, mut r1) = (*xs.first().unwrap_or(&0.), 0.);
93    for &x in xs.iter() {
94        let (s, e) = two_sum(r0, x);
95        r0 = s;
96        r1 += e;
97    }
98    (r0, r1)
99}
100
101impl Add<f64> for QuadDouble {
102    type Output = Self;
103    fn add(self, rhs: f64) -> Self::Output {
104        let (t0, e) = two_sum(self.0, rhs);
105        let (t1, e) = two_sum(self.1, e);
106        let (t2, e) = two_sum(self.2, e);
107        let (t3, t4) = two_sum(self.3, e);
108        Self::renormalize(t0, t1, t2, t3, t4)
109    }
110}
111
112fn double_accumulate(u: f64, v: f64, x: f64) -> (f64, f64, f64) {
113    let (s, mut v) = two_sum(v, x);
114    let (mut s, mut u) = two_sum(u, s);
115    if u == 0. {
116        u = s;
117        s = 0.;
118    }
119    if v == 0. {
120        v = u;
121        u = s;
122        s = 0.
123    }
124    (s, u, v)
125}
126
127impl Add<QuadDouble> for QuadDouble {
128    type Output = Self;
129    fn add(self, rhs: Self) -> Self::Output {
130        let mut x = [0.; 8];
131        let (mut i, mut j, mut k) = (0, 0, 0);
132        while k < 8 {
133            if j >= 4 || i < 4 && self[i].abs() > rhs[j].abs() {
134                x[k] = self[i];
135                i += 1;
136            } else {
137                x[k] = rhs[j];
138                j += 1;
139            }
140            k += 1;
141        }
142
143        let (mut u, mut v) = (0., 0.);
144        let (mut k, mut i) = (0, 0);
145        let mut c = [0.; 4];
146        while k < 4 && i < 8 {
147            let tpl = double_accumulate(u, v, x[i]);
148            let s = tpl.0;
149            u = tpl.1;
150            v = tpl.2;
151            if s != 0. {
152                c[k] = s;
153                k += 1;
154            }
155            i += 1;
156        }
157        if k < 2 {
158            c[k + 1] = v;
159        }
160        if k < 3 {
161            c[k] = u;
162        }
163        Self::renormalize(c[0], c[1], c[2], c[3], 0.)
164    }
165}
166
167impl Sub for QuadDouble {
168    type Output = Self;
169    fn sub(self, rhs: Self) -> Self::Output {
170        self + -rhs
171    }
172}
173
174impl Neg for QuadDouble {
175    type Output = Self;
176    fn neg(self) -> Self::Output {
177        Self(-self.0, -self.1, -self.2, -self.3)
178    }
179}
180
181impl Mul<f64> for QuadDouble {
182    type Output = Self;
183    fn mul(self, rhs: f64) -> Self::Output {
184        let (t0, e0) = two_prod(self.0, rhs);
185        let (p1, e1) = two_prod(self.1, rhs);
186        let (p2, e2) = two_prod(self.2, rhs);
187        let p3 = self.3 * rhs;
188
189        let (t1, e4) = two_sum(p1, e0);
190        let (t2, e5, e6) = three_three_sum(p2, e1, e4);
191        let (t3, e7) = three_two_sum(p3, e2, e5);
192        let t4 = e7 + e6;
193        Self::renormalize(t0, t1, t2, t3, t4)
194    }
195}
196
197impl Mul<QuadDouble> for QuadDouble {
198    type Output = Self;
199    fn mul(self, rhs: Self) -> Self::Output {
200        let (t0, q00) = two_prod(self.0, rhs.0);
201
202        let (p01, q01) = two_prod(self.0, rhs.1);
203        let (p10, q10) = two_prod(self.1, rhs.0);
204
205        let (p02, q02) = two_prod(self.0, rhs.2);
206        let (p11, q11) = two_prod(self.1, rhs.1);
207        let (p20, q20) = two_prod(self.2, rhs.0);
208
209        let (p03, q03) = two_prod(self.0, rhs.3);
210        let (p12, q12) = two_prod(self.1, rhs.2);
211        let (p21, q21) = two_prod(self.2, rhs.1);
212        let (p30, q30) = two_prod(self.3, rhs.0);
213
214        let p13 = self.1 * rhs.3;
215        let p22 = self.2 * rhs.2;
216        let p31 = self.3 * rhs.1;
217
218        let (t1, e1, e2) = three_three_sum(q00, p01, p10);
219        let (t2, e3, e4) = multiple_three_sum(&[e1, q01, q10, p02, p11, p20]);
220        let (t3, e5) = multiple_two_sum(&[e2, e3, q02, q11, q20, p03, p12, p21, p30]);
221        let t4 = e4 + e5 + q03 + q12 + q21 + q30 + p13 + p22 + p31;
222        Self::renormalize(t0, t1, t2, t3, t4)
223    }
224}
225
226impl Div<QuadDouble> for QuadDouble {
227    type Output = Self;
228    fn div(self, rhs: Self) -> Self::Output {
229        let q0 = self.0 / rhs.0;
230        let r = self - rhs * q0;
231        let q1 = r.0 / rhs.0;
232        let r = r - rhs * q1;
233        let q2 = r.0 / rhs.0;
234        let r = r - rhs * q2;
235        let q3 = r.0 / rhs.0;
236        let r = r - rhs * q3;
237        let q4 = r.0 / rhs.0;
238        Self::renormalize(q0, q1, q2, q3, q4)
239    }
240}
241
242impl Index<usize> for QuadDouble {
243    type Output = f64;
244    fn index(&self, index: usize) -> &Self::Output {
245        match index {
246            0 => &self.0,
247            1 => &self.1,
248            2 => &self.2,
249            3 => &self.3,
250            _ => panic!(),
251        }
252    }
253}
254
255impl From<QuadDouble> for f64 {
256    fn from(x: QuadDouble) -> f64 {
257        x.3 + x.2 + x.1 + x.0
258    }
259}
260
261impl From<QuadDouble> for i64 {
262    fn from(mut x: QuadDouble) -> i64 {
263        let is_neg = x.0.is_sign_negative();
264        if is_neg {
265            x = -x;
266        }
267        let mut i = 0i64;
268        for k in (1..64).rev() {
269            let t = (k as f64).exp2();
270            if x.0 >= t {
271                x = x + -t;
272                i += 1 << k;
273            }
274        }
275        i += x.0.round() as i64;
276        if is_neg {
277            i = -i;
278        }
279        i
280    }
281}
282
283impl From<f64> for QuadDouble {
284    fn from(x: f64) -> Self {
285        Self(x, 0., 0., 0.)
286    }
287}
288
289impl Display for QuadDouble {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        write!(
292            f,
293            "{}",
294            Decimal::from(self.0)
295                + Decimal::from(self.1)
296                + Decimal::from(self.2)
297                + Decimal::from(self.3)
298        )
299    }
300}
301
302#[derive(Debug, Clone)]
303pub enum ParseDoubleDoubleError {
304    ParseFloatError(ParseFloatError),
305    ParseDecimalError(super::decimal::convert::ParseDecimalError),
306}
307
308impl From<ParseFloatError> for ParseDoubleDoubleError {
309    fn from(e: ParseFloatError) -> Self {
310        Self::ParseFloatError(e)
311    }
312}
313
314impl From<super::decimal::convert::ParseDecimalError> for ParseDoubleDoubleError {
315    fn from(e: super::decimal::convert::ParseDecimalError) -> Self {
316        Self::ParseDecimalError(e)
317    }
318}
319
320impl FromStr for QuadDouble {
321    type Err = ParseDoubleDoubleError;
322    fn from_str(s: &str) -> Result<Self, Self::Err> {
323        let f0: f64 = s.parse()?;
324        let d1 = Decimal::from_str(s)? - Decimal::from(f0);
325        let f1: f64 = d1.to_string().parse()?;
326        let d2 = d1 - Decimal::from(f1);
327        let f2: f64 = d2.to_string().parse()?;
328        let d3 = d2 - Decimal::from(f2);
329        let f3: f64 = d3.to_string().parse()?;
330        Ok(Self::renormalize(f0, f1, f2, f3, 0.))
331    }
332}
333
334impl Eq for QuadDouble {}
335impl PartialOrd for QuadDouble {
336    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
337        Some(self.cmp(other))
338    }
339}
340impl Ord for QuadDouble {
341    fn cmp(&self, other: &Self) -> Ordering {
342        fn total_cmp(x: f64, y: f64) -> Ordering {
343            let mut left = x.to_bits() as i64;
344            let mut right = y.to_bits() as i64;
345            left ^= (((left >> 63) as u64) >> 1) as i64;
346            right ^= (((right >> 63) as u64) >> 1) as i64;
347            left.cmp(&right)
348        }
349        total_cmp(self.0, other.0).then_with(|| total_cmp(self.1, other.1))
350    }
351}
352impl Bounded for QuadDouble {
353    fn maximum() -> Self {
354        Self::from(<f64 as Bounded>::maximum())
355    }
356    fn minimum() -> Self {
357        Self::from(<f64 as Bounded>::minimum())
358    }
359}
360
361impl Zero for QuadDouble {
362    fn zero() -> Self {
363        Self::from(0.)
364    }
365    fn is_zero(&self) -> bool
366    where
367        Self: PartialEq,
368    {
369        self.0 == 0.
370    }
371}
372
373impl One for QuadDouble {
374    fn one() -> Self {
375        Self::from(1.)
376    }
377    fn is_one(&self) -> bool
378    where
379        Self: PartialEq,
380    {
381        self.0 == 1.
382    }
383}
384
385impl IterScan for QuadDouble {
386    type Output = Self;
387    fn scan<'a, I: Iterator<Item = &'a str>>(iter: &mut I) -> Option<Self::Output> {
388        iter.next().and_then(|s| s.parse().ok())
389    }
390}
391
392impl QuadDouble {
393    pub fn is_zero(&self) -> bool {
394        self.0 == 0.
395    }
396    pub fn sqrt(self) -> Self {
397        if self.is_zero() {
398            return Self::from(0.);
399        }
400        let x = Self::from(1. / self.0.sqrt());
401        let x = x + x * (Self::from(1.) - self * x * x).div2(2.);
402        let x = x + x * (Self::from(1.) - self * x * x).div2(2.);
403        let x = x + x * (Self::from(1.) - self * x * x).div2(2.);
404        x * self
405    }
406    pub fn abs(self) -> Self {
407        if self.0.is_sign_negative() {
408            -self
409        } else {
410            self
411        }
412    }
413    fn div2(self, rhs: f64) -> Self {
414        Self(self.0 / rhs, self.1 / rhs, self.2 / rhs, self.3 / rhs)
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_display() {
424        let x = QuadDouble::from(1.234);
425        assert_eq!(x.to_string(), "1.234");
426        let x = QuadDouble::from(1.234e-10);
427        assert_eq!(x.to_string(), "0.0000000001234");
428        let x = QuadDouble::from(1.234e10);
429        assert_eq!(x.to_string(), "12340000000");
430        let x = QuadDouble::from(1.234e-10) + QuadDouble::from(1.234e10);
431        assert_eq!(x.to_string(), "12340000000.0000000001234");
432    }
433
434    #[test]
435    fn test_from_str() {
436        let x = QuadDouble::from_str("1.234").unwrap();
437        assert_eq!(x, QuadDouble::from(1.234));
438        let x = QuadDouble::from_str("0.0000000001234").unwrap();
439        assert_eq!(x, QuadDouble::from(1.234e-10));
440        let x = QuadDouble::from_str("12340000000").unwrap();
441        assert_eq!(x, QuadDouble::from(1.234e10));
442        let x = QuadDouble::from_str("12340000000.0000000001234").unwrap();
443        assert_eq!(x, QuadDouble::from(1.234e10) + QuadDouble::from(1.234e-10));
444    }
445}