competitive/num/decimal/
addsub.rs

1use super::*;
2use std::{
3    mem::replace,
4    ops::{Add, AddAssign, Sub, SubAssign},
5};
6
7fn add_carry(carry: bool, lhs: u64, rhs: u64, out: &mut u64) -> bool {
8    let mut sum = lhs + rhs + carry as u64;
9    let cond = sum >= RADIX;
10    if cond {
11        sum -= RADIX;
12    }
13    *out = sum;
14    cond
15}
16
17fn add_absolute_parts(lhs: &mut Decimal, rhs: &Decimal) {
18    let mut carry = false;
19
20    // decimal part
21    let lhs_decimal_len = lhs.decimal.len();
22    if lhs_decimal_len < rhs.decimal.len() {
23        for (l, r) in lhs
24            .decimal
25            .iter_mut()
26            .rev()
27            .zip(rhs.decimal[..lhs_decimal_len].iter().rev())
28        {
29            carry = add_carry(carry, *l, *r, l);
30        }
31        lhs.decimal
32            .extend_from_slice(&rhs.decimal[lhs_decimal_len..]);
33    } else {
34        for (l, r) in lhs.decimal[..rhs.decimal.len()]
35            .iter_mut()
36            .rev()
37            .zip(rhs.decimal.iter().rev())
38        {
39            carry = add_carry(carry, *l, *r, l);
40        }
41    }
42
43    // integer part
44    let lhs_integer_len = lhs.integer.len();
45    if lhs_integer_len < rhs.integer.len() {
46        for (l, r) in lhs.integer.iter_mut().zip(&rhs.integer[..lhs_integer_len]) {
47            carry = add_carry(carry, *l, *r, l);
48        }
49        lhs.integer
50            .extend_from_slice(&rhs.integer[lhs_integer_len..]);
51        if carry {
52            for l in lhs.integer[lhs_integer_len..].iter_mut() {
53                carry = add_carry(carry, *l, 0, l);
54                if !carry {
55                    break;
56                }
57            }
58        }
59    } else {
60        for (l, r) in lhs.integer.iter_mut().zip(&rhs.integer) {
61            carry = add_carry(carry, *l, *r, l);
62        }
63        if carry {
64            for l in lhs.integer[rhs.integer.len()..].iter_mut() {
65                carry = add_carry(carry, *l, 0, l);
66                if !carry {
67                    break;
68                }
69            }
70        }
71    }
72
73    if carry {
74        lhs.integer.push(carry as u64);
75    }
76
77    lhs.normalize();
78}
79
80fn sub_borrow(borrow: bool, lhs: u64, rhs: u64, out: &mut u64) -> bool {
81    let (sum, borrow1) = lhs.overflowing_sub(rhs);
82    let (mut sum, borrow2) = sum.overflowing_sub(borrow as u64);
83    let borrow = borrow1 || borrow2;
84    if borrow {
85        sum = sum.wrapping_add(RADIX);
86    }
87    *out = sum;
88    borrow
89}
90
91// assume |lhs| >= |rhs|
92fn sub_absolute_parts_gte(lhs: &Decimal, rhs: &mut Decimal) {
93    debug_assert!(matches!(lhs.cmp_absolute_parts(rhs), Ordering::Greater));
94
95    let mut borrow = false;
96
97    // decimal part
98    let rhs_decimal_len = rhs.decimal.len();
99    if lhs.decimal.len() > rhs_decimal_len {
100        for (l, r) in lhs.decimal[..rhs_decimal_len]
101            .iter()
102            .rev()
103            .zip(rhs.decimal.iter_mut().rev())
104        {
105            borrow = sub_borrow(borrow, *l, *r, r);
106        }
107        rhs.decimal
108            .extend_from_slice(&lhs.decimal[rhs_decimal_len..]);
109    } else {
110        for r in rhs.decimal[lhs.decimal.len()..].iter_mut().rev() {
111            borrow = sub_borrow(borrow, 0, *r, r);
112        }
113        for (l, r) in lhs
114            .decimal
115            .iter()
116            .rev()
117            .zip(rhs.decimal[..lhs.decimal.len()].iter_mut().rev())
118        {
119            borrow = sub_borrow(borrow, *l, *r, r);
120        }
121    }
122
123    // integer part
124    let rhs_integer_len = rhs.integer.len();
125    if lhs.integer.len() > rhs_integer_len {
126        for (l, r) in lhs.integer[..rhs_integer_len]
127            .iter()
128            .zip(rhs.integer.iter_mut())
129        {
130            borrow = sub_borrow(borrow, *l, *r, r);
131        }
132        rhs.integer
133            .extend_from_slice(&lhs.integer[rhs_integer_len..]);
134        if borrow {
135            for r in rhs.integer[rhs_integer_len..].iter_mut() {
136                borrow = sub_borrow(borrow, *r, 0, r);
137                if !borrow {
138                    break;
139                }
140            }
141        }
142    } else {
143        debug_assert_eq!(lhs.integer.len(), rhs_integer_len);
144        for (l, r) in lhs.integer.iter().zip(&mut rhs.integer) {
145            borrow = sub_borrow(borrow, *l, *r, r);
146        }
147    }
148
149    assert!(
150        !borrow,
151        "Cannot subtract lhs from rhs because lhs is smaller than rhs"
152    );
153
154    rhs.normalize();
155}
156
157macro_rules! add {
158    ($lhs:expr, $lhs_owned:expr, $rhs:expr, $rhs_owned:expr) => {
159        match ($lhs.sign, $rhs.sign) {
160            (Sign::Zero, _) => $rhs_owned,
161            (_, Sign::Zero) => $lhs_owned,
162            (Sign::Plus, Sign::Plus) | (Sign::Minus, Sign::Minus) => {
163                let mut lhs = $lhs_owned;
164                add_absolute_parts(&mut lhs, &$rhs);
165                lhs
166            }
167            (Sign::Plus, Sign::Minus) | (Sign::Minus, Sign::Plus) => {
168                match $lhs.cmp_absolute_parts(&$rhs) {
169                    Ordering::Less => {
170                        let mut lhs = $lhs_owned;
171                        sub_absolute_parts_gte(&$rhs, &mut lhs);
172                        lhs.sign = $rhs.sign;
173                        lhs
174                    }
175                    Ordering::Equal => ZERO,
176                    Ordering::Greater => {
177                        let mut rhs = $rhs_owned;
178                        sub_absolute_parts_gte(&$lhs, &mut rhs);
179                        rhs.sign = $lhs.sign;
180                        rhs
181                    }
182                }
183            }
184        }
185    };
186}
187
188macro_rules! sub {
189    ($lhs:expr, $lhs_owned:expr, $rhs:expr, $rhs_owned:expr) => {
190        match ($lhs.sign, $rhs.sign) {
191            (Sign::Zero, _) => -$rhs_owned,
192            (_, Sign::Zero) => $lhs_owned,
193            (Sign::Plus, Sign::Minus) | (Sign::Minus, Sign::Plus) => {
194                let mut lhs = $lhs_owned;
195                add_absolute_parts(&mut lhs, &$rhs);
196                lhs
197            }
198            (Sign::Plus, Sign::Plus) | (Sign::Minus, Sign::Minus) => {
199                match $lhs.cmp_absolute_parts(&$rhs) {
200                    Ordering::Less => {
201                        let mut lhs = $lhs_owned;
202                        sub_absolute_parts_gte(&$rhs, &mut lhs);
203                        lhs.sign = -$rhs.sign;
204                        lhs
205                    }
206                    Ordering::Equal => ZERO,
207                    Ordering::Greater => {
208                        let mut rhs = $rhs_owned;
209                        sub_absolute_parts_gte(&$lhs, &mut rhs);
210                        rhs
211                    }
212                }
213            }
214        }
215    };
216}
217
218macro_rules! impl_binop {
219    (impl $Trait:ident for Decimal, $method:ident, $macro:ident) => {
220        impl $Trait<Decimal> for Decimal {
221            type Output = Decimal;
222
223            fn $method(self, rhs: Decimal) -> Self::Output {
224                $macro!(self, self, rhs, rhs)
225            }
226        }
227
228        impl $Trait<&Decimal> for Decimal {
229            type Output = Decimal;
230
231            fn $method(self, rhs: &Decimal) -> Self::Output {
232                $macro!(self, self, rhs, rhs.clone())
233            }
234        }
235
236        impl $Trait<Decimal> for &Decimal {
237            type Output = Decimal;
238
239            fn $method(self, rhs: Decimal) -> Self::Output {
240                $macro!(self, self.clone(), rhs, rhs)
241            }
242        }
243
244        impl $Trait<&Decimal> for &Decimal {
245            type Output = Decimal;
246
247            fn $method(self, rhs: &Decimal) -> Self::Output {
248                $macro!(self, self.clone(), rhs, rhs.clone())
249            }
250        }
251    };
252}
253impl_binop!(impl Add for Decimal, add, add);
254impl_binop!(impl Sub for Decimal, sub, sub);
255
256macro_rules! impl_binop_assign {
257    (impl $Trait:ident for Decimal, $method:ident, $op:tt) => {
258        impl $Trait for Decimal {
259            fn $method(&mut self, rhs: Decimal) {
260                let lhs = replace(self, ZERO);
261                *self = lhs $op rhs;
262            }
263        }
264
265        impl $Trait<&Decimal> for Decimal {
266            fn $method(&mut self, rhs: &Decimal) {
267                let lhs = replace(self, ZERO);
268                *self = lhs $op rhs;
269            }
270        }
271    };
272}
273
274impl_binop_assign!(impl AddAssign for Decimal, add_assign, +);
275impl_binop_assign!(impl SubAssign for Decimal, sub_assign, -);
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use test_case::test_case;
281
282    #[test_case("0", "0", "0"; "zero")]
283    #[test_case("0", "1", "1"; "zero vs plus")]
284    #[test_case("0", "-1", "-1"; "zero vs minus")]
285    #[test_case("1", "0", "1"; "plus vs zero")]
286    #[test_case("-1", "0", "-1"; "minus vs zero")]
287    #[test_case("1", "1", "2"; "plus vs plus")]
288    #[test_case("1", "-1", "0"; "plus vs minus results zero")]
289    #[test_case("2", "-1", "1"; "plus vs minus results plus")]
290    #[test_case("1", "-2", "-1"; "plus vs minus results minus")]
291    #[test_case("-1", "1", "0"; "minus vs plus results zero")]
292    #[test_case("-2", "1", "-1"; "minus vs plus results minus")]
293    #[test_case("-1", "2", "1"; "minus vs plus results plus")]
294    #[test_case("-1", "-1", "-2"; "minus vs minus")]
295    #[test_case(
296        "999999999999999999.999999999999999999",
297        "000000000000000000.000000000000000001",
298        "1000000000000000000";
299        "carry"
300    )]
301    #[test_case(
302        "012345678901234567890.1234567890123456789",
303        "098765432109876543210.9876543210987654321",
304        "111111111011111111101.1111111101111111110";
305        "plus long vs plus long"
306    )]
307    #[test_case(
308        "00345678901234567890.1234567890",
309        "98765432109876543210.9876543210987654321",
310        "99111111011111111101.1111111100987654321";
311        "plus short vs plus long"
312    )]
313    #[test_case(
314        "12345678901234567890.1234567890123456789",
315        "00765432109876543210.9876543210",
316        "13111111011111111101.1111111100123456789";
317        "plus long vs plus short"
318    )]
319    #[test_case(
320        "+1000000000000000000.0000000000000000000",
321        "-0000000000000000000.0000000000000000001",
322        "+0999999999999999999.9999999999999999999";
323        "borrow"
324    )]
325    #[test_case(
326        "+098765432109876543210.9876543210987654321",
327        "-012345678901234567890.1234567890123456789",
328        "+086419753208641975320.8641975320864197532";
329        "plus long vs minus long results plus"
330    )]
331    #[test_case(
332        "+012345678901234567890.1234567890123456789",
333        "-098765432109876543210.9876543210987654321",
334        "-086419753208641975320.8641975320864197532";
335        "plus long vs minus long results minus"
336    )]
337    #[test_case(
338        "-098765432109876543210.9876543210987654321",
339        "+012345678901234567890.1234567890123456789",
340        "-086419753208641975320.8641975320864197532";
341        "minus long vs plus long results minus"
342    )]
343    #[test_case(
344        "-012345678901234567890.1234567890123456789",
345        "+098765432109876543210.9876543210987654321",
346        "+086419753208641975320.8641975320864197532";
347        "minus long vs plus long results plus"
348    )]
349    #[test_case(
350        "+098765432109876543210.9876543210987654321",
351        "-000945678901234567890.123456789",
352        "+097819753208641975320.8641975320987654321";
353        "plus long vs minus short results plus"
354    )]
355    fn test_add(lhs: &str, rhs: &str, expected: &str) {
356        let lhs: Decimal = lhs.parse().unwrap();
357        let rhs: Decimal = rhs.parse().unwrap();
358        let expected: Decimal = expected.parse().unwrap();
359        assert_eq!(lhs.clone() + rhs.clone(), expected);
360        assert_eq!(lhs.clone() + &rhs, expected);
361        assert_eq!(&lhs + rhs.clone(), expected);
362        assert_eq!(&lhs + &rhs, expected);
363    }
364
365    #[test_case("0", "0", "0"; "zero")]
366    #[test_case("0", "1", "-1"; "zero vs plus")]
367    #[test_case("0", "-1", "1"; "zero vs minus")]
368    #[test_case("1", "0", "1"; "plus vs zero")]
369    #[test_case("-1", "0", "-1"; "minus vs zero")]
370    #[test_case("1", "-1", "2"; "plus vs minus")]
371    #[test_case("1", "1", "0"; "plus vs plus results zero")]
372    #[test_case("2", "1", "1"; "plus vs plus results plus")]
373    #[test_case("1", "2", "-1"; "plus vs plus results minus")]
374    #[test_case("-1", "-1", "0"; "minus vs minus results zero")]
375    #[test_case("-2", "-1", "-1"; "minus vs minus results minus")]
376    #[test_case("-1", "-2", "1"; "minus vs minus results plus")]
377    #[test_case("-1", "1", "-2"; "minus vs plus")]
378    #[test_case(
379        "+999999999999999999.999999999999999999",
380        "-000000000000000000.000000000000000001",
381        "+1000000000000000000";
382        "carry"
383    )]
384    #[test_case(
385        "+012345678901234567890.1234567890123456789",
386        "-098765432109876543210.9876543210987654321",
387        "+111111111011111111101.1111111101111111110";
388        "plus long vs minus long"
389    )]
390    #[test_case(
391        "+00345678901234567890.1234567890",
392        "-98765432109876543210.9876543210987654321",
393        "+99111111011111111101.1111111100987654321";
394        "plus short vs minus long"
395    )]
396    #[test_case(
397        "+12345678901234567890.1234567890123456789",
398        "-00765432109876543210.9876543210",
399        "+13111111011111111101.1111111100123456789";
400        "plus long vs minus short"
401    )]
402    #[test_case(
403        "+1000000000000000000.0000000000000000000",
404        "+0000000000000000000.0000000000000000001",
405        "+0999999999999999999.9999999999999999999";
406        "borrow"
407    )]
408    #[test_case(
409        "+098765432109876543210.9876543210987654321",
410        "+012345678901234567890.1234567890123456789",
411        "+086419753208641975320.8641975320864197532";
412        "plus long vs plus long results plus"
413    )]
414    #[test_case(
415        "+012345678901234567890.1234567890123456789",
416        "+098765432109876543210.9876543210987654321",
417        "-086419753208641975320.8641975320864197532";
418        "plus long vs plus long results minus"
419    )]
420    #[test_case(
421        "-098765432109876543210.9876543210987654321",
422        "-012345678901234567890.1234567890123456789",
423        "-086419753208641975320.8641975320864197532";
424        "minus long vs minus long results minus"
425    )]
426    #[test_case(
427        "-012345678901234567890.1234567890123456789",
428        "-098765432109876543210.9876543210987654321",
429        "+086419753208641975320.8641975320864197532";
430        "minus long vs minus long results plus"
431    )]
432    #[test_case(
433        "+098765432109876543210.9876543210987654321",
434        "+000945678901234567890.123456789",
435        "+097819753208641975320.8641975320987654321";
436        "plus long vs plus short results plus"
437    )]
438    fn test_sub(lhs: &str, rhs: &str, expected: &str) {
439        let lhs: Decimal = lhs.parse().unwrap();
440        let rhs: Decimal = rhs.parse().unwrap();
441        let expected: Decimal = expected.parse().unwrap();
442        assert_eq!(lhs.clone() - rhs.clone(), expected);
443        assert_eq!(lhs.clone() - &rhs, expected);
444        assert_eq!(&lhs - rhs.clone(), expected);
445        assert_eq!(&lhs - &rhs, expected);
446    }
447}