1use super::*;
2use std::{
3 mem::replace,
4 ops::{Add, AddAssign, Sub, SubAssign},
5};
6
7fn add_carry(carry: bool, lhs: u64, rhs: u64, out: &mut u64) -> bool {
8 let mut sum = lhs + rhs + carry as u64;
9 let cond = sum >= RADIX;
10 if cond {
11 sum -= RADIX;
12 }
13 *out = sum;
14 cond
15}
16
17fn add_absolute_parts(lhs: &mut Decimal, rhs: &Decimal) {
18 let mut carry = false;
19
20 let lhs_decimal_len = lhs.decimal.len();
22 if lhs_decimal_len < rhs.decimal.len() {
23 for (l, r) in lhs
24 .decimal
25 .iter_mut()
26 .rev()
27 .zip(rhs.decimal[..lhs_decimal_len].iter().rev())
28 {
29 carry = add_carry(carry, *l, *r, l);
30 }
31 lhs.decimal
32 .extend_from_slice(&rhs.decimal[lhs_decimal_len..]);
33 } else {
34 for (l, r) in lhs.decimal[..rhs.decimal.len()]
35 .iter_mut()
36 .rev()
37 .zip(rhs.decimal.iter().rev())
38 {
39 carry = add_carry(carry, *l, *r, l);
40 }
41 }
42
43 let lhs_integer_len = lhs.integer.len();
45 if lhs_integer_len < rhs.integer.len() {
46 for (l, r) in lhs.integer.iter_mut().zip(&rhs.integer[..lhs_integer_len]) {
47 carry = add_carry(carry, *l, *r, l);
48 }
49 lhs.integer
50 .extend_from_slice(&rhs.integer[lhs_integer_len..]);
51 if carry {
52 for l in lhs.integer[lhs_integer_len..].iter_mut() {
53 carry = add_carry(carry, *l, 0, l);
54 if !carry {
55 break;
56 }
57 }
58 }
59 } else {
60 for (l, r) in lhs.integer.iter_mut().zip(&rhs.integer) {
61 carry = add_carry(carry, *l, *r, l);
62 }
63 if carry {
64 for l in lhs.integer[rhs.integer.len()..].iter_mut() {
65 carry = add_carry(carry, *l, 0, l);
66 if !carry {
67 break;
68 }
69 }
70 }
71 }
72
73 if carry {
74 lhs.integer.push(carry as u64);
75 }
76
77 lhs.normalize();
78}
79
80fn sub_borrow(borrow: bool, lhs: u64, rhs: u64, out: &mut u64) -> bool {
81 let (sum, borrow1) = lhs.overflowing_sub(rhs);
82 let (mut sum, borrow2) = sum.overflowing_sub(borrow as u64);
83 let borrow = borrow1 || borrow2;
84 if borrow {
85 sum = sum.wrapping_add(RADIX);
86 }
87 *out = sum;
88 borrow
89}
90
91fn sub_absolute_parts_gte(lhs: &Decimal, rhs: &mut Decimal) {
93 debug_assert!(matches!(lhs.cmp_absolute_parts(rhs), Ordering::Greater));
94
95 let mut borrow = false;
96
97 let rhs_decimal_len = rhs.decimal.len();
99 if lhs.decimal.len() > rhs_decimal_len {
100 for (l, r) in lhs.decimal[..rhs_decimal_len]
101 .iter()
102 .rev()
103 .zip(rhs.decimal.iter_mut().rev())
104 {
105 borrow = sub_borrow(borrow, *l, *r, r);
106 }
107 rhs.decimal
108 .extend_from_slice(&lhs.decimal[rhs_decimal_len..]);
109 } else {
110 for r in rhs.decimal[lhs.decimal.len()..].iter_mut().rev() {
111 borrow = sub_borrow(borrow, 0, *r, r);
112 }
113 for (l, r) in lhs
114 .decimal
115 .iter()
116 .rev()
117 .zip(rhs.decimal[..lhs.decimal.len()].iter_mut().rev())
118 {
119 borrow = sub_borrow(borrow, *l, *r, r);
120 }
121 }
122
123 let rhs_integer_len = rhs.integer.len();
125 if lhs.integer.len() > rhs_integer_len {
126 for (l, r) in lhs.integer[..rhs_integer_len]
127 .iter()
128 .zip(rhs.integer.iter_mut())
129 {
130 borrow = sub_borrow(borrow, *l, *r, r);
131 }
132 rhs.integer
133 .extend_from_slice(&lhs.integer[rhs_integer_len..]);
134 if borrow {
135 for r in rhs.integer[rhs_integer_len..].iter_mut() {
136 borrow = sub_borrow(borrow, *r, 0, r);
137 if !borrow {
138 break;
139 }
140 }
141 }
142 } else {
143 debug_assert_eq!(lhs.integer.len(), rhs_integer_len);
144 for (l, r) in lhs.integer.iter().zip(&mut rhs.integer) {
145 borrow = sub_borrow(borrow, *l, *r, r);
146 }
147 }
148
149 assert!(
150 !borrow,
151 "Cannot subtract lhs from rhs because lhs is smaller than rhs"
152 );
153
154 rhs.normalize();
155}
156
157macro_rules! add {
158 ($lhs:expr, $lhs_owned:expr, $rhs:expr, $rhs_owned:expr) => {
159 match ($lhs.sign, $rhs.sign) {
160 (Sign::Zero, _) => $rhs_owned,
161 (_, Sign::Zero) => $lhs_owned,
162 (Sign::Plus, Sign::Plus) | (Sign::Minus, Sign::Minus) => {
163 let mut lhs = $lhs_owned;
164 add_absolute_parts(&mut lhs, &$rhs);
165 lhs
166 }
167 (Sign::Plus, Sign::Minus) | (Sign::Minus, Sign::Plus) => {
168 match $lhs.cmp_absolute_parts(&$rhs) {
169 Ordering::Less => {
170 let mut lhs = $lhs_owned;
171 sub_absolute_parts_gte(&$rhs, &mut lhs);
172 lhs.sign = $rhs.sign;
173 lhs
174 }
175 Ordering::Equal => ZERO,
176 Ordering::Greater => {
177 let mut rhs = $rhs_owned;
178 sub_absolute_parts_gte(&$lhs, &mut rhs);
179 rhs.sign = $lhs.sign;
180 rhs
181 }
182 }
183 }
184 }
185 };
186}
187
188macro_rules! sub {
189 ($lhs:expr, $lhs_owned:expr, $rhs:expr, $rhs_owned:expr) => {
190 match ($lhs.sign, $rhs.sign) {
191 (Sign::Zero, _) => -$rhs_owned,
192 (_, Sign::Zero) => $lhs_owned,
193 (Sign::Plus, Sign::Minus) | (Sign::Minus, Sign::Plus) => {
194 let mut lhs = $lhs_owned;
195 add_absolute_parts(&mut lhs, &$rhs);
196 lhs
197 }
198 (Sign::Plus, Sign::Plus) | (Sign::Minus, Sign::Minus) => {
199 match $lhs.cmp_absolute_parts(&$rhs) {
200 Ordering::Less => {
201 let mut lhs = $lhs_owned;
202 sub_absolute_parts_gte(&$rhs, &mut lhs);
203 lhs.sign = -$rhs.sign;
204 lhs
205 }
206 Ordering::Equal => ZERO,
207 Ordering::Greater => {
208 let mut rhs = $rhs_owned;
209 sub_absolute_parts_gte(&$lhs, &mut rhs);
210 rhs
211 }
212 }
213 }
214 }
215 };
216}
217
218macro_rules! impl_binop {
219 (impl $Trait:ident for Decimal, $method:ident, $macro:ident) => {
220 impl $Trait<Decimal> for Decimal {
221 type Output = Decimal;
222
223 fn $method(self, rhs: Decimal) -> Self::Output {
224 $macro!(self, self, rhs, rhs)
225 }
226 }
227
228 impl $Trait<&Decimal> for Decimal {
229 type Output = Decimal;
230
231 fn $method(self, rhs: &Decimal) -> Self::Output {
232 $macro!(self, self, rhs, rhs.clone())
233 }
234 }
235
236 impl $Trait<Decimal> for &Decimal {
237 type Output = Decimal;
238
239 fn $method(self, rhs: Decimal) -> Self::Output {
240 $macro!(self, self.clone(), rhs, rhs)
241 }
242 }
243
244 impl $Trait<&Decimal> for &Decimal {
245 type Output = Decimal;
246
247 fn $method(self, rhs: &Decimal) -> Self::Output {
248 $macro!(self, self.clone(), rhs, rhs.clone())
249 }
250 }
251 };
252}
253impl_binop!(impl Add for Decimal, add, add);
254impl_binop!(impl Sub for Decimal, sub, sub);
255
256macro_rules! impl_binop_assign {
257 (impl $Trait:ident for Decimal, $method:ident, $op:tt) => {
258 impl $Trait for Decimal {
259 fn $method(&mut self, rhs: Decimal) {
260 let lhs = replace(self, ZERO);
261 *self = lhs $op rhs;
262 }
263 }
264
265 impl $Trait<&Decimal> for Decimal {
266 fn $method(&mut self, rhs: &Decimal) {
267 let lhs = replace(self, ZERO);
268 *self = lhs $op rhs;
269 }
270 }
271 };
272}
273
274impl_binop_assign!(impl AddAssign for Decimal, add_assign, +);
275impl_binop_assign!(impl SubAssign for Decimal, sub_assign, -);
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use test_case::test_case;
281
282 #[test_case("0", "0", "0"; "zero")]
283 #[test_case("0", "1", "1"; "zero vs plus")]
284 #[test_case("0", "-1", "-1"; "zero vs minus")]
285 #[test_case("1", "0", "1"; "plus vs zero")]
286 #[test_case("-1", "0", "-1"; "minus vs zero")]
287 #[test_case("1", "1", "2"; "plus vs plus")]
288 #[test_case("1", "-1", "0"; "plus vs minus results zero")]
289 #[test_case("2", "-1", "1"; "plus vs minus results plus")]
290 #[test_case("1", "-2", "-1"; "plus vs minus results minus")]
291 #[test_case("-1", "1", "0"; "minus vs plus results zero")]
292 #[test_case("-2", "1", "-1"; "minus vs plus results minus")]
293 #[test_case("-1", "2", "1"; "minus vs plus results plus")]
294 #[test_case("-1", "-1", "-2"; "minus vs minus")]
295 #[test_case(
296 "999999999999999999.999999999999999999",
297 "000000000000000000.000000000000000001",
298 "1000000000000000000";
299 "carry"
300 )]
301 #[test_case(
302 "012345678901234567890.1234567890123456789",
303 "098765432109876543210.9876543210987654321",
304 "111111111011111111101.1111111101111111110";
305 "plus long vs plus long"
306 )]
307 #[test_case(
308 "00345678901234567890.1234567890",
309 "98765432109876543210.9876543210987654321",
310 "99111111011111111101.1111111100987654321";
311 "plus short vs plus long"
312 )]
313 #[test_case(
314 "12345678901234567890.1234567890123456789",
315 "00765432109876543210.9876543210",
316 "13111111011111111101.1111111100123456789";
317 "plus long vs plus short"
318 )]
319 #[test_case(
320 "+1000000000000000000.0000000000000000000",
321 "-0000000000000000000.0000000000000000001",
322 "+0999999999999999999.9999999999999999999";
323 "borrow"
324 )]
325 #[test_case(
326 "+098765432109876543210.9876543210987654321",
327 "-012345678901234567890.1234567890123456789",
328 "+086419753208641975320.8641975320864197532";
329 "plus long vs minus long results plus"
330 )]
331 #[test_case(
332 "+012345678901234567890.1234567890123456789",
333 "-098765432109876543210.9876543210987654321",
334 "-086419753208641975320.8641975320864197532";
335 "plus long vs minus long results minus"
336 )]
337 #[test_case(
338 "-098765432109876543210.9876543210987654321",
339 "+012345678901234567890.1234567890123456789",
340 "-086419753208641975320.8641975320864197532";
341 "minus long vs plus long results minus"
342 )]
343 #[test_case(
344 "-012345678901234567890.1234567890123456789",
345 "+098765432109876543210.9876543210987654321",
346 "+086419753208641975320.8641975320864197532";
347 "minus long vs plus long results plus"
348 )]
349 #[test_case(
350 "+098765432109876543210.9876543210987654321",
351 "-000945678901234567890.123456789",
352 "+097819753208641975320.8641975320987654321";
353 "plus long vs minus short results plus"
354 )]
355 fn test_add(lhs: &str, rhs: &str, expected: &str) {
356 let lhs: Decimal = lhs.parse().unwrap();
357 let rhs: Decimal = rhs.parse().unwrap();
358 let expected: Decimal = expected.parse().unwrap();
359 assert_eq!(lhs.clone() + rhs.clone(), expected);
360 assert_eq!(lhs.clone() + &rhs, expected);
361 assert_eq!(&lhs + rhs.clone(), expected);
362 assert_eq!(&lhs + &rhs, expected);
363 }
364
365 #[test_case("0", "0", "0"; "zero")]
366 #[test_case("0", "1", "-1"; "zero vs plus")]
367 #[test_case("0", "-1", "1"; "zero vs minus")]
368 #[test_case("1", "0", "1"; "plus vs zero")]
369 #[test_case("-1", "0", "-1"; "minus vs zero")]
370 #[test_case("1", "-1", "2"; "plus vs minus")]
371 #[test_case("1", "1", "0"; "plus vs plus results zero")]
372 #[test_case("2", "1", "1"; "plus vs plus results plus")]
373 #[test_case("1", "2", "-1"; "plus vs plus results minus")]
374 #[test_case("-1", "-1", "0"; "minus vs minus results zero")]
375 #[test_case("-2", "-1", "-1"; "minus vs minus results minus")]
376 #[test_case("-1", "-2", "1"; "minus vs minus results plus")]
377 #[test_case("-1", "1", "-2"; "minus vs plus")]
378 #[test_case(
379 "+999999999999999999.999999999999999999",
380 "-000000000000000000.000000000000000001",
381 "+1000000000000000000";
382 "carry"
383 )]
384 #[test_case(
385 "+012345678901234567890.1234567890123456789",
386 "-098765432109876543210.9876543210987654321",
387 "+111111111011111111101.1111111101111111110";
388 "plus long vs minus long"
389 )]
390 #[test_case(
391 "+00345678901234567890.1234567890",
392 "-98765432109876543210.9876543210987654321",
393 "+99111111011111111101.1111111100987654321";
394 "plus short vs minus long"
395 )]
396 #[test_case(
397 "+12345678901234567890.1234567890123456789",
398 "-00765432109876543210.9876543210",
399 "+13111111011111111101.1111111100123456789";
400 "plus long vs minus short"
401 )]
402 #[test_case(
403 "+1000000000000000000.0000000000000000000",
404 "+0000000000000000000.0000000000000000001",
405 "+0999999999999999999.9999999999999999999";
406 "borrow"
407 )]
408 #[test_case(
409 "+098765432109876543210.9876543210987654321",
410 "+012345678901234567890.1234567890123456789",
411 "+086419753208641975320.8641975320864197532";
412 "plus long vs plus long results plus"
413 )]
414 #[test_case(
415 "+012345678901234567890.1234567890123456789",
416 "+098765432109876543210.9876543210987654321",
417 "-086419753208641975320.8641975320864197532";
418 "plus long vs plus long results minus"
419 )]
420 #[test_case(
421 "-098765432109876543210.9876543210987654321",
422 "-012345678901234567890.1234567890123456789",
423 "-086419753208641975320.8641975320864197532";
424 "minus long vs minus long results minus"
425 )]
426 #[test_case(
427 "-012345678901234567890.1234567890123456789",
428 "-098765432109876543210.9876543210987654321",
429 "+086419753208641975320.8641975320864197532";
430 "minus long vs minus long results plus"
431 )]
432 #[test_case(
433 "+098765432109876543210.9876543210987654321",
434 "+000945678901234567890.123456789",
435 "+097819753208641975320.8641975320987654321";
436 "plus long vs plus short results plus"
437 )]
438 fn test_sub(lhs: &str, rhs: &str, expected: &str) {
439 let lhs: Decimal = lhs.parse().unwrap();
440 let rhs: Decimal = rhs.parse().unwrap();
441 let expected: Decimal = expected.parse().unwrap();
442 assert_eq!(lhs.clone() - rhs.clone(), expected);
443 assert_eq!(lhs.clone() - &rhs, expected);
444 assert_eq!(&lhs - rhs.clone(), expected);
445 assert_eq!(&lhs - &rhs, expected);
446 }
447}