competitive/algorithm/
sort.rs1use std::{cmp::Ordering, ptr::copy_nonoverlapping};
2
3pub trait SliceSortExt<T> {
4 fn bubble_sort(&mut self)
5 where
6 T: Ord;
7 fn bubble_sort_by<F>(&mut self, compare: F)
8 where
9 F: FnMut(&T, &T) -> Ordering;
10 fn merge_sort(&mut self)
11 where
12 T: Ord;
13 fn merge_sort_by<F>(&mut self, compare: F)
14 where
15 F: FnMut(&T, &T) -> Ordering;
16 fn insertion_sort(&mut self)
17 where
18 T: Ord;
19 fn insertion_sort_by<F>(&mut self, compare: F)
20 where
21 F: FnMut(&T, &T) -> Ordering;
22}
23impl<T> SliceSortExt<T> for [T] {
24 fn bubble_sort(&mut self)
25 where
26 T: Ord,
27 {
28 bubble_sort(self, |a, b| a.lt(b));
29 }
30 fn bubble_sort_by<F>(&mut self, mut compare: F)
31 where
32 F: FnMut(&T, &T) -> Ordering,
33 {
34 bubble_sort(self, |a, b| compare(a, b) == Ordering::Less);
35 }
36 fn merge_sort(&mut self)
37 where
38 T: Ord,
39 {
40 merge_sort(self, |a, b| a.lt(b));
41 }
42 fn merge_sort_by<F>(&mut self, mut compare: F)
43 where
44 F: FnMut(&T, &T) -> Ordering,
45 {
46 merge_sort(self, |a, b| compare(a, b) == Ordering::Less);
47 }
48 fn insertion_sort(&mut self)
49 where
50 T: Ord,
51 {
52 insertion_sort(self, |a, b| a.lt(b));
53 }
54 fn insertion_sort_by<F>(&mut self, mut compare: F)
55 where
56 F: FnMut(&T, &T) -> Ordering,
57 {
58 insertion_sort(self, |a, b| compare(a, b) == Ordering::Less);
59 }
60}
61
62fn bubble_sort<T, F>(v: &mut [T], mut is_less: F)
63where
64 F: FnMut(&T, &T) -> bool,
65{
66 let len = v.len();
67 if len <= 1 {
68 return;
69 }
70 for i in 0..len - 1 {
71 for j in 0..len - i - 1 {
72 unsafe {
73 if is_less(v.get_unchecked(j + 1), v.get_unchecked(j)) {
74 v.swap(j, j + 1);
75 }
76 }
77 }
78 }
79}
80
81unsafe fn merge<T, F>(v: &mut [T], mut mid: usize, buf: *mut T, is_less: &mut F)
82where
83 F: FnMut(&T, &T) -> bool,
84{
85 unsafe {
86 let len = v.len();
87 let v = v.as_mut_ptr();
88 let (v_mid, v_end) = (v.add(mid), v.add(len));
89
90 copy_nonoverlapping(v, buf, mid);
91 let mut start = buf;
92 let end = buf.add(mid);
93 let mut dest = v;
94
95 let left = &mut start;
96 let mut right = v_mid;
97 while *left < end && right < v_end {
98 let to_copy = if is_less(&*right, &**left) {
99 get_and_increment(&mut right)
100 } else {
101 mid -= 1;
102 get_and_increment(left)
103 };
104 copy_nonoverlapping(to_copy, get_and_increment(&mut dest), 1);
105 }
106
107 copy_nonoverlapping(start, dest, mid);
109 }
110
111 unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
112 let old = *ptr;
113 *ptr = unsafe { ptr.offset(1) };
114 old
115 }
116}
117
118fn merge_sort<T, F>(v: &mut [T], mut is_less: F)
119where
120 F: FnMut(&T, &T) -> bool,
121{
122 let len = v.len();
123 if len <= 1 {
124 return;
125 }
126 let mut buf = Vec::with_capacity(len / 2);
127 let mut runs: Vec<Run> = vec![];
128 let mut end = len;
129 while end > 0 {
130 let start = end - 1;
131 let mut left = Run {
132 start,
133 len: end - start,
134 };
135 end = start;
136
137 while let Some(right) = runs.pop_if(|right| left.start == 0 || right.len <= left.len) {
138 unsafe {
139 merge(
140 &mut v[left.start..right.start + right.len],
141 left.len,
142 buf.as_mut_ptr(),
143 &mut is_less,
144 );
145 }
146 left = Run {
147 start: left.start,
148 len: left.len + right.len,
149 };
150 }
151 runs.push(left);
152 }
153
154 debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
155
156 #[derive(Clone, Copy)]
157 struct Run {
158 start: usize,
159 len: usize,
160 }
161}
162
163fn insertion_sort<T, F>(v: &mut [T], mut is_less: F)
164where
165 F: FnMut(&T, &T) -> bool,
166{
167 for i in 1..v.len() {
168 let x = &v[i];
169 let p = v[..i].partition_point(|y| is_less(y, x));
170 v[p..=i].rotate_right(1);
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::{algorithm::SliceCombinationsExt, tools::Xorshift};
178
179 macro_rules! test_sort {
180 (@small $sort_method:ident) => {
181 for n in 0..=8 {
182 let a: Vec<_> = (0..n).collect();
183 a.for_each_permutations(n, |a| {
184 let mut x = a.to_vec();
185 let mut y = a.to_vec();
186 x.sort();
187 y.$sort_method();
188 assert_eq!(x, y);
189 });
190 }
191 };
192 (@large $sort_method:ident, $n_ub:expr) => {{
193 let mut rng = Xorshift::default();
194 for _ in 0..10 {
195 let n = rng.random(..$n_ub);
196 let ub = 1 << rng.random(0..20);
197 let a: Vec<_> = rng.random_iter(0..ub).take(n).collect();
198 let mut x = a.to_vec();
199 let mut y = a.to_vec();
200 x.sort();
201 y.$sort_method();
202 assert_eq!(x, y);
203 }
204 }};
205 }
206
207 #[test]
208 fn test_bubble_sort_small() {
209 test_sort!(@small bubble_sort);
210 }
211
212 #[test]
213 fn test_bubble_sort_large() {
214 test_sort!(@large bubble_sort, 3000);
215 }
216
217 #[test]
218 fn test_merge_sort_small() {
219 test_sort!(@small merge_sort);
220 }
221
222 #[test]
223 fn test_merge_sort_large() {
224 test_sort!(@large merge_sort, 100_000);
225 }
226
227 #[test]
228 fn test_insertion_sort_small() {
229 test_sort!(@small insertion_sort);
230 }
231
232 #[test]
233 fn test_insertion_sort_large() {
234 test_sort!(@large insertion_sort, 100_000);
235 }
236}