Skip to main content

competitive/math/
array_vec.rs

1use std::ops::{
2    Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
3    Index, IndexMut, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
4    SubAssign,
5};
6
7#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
8pub struct ArrayVecScalar<T>(pub T);
9
10impl<T> From<T> for ArrayVecScalar<T> {
11    fn from(value: T) -> Self {
12        Self(value)
13    }
14}
15
16pub trait ToArrayVecScalar: Sized {
17    fn to_array_vec_scalar(self) -> ArrayVecScalar<Self>;
18}
19
20impl<T> ToArrayVecScalar for T {
21    fn to_array_vec_scalar(self) -> ArrayVecScalar<Self> {
22        ArrayVecScalar(self)
23    }
24}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
27pub struct ArrayVec<T, const N: usize>(pub [T; N]);
28
29pub trait ToArrayVec<T, const N: usize>: Sized {
30    fn to_array_vec(self) -> ArrayVec<T, N>;
31}
32
33impl<T, const N: usize> ToArrayVec<T, N> for [T; N] {
34    fn to_array_vec(self) -> ArrayVec<T, N> {
35        ArrayVec(self)
36    }
37}
38
39impl<T, const N: usize> Default for ArrayVec<T, N>
40where
41    T: Default,
42{
43    fn default() -> Self {
44        Self(std::array::from_fn(|_| T::default()))
45    }
46}
47
48impl<T, const N: usize> ArrayVec<T, N> {
49    pub fn new(data: [T; N]) -> Self {
50        Self(data)
51    }
52
53    pub fn map<U>(&self, transform: impl FnMut(&T) -> U) -> ArrayVec<U, N> {
54        ArrayVec(array_from_iter(self.0.iter().map(transform)))
55    }
56
57    pub fn zip_with<U, V>(
58        &self,
59        other: &ArrayVec<U, N>,
60        mut combine: impl FnMut(&T, &U) -> V,
61    ) -> ArrayVec<V, N> {
62        ArrayVec(array_from_iter(
63            self.0
64                .iter()
65                .zip(other.0.iter())
66                .map(|(left, right)| combine(left, right)),
67        ))
68    }
69}
70
71impl<T, const N: usize> From<[T; N]> for ArrayVec<T, N> {
72    fn from(data: [T; N]) -> Self {
73        Self(data)
74    }
75}
76
77impl<T, const N: usize> From<ArrayVec<T, N>> for [T; N] {
78    fn from(data: ArrayVec<T, N>) -> Self {
79        data.0
80    }
81}
82
83impl<T, const N: usize> AsRef<[T; N]> for ArrayVec<T, N> {
84    fn as_ref(&self) -> &[T; N] {
85        &self.0
86    }
87}
88
89impl<T, const N: usize> AsMut<[T; N]> for ArrayVec<T, N> {
90    fn as_mut(&mut self) -> &mut [T; N] {
91        &mut self.0
92    }
93}
94
95impl<T, const N: usize> Index<usize> for ArrayVec<T, N> {
96    type Output = T;
97    fn index(&self, index: usize) -> &Self::Output {
98        &self.0[index]
99    }
100}
101
102impl<T, const N: usize> IndexMut<usize> for ArrayVec<T, N> {
103    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
104        &mut self.0[index]
105    }
106}
107
108#[inline]
109fn array_from_iter<T, I, const N: usize>(mut iter: I) -> [T; N]
110where
111    I: Iterator<Item = T>,
112{
113    std::array::from_fn(|_| iter.next().unwrap())
114}
115
116macro_rules! impl_arrayvec_binop {
117    ($imp:ident, $method:ident, $op:tt) => {
118        impl<T, U, V, const N: usize> $imp<ArrayVec<U, N>> for ArrayVec<T, N>
119        where
120            T: $imp<U, Output = V>,
121        {
122            type Output = ArrayVec<V, N>;
123            fn $method(self, rhs: ArrayVec<U, N>) -> Self::Output {
124                ArrayVec(array_from_iter(
125                    self.0
126                        .into_iter()
127                        .zip(rhs.0.into_iter())
128                        .map(|(left_value, right_value)| left_value $op right_value),
129                ))
130            }
131        }
132        impl<T, U, V, const N: usize> $imp<&ArrayVec<U, N>> for ArrayVec<T, N>
133        where
134            T: $imp<U, Output = V>,
135            U: Clone,
136        {
137            type Output = ArrayVec<V, N>;
138            fn $method(self, rhs: &ArrayVec<U, N>) -> Self::Output {
139                $imp::$method(self, rhs.clone())
140            }
141        }
142        impl<T, U, V, const N: usize> $imp<ArrayVec<U, N>> for &ArrayVec<T, N>
143        where
144            T: Clone + $imp<U, Output = V>,
145        {
146            type Output = ArrayVec<V, N>;
147            fn $method(self, rhs: ArrayVec<U, N>) -> Self::Output {
148                $imp::$method(self.clone(), rhs)
149            }
150        }
151        impl<T, U, V, const N: usize> $imp<&ArrayVec<U, N>> for &ArrayVec<T, N>
152        where
153            T: Clone + $imp<U, Output = V>,
154            U: Clone,
155        {
156            type Output = ArrayVec<V, N>;
157            fn $method(self, rhs: &ArrayVec<U, N>) -> Self::Output {
158                $imp::$method(self.clone(), rhs.clone())
159            }
160        }
161
162        impl<T, U, V, const N: usize> $imp<ArrayVecScalar<U>> for ArrayVec<T, N>
163        where
164            T: $imp<U, Output = V>,
165            U: Clone,
166        {
167            type Output = ArrayVec<V, N>;
168            fn $method(self, rhs: ArrayVecScalar<U>) -> Self::Output {
169                let scalar_value = rhs.0;
170                ArrayVec(array_from_iter(
171                    self.0
172                        .into_iter()
173                        .map(|value| value $op scalar_value.clone()),
174                ))
175            }
176        }
177        impl<T, U, V, const N: usize> $imp<&ArrayVecScalar<U>> for ArrayVec<T, N>
178        where
179            T: $imp<U, Output = V>,
180            U: Clone,
181        {
182            type Output = ArrayVec<V, N>;
183            fn $method(self, rhs: &ArrayVecScalar<U>) -> Self::Output {
184                $imp::$method(self, rhs.clone())
185            }
186        }
187        impl<T, U, V, const N: usize> $imp<ArrayVecScalar<U>> for &ArrayVec<T, N>
188        where
189            T: Clone + $imp<U, Output = V>,
190            U: Clone,
191        {
192            type Output = ArrayVec<V, N>;
193            fn $method(self, rhs: ArrayVecScalar<U>) -> Self::Output {
194                $imp::$method(self.clone(), rhs)
195            }
196        }
197        impl<T, U, V, const N: usize> $imp<&ArrayVecScalar<U>> for &ArrayVec<T, N>
198        where
199            T: Clone + $imp<U, Output = V>,
200            U: Clone,
201        {
202            type Output = ArrayVec<V, N>;
203            fn $method(self, rhs: &ArrayVecScalar<U>) -> Self::Output {
204                $imp::$method(self.clone(), rhs.clone())
205            }
206        }
207
208        impl<T, U, V, const N: usize> $imp<ArrayVec<T, N>> for ArrayVecScalar<U>
209        where
210            U: Clone + $imp<T, Output = V>,
211        {
212            type Output = ArrayVec<V, N>;
213            fn $method(self, rhs: ArrayVec<T, N>) -> Self::Output {
214                let scalar_value = self.0;
215                ArrayVec(array_from_iter(
216                    rhs.0
217                        .into_iter()
218                        .map(|value| scalar_value.clone() $op value),
219                ))
220            }
221        }
222        impl<T, U, V, const N: usize> $imp<&ArrayVec<T, N>> for ArrayVecScalar<U>
223        where
224            U: Clone + $imp<T, Output = V>,
225            T: Clone,
226        {
227            type Output = ArrayVec<V, N>;
228            fn $method(self, rhs: &ArrayVec<T, N>) -> Self::Output {
229                $imp::$method(self, rhs.clone())
230            }
231        }
232        impl<T, U, V, const N: usize> $imp<ArrayVec<T, N>> for &ArrayVecScalar<U>
233        where
234            U: Clone + $imp<T, Output = V>,
235        {
236            type Output = ArrayVec<V, N>;
237            fn $method(self, rhs: ArrayVec<T, N>) -> Self::Output {
238                $imp::$method(self.clone(), rhs)
239            }
240        }
241        impl<T, U, V, const N: usize> $imp<&ArrayVec<T, N>> for &ArrayVecScalar<U>
242        where
243            U: Clone + $imp<T, Output = V>,
244            T: Clone,
245        {
246            type Output = ArrayVec<V, N>;
247            fn $method(self, rhs: &ArrayVec<T, N>) -> Self::Output {
248                $imp::$method(self.clone(), rhs.clone())
249            }
250        }
251    };
252}
253
254macro_rules! impl_arrayvec_unop {
255    ($imp:ident, $method:ident, $op:tt) => {
256        impl<T, U, const N: usize> $imp for ArrayVec<T, N>
257        where
258            T: $imp<Output = U>,
259        {
260            type Output = ArrayVec<U, N>;
261            fn $method(self) -> Self::Output {
262                ArrayVec(array_from_iter(
263                    self.0.into_iter().map(|value| $op value),
264                ))
265            }
266        }
267        impl<T, U, const N: usize> $imp for &ArrayVec<T, N>
268        where
269            T: Clone + $imp<Output = U>,
270        {
271            type Output = ArrayVec<U, N>;
272            fn $method(self) -> Self::Output {
273                $imp::$method(self.clone())
274            }
275        }
276    };
277}
278
279macro_rules! impl_arrayvec_assign {
280    ($imp:ident, $method:ident) => {
281        impl<T, U, const N: usize> $imp<ArrayVec<U, N>> for ArrayVec<T, N>
282        where
283            T: $imp<U>,
284        {
285            fn $method(&mut self, rhs: ArrayVec<U, N>) {
286                for (left_value, right_value) in self.0.iter_mut().zip(rhs.0.into_iter()) {
287                    left_value.$method(right_value);
288                }
289            }
290        }
291        impl<T, U, const N: usize> $imp<&ArrayVec<U, N>> for ArrayVec<T, N>
292        where
293            T: $imp<U>,
294            U: Clone,
295        {
296            fn $method(&mut self, rhs: &ArrayVec<U, N>) {
297                for (left_value, right_value) in self.0.iter_mut().zip(rhs.0.iter()) {
298                    left_value.$method(right_value.clone());
299                }
300            }
301        }
302        impl<T, U, const N: usize> $imp<ArrayVecScalar<U>> for ArrayVec<T, N>
303        where
304            T: $imp<U>,
305            U: Clone,
306        {
307            fn $method(&mut self, rhs: ArrayVecScalar<U>) {
308                let scalar_value = rhs.0;
309                for value in self.0.iter_mut() {
310                    value.$method(scalar_value.clone());
311                }
312            }
313        }
314        impl<T, U, const N: usize> $imp<&ArrayVecScalar<U>> for ArrayVec<T, N>
315        where
316            T: $imp<U>,
317            U: Clone,
318        {
319            fn $method(&mut self, rhs: &ArrayVecScalar<U>) {
320                self.$method(rhs.clone());
321            }
322        }
323    };
324}
325
326impl_arrayvec_binop!(Add, add, +);
327impl_arrayvec_binop!(Sub, sub, -);
328impl_arrayvec_binop!(Mul, mul, *);
329impl_arrayvec_binop!(Div, div, /);
330impl_arrayvec_binop!(Rem, rem, %);
331impl_arrayvec_binop!(BitAnd, bitand, &);
332impl_arrayvec_binop!(BitOr, bitor, |);
333impl_arrayvec_binop!(BitXor, bitxor, ^);
334impl_arrayvec_binop!(Shl, shl, <<);
335impl_arrayvec_binop!(Shr, shr, >>);
336
337impl_arrayvec_unop!(Neg, neg, -);
338impl_arrayvec_unop!(Not, not, !);
339
340impl_arrayvec_assign!(AddAssign, add_assign);
341impl_arrayvec_assign!(SubAssign, sub_assign);
342impl_arrayvec_assign!(MulAssign, mul_assign);
343impl_arrayvec_assign!(DivAssign, div_assign);
344impl_arrayvec_assign!(RemAssign, rem_assign);
345impl_arrayvec_assign!(BitAndAssign, bitand_assign);
346impl_arrayvec_assign!(BitOrAssign, bitor_assign);
347impl_arrayvec_assign!(BitXorAssign, bitxor_assign);
348impl_arrayvec_assign!(ShlAssign, shl_assign);
349impl_arrayvec_assign!(ShrAssign, shr_assign);
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use std::ops::Add;
355
356    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
357    struct LeftValue(i32);
358
359    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
360    struct RightValue(i32);
361
362    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
363    struct SumValue(i32);
364
365    impl Add<RightValue> for LeftValue {
366        type Output = SumValue;
367        fn add(self, rhs: RightValue) -> Self::Output {
368            SumValue(self.0 + rhs.0)
369        }
370    }
371
372    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
373    struct ScalarValue(i32);
374
375    impl Add<i32> for ScalarValue {
376        type Output = i64;
377        fn add(self, rhs: i32) -> Self::Output {
378            self.0 as i64 + rhs as i64
379        }
380    }
381
382    #[test]
383    fn test_vec_vec_output_change() {
384        let left = [LeftValue(1), LeftValue(2)].to_array_vec();
385        let right = [RightValue(3), RightValue(4)].to_array_vec();
386        let sum = left + right;
387        assert_eq!(sum.0, [SumValue(4), SumValue(6)]);
388    }
389
390    #[test]
391    fn test_vec_scalar_output_change() {
392        let vector = [ScalarValue(1), ScalarValue(2)].to_array_vec();
393        let output = vector + 3.to_array_vec_scalar();
394        assert_eq!(output.0, [4i64, 5i64]);
395    }
396
397    #[test]
398    fn test_binary_ops() {
399        let left = [10i32, 20i32].to_array_vec();
400        let right = [3i32, 4i32].to_array_vec();
401        assert_eq!((left + right).0, [13, 24]);
402        assert_eq!((left - right).0, [7, 16]);
403        assert_eq!((left * 2.to_array_vec_scalar()).0, [20, 40]);
404        assert_eq!((left / 2.to_array_vec_scalar()).0, [5, 10]);
405        assert_eq!((left % 7.to_array_vec_scalar()).0, [3, 6]);
406        assert_eq!((2.to_array_vec_scalar() * left).0, [20, 40]);
407    }
408
409    #[test]
410    fn test_bit_and_shift_ops() {
411        let vector = [0b1100u8, 0b1010u8].to_array_vec();
412        let other = [0b1010u8, 0b1100u8].to_array_vec();
413        assert_eq!((vector & other).0, [0b1000, 0b1000]);
414        assert_eq!((vector | 0b0001.to_array_vec_scalar()).0, [0b1101, 0b1011]);
415        assert_eq!((vector ^ other).0, [0b0110, 0b0110]);
416        let shift_amounts = [1u32, 2u32].to_array_vec();
417        assert_eq!((vector << shift_amounts).0, [0b11000, 0b101000]);
418        assert_eq!((vector >> 1u32.to_array_vec_scalar()).0, [0b0110, 0b0101]);
419    }
420
421    #[test]
422    fn test_assign_ops() {
423        let mut values = [10i32, 20i32].to_array_vec();
424        values += [1, 2].to_array_vec();
425        values -= &[2, 3].to_array_vec();
426        values *= 2.to_array_vec_scalar();
427        values /= 3.to_array_vec_scalar();
428        values %= 5.to_array_vec_scalar();
429        assert_eq!(values.0, [1, 2]);
430
431        let mut bits = [0b1100u8, 0b1010u8].to_array_vec();
432        bits &= [0b1010u8, 0b1100u8].to_array_vec();
433        bits |= &[0b0001u8, 0b0010u8].to_array_vec();
434        bits ^= 0b0011.to_array_vec_scalar();
435        bits <<= 1u32.to_array_vec_scalar();
436        bits >>= 1u32.to_array_vec_scalar();
437        assert_eq!(bits.0, [0b1010u8, 0b1001u8]);
438    }
439}