competitive/tools/
coding.rs

1use std::{
2    char::from_u32_unchecked,
3    cmp::Reverse,
4    collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, VecDeque},
5    hash::Hash,
6    iter::repeat_with,
7    mem::size_of,
8};
9
10pub fn unescape(bytes: &[u8]) -> Vec<u8> {
11    let mut buf = Vec::with_capacity(bytes.len());
12    let mut escape = false;
13    for b in bytes {
14        match (escape, *b) {
15            (true, b'n') => {
16                buf.push(b'\n');
17                escape = false;
18            }
19            (true, b'r') => {
20                buf.push(b'\r');
21                escape = false;
22            }
23            (true, b't') => {
24                buf.push(b'\t');
25                escape = false;
26            }
27            (true, b'\\') => {
28                buf.push(b'\\');
29                escape = false;
30            }
31            (true, b'0') => {
32                buf.push(b'\0');
33                escape = false;
34            }
35            (true, b'"') => {
36                buf.push(b'"');
37                escape = false;
38            }
39            (false, b'\\') => {
40                escape = true;
41            }
42            (_, b) => {
43                buf.push(b);
44                escape = false;
45            }
46        }
47    }
48    buf
49}
50
51fn to_bytestring(bytes: &[u8]) -> String {
52    let mut count = [0usize; 128];
53    let mut c = 0u8;
54    let mut w = 0u32;
55    for &b in bytes {
56        w += 1;
57        count[(c | (b >> w)) as usize] += 1;
58        if w == 7 {
59            count[(b & 0b1111111u8) as usize] += 1;
60            c = 0;
61            w = 0;
62        } else {
63            c = (b << (7 - w)) & 0b1111111u8;
64        }
65    }
66    if w > 0 {
67        count[c as usize] += 1;
68        c = 0;
69        w = 0;
70    }
71    let mut salt = 0u8;
72    let mut extra_min = !0usize;
73    for s in 0u8..128 {
74        let mut extra = 0usize;
75        for &b in b"\n\r\t\\\0\"" {
76            extra += count[(b ^ s) as usize];
77            extra += (b == s) as usize;
78        }
79        if extra < extra_min {
80            extra_min = extra;
81            salt = s;
82        }
83    }
84    let cap = extra_min + count.iter().sum::<usize>() + 1;
85    let mut buf = String::with_capacity(cap);
86    macro_rules! escape_branch {
87        ($e:literal) => {{
88            buf.push('\\');
89            buf.push($e);
90        }};
91    }
92    macro_rules! escape {
93        ($b:expr) => {
94            match ($b) ^ salt {
95                b'\n' => escape_branch!('n'),
96                b'\r' => escape_branch!('r'),
97                b'\t' => escape_branch!('t'),
98                b'\\' => escape_branch!('\\'),
99                b'\0' => escape_branch!('0'),
100                // b'\'' => escape_branch!('\''),
101                b'"' => escape_branch!('"'),
102                b => buf.push(b as char),
103            }
104        };
105    }
106    escape!(0);
107    for &b in bytes {
108        w += 1;
109        escape!(c | (b >> w));
110        if w == 7 {
111            escape!(b & 0b1111111u8);
112            c = 0;
113            w = 0;
114        } else {
115            c = (b << (7 - w)) & 0b1111111u8;
116        }
117    }
118    if w > 0 {
119        escape!(c);
120    }
121    assert_eq!(cap, buf.len());
122    buf
123}
124
125fn from_bytestring(bytes: &[u8]) -> Vec<u8> {
126    assert!(!bytes.is_empty());
127    let cap = (bytes.len() - 1) * 7 / 8;
128    let mut buf = Vec::with_capacity(cap);
129    let salt = bytes[0];
130    let bytes = &bytes[1..];
131    let mut c = 0u8;
132    let mut w = 0u32;
133    for &b in bytes {
134        let b = b ^ salt;
135        if w == 0 {
136            c = b << 1;
137            w = 7;
138        } else {
139            w -= 1;
140            buf.push(c | (b >> w));
141            c = if w > 0 { b << (8 - w) } else { 0 };
142        }
143    }
144    assert_eq!(cap, buf.len());
145    buf
146}
147
148#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
149enum HuffmanTree {
150    Leaf(u8),
151    Node(Box<HuffmanTree>, Box<HuffmanTree>),
152}
153
154#[derive(Debug)]
155struct BitWriter {
156    bytes: Vec<u8>,
157    last: u8,
158    w: u32,
159}
160
161impl Default for BitWriter {
162    fn default() -> Self {
163        Self {
164            bytes: Default::default(),
165            last: Default::default(),
166            w: 8,
167        }
168    }
169}
170
171impl BitWriter {
172    fn push_bit(&mut self, b: bool) {
173        self.w -= 1;
174        self.last |= (b as u8) << self.w;
175        if self.w == 0 {
176            self.bytes.push(self.last);
177            self.w = 8;
178            self.last = 0;
179        }
180    }
181    fn push_u8(&mut self, b: u8) {
182        self.bytes.push(self.last | (b >> (8 - self.w)));
183        if self.w < 8 {
184            self.last = b << self.w;
185        }
186    }
187    fn push_u64(&mut self, b: u64, mut c: u32) {
188        let k = self.w.min(c);
189        self.w -= k;
190        c -= k;
191        self.last |= (b >> c << self.w) as u8;
192        if self.w == 0 {
193            self.bytes.push(self.last);
194            let (s, t) = (c / 8, c % 8);
195            for _ in 0..s {
196                c -= 8;
197                self.bytes.push((b >> c) as u8);
198            }
199            self.last = 0;
200            self.w = 8 - t;
201            if t > 0 {
202                self.last = (b as u8) << self.w;
203            }
204        }
205    }
206    fn into_inner(mut self) -> Vec<u8> {
207        if self.w < 8 {
208            self.bytes.push(self.last);
209        }
210        self.bytes
211    }
212}
213
214#[derive(Debug)]
215struct BitReader<'a> {
216    bytes: &'a [u8],
217    pos: u32,
218}
219
220impl<'a> BitReader<'a> {
221    fn new(bytes: &'a [u8]) -> Self {
222        Self { bytes, pos: 0 }
223    }
224    fn read_bit(&mut self) -> bool {
225        let b = (self.bytes[0] >> (7 - self.pos)) & 1 == 1;
226        self.pos += 1;
227        if self.pos == 8 {
228            self.pos = 0;
229            self.bytes = &self.bytes[1..];
230        }
231        b
232    }
233    fn read_u8(&mut self) -> u8 {
234        let b = self.bytes[0] << self.pos;
235        self.bytes = &self.bytes[1..];
236        if self.pos == 0 {
237            b
238        } else {
239            b | (self.bytes[0] >> (8 - self.pos))
240        }
241    }
242}
243
244fn huffman_coding(bytes: &[u8]) -> Vec<u8> {
245    fn make_table(t: &HuffmanTree, code: u64, len: u32, table: &mut [(u64, u32)]) {
246        match t {
247            HuffmanTree::Leaf(i) => {
248                table[*i as usize] = (code, len.max(1));
249            }
250            HuffmanTree::Node(l, r) => {
251                make_table(l, code << 1, len + 1, table);
252                make_table(r, (code << 1) | 1, len + 1, table);
253            }
254        }
255    }
256    fn output_tree(t: &HuffmanTree, buf: &mut BitWriter) {
257        match t {
258            HuffmanTree::Leaf(i) => {
259                buf.push_bit(false);
260                buf.push_u8(*i);
261            }
262            HuffmanTree::Node(l, r) => {
263                buf.push_bit(true);
264                output_tree(l, buf);
265                output_tree(r, buf);
266            }
267        }
268    }
269
270    let mut freq = [0usize; 256];
271    for &b in &bytes.len().to_le_bytes() {
272        freq[b as usize] += 1;
273    }
274    for &b in bytes {
275        freq[b as usize] += 1;
276    }
277    let mut heap = BinaryHeap::new();
278    for (i, &f) in freq.iter().enumerate() {
279        if f > 0 {
280            heap.push(Reverse((f, 0usize, HuffmanTree::Leaf(i as _))));
281        }
282    }
283    let t = if heap.is_empty() {
284        HuffmanTree::Node(
285            Box::new(HuffmanTree::Leaf(0)),
286            Box::new(HuffmanTree::Leaf(0)),
287        )
288    } else {
289        loop {
290            let Reverse((f, c, t)) = heap.pop().unwrap();
291            if let Some(Reverse((ff, cc, tt))) = heap.pop() {
292                heap.push(Reverse((
293                    f + ff,
294                    c.max(cc) + 1,
295                    HuffmanTree::Node(Box::new(t), Box::new(tt)),
296                )));
297            } else {
298                break t;
299            }
300        }
301    };
302
303    let mut table = vec![(0u64, 0u32); 256];
304    make_table(&t, 0, 0, &mut table);
305    let mut buf = BitWriter::default();
306    output_tree(&t, &mut buf);
307    for &b in &bytes.len().to_le_bytes() {
308        let (x, y) = table[b as usize];
309        buf.push_u64(x, y);
310    }
311    for &b in bytes {
312        let (x, y) = table[b as usize];
313        buf.push_u64(x, y);
314    }
315    buf.into_inner()
316}
317
318fn huffman_decoding(bytes: &[u8]) -> Vec<u8> {
319    fn read_tree(reader: &mut BitReader) -> HuffmanTree {
320        if reader.read_bit() {
321            HuffmanTree::Node(Box::new(read_tree(reader)), Box::new(read_tree(reader)))
322        } else {
323            HuffmanTree::Leaf(reader.read_u8())
324        }
325    }
326    fn decode(mut t: &HuffmanTree, reader: &mut BitReader) -> u8 {
327        loop {
328            match t {
329                HuffmanTree::Leaf(i) => break *i,
330                HuffmanTree::Node(l, r) => t = if reader.read_bit() { r } else { l },
331            }
332        }
333    }
334
335    let mut reader = BitReader::new(bytes);
336    let t = read_tree(&mut reader);
337    const C: usize = size_of::<usize>();
338    let mut size: [u8; C] = [0u8; C];
339    for b in &mut size {
340        *b = decode(&t, &mut reader);
341    }
342    let size = usize::from_le_bytes(size);
343    let mut buf = vec![];
344    for i in 0..size {
345        buf.push(decode(&t, &mut reader));
346        if i < 10 {}
347    }
348    buf
349}
350
351pub trait SerdeByteStr {
352    fn serialize(&self, buf: &mut Vec<u8>);
353
354    fn deserialize<I>(iter: &mut I) -> Self
355    where
356        I: Iterator<Item = u8>;
357
358    fn serialize_bytestr(&self) -> String {
359        let mut bytes = vec![];
360        self.serialize(&mut bytes);
361        let bytes = huffman_coding(&bytes);
362        to_bytestring(&bytes)
363    }
364
365    fn deserialize_from_bytes(bytes: &[u8]) -> Self
366    where
367        Self: Sized,
368    {
369        let bytes = from_bytestring(bytes);
370        let bytes = huffman_decoding(&bytes);
371        Self::deserialize(&mut bytes.as_slice().iter().cloned())
372    }
373}
374
375impl SerdeByteStr for bool {
376    fn serialize(&self, buf: &mut Vec<u8>) {
377        (*self as u8).serialize(buf)
378    }
379    fn deserialize<I>(iter: &mut I) -> Self
380    where
381        I: Iterator<Item = u8>,
382    {
383        iter.next().unwrap() != 0
384    }
385}
386
387macro_rules! impl_serdebytestr_num {
388    ($($t:ty)*) => {
389        $(
390            impl SerdeByteStr for $t {
391                fn serialize(&self, buf: &mut Vec<u8>) {
392                    buf.extend(self.to_le_bytes().iter());
393                }
394                fn deserialize<I>(iter: &mut I) -> Self
395                where
396                    I: Iterator<Item = u8>,
397                {
398                    const C: usize = size_of::<$t>();
399                    let mut bytes: [u8; C] = [0u8; C];
400                    for (b, i) in bytes.iter_mut().zip(iter) {
401                        *b = i;
402                    }
403                    <$t>::from_le_bytes(bytes)
404                }
405            }
406        )*
407    };
408}
409
410impl_serdebytestr_num!(u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize f32 f64);
411
412impl SerdeByteStr for char {
413    fn serialize(&self, buf: &mut Vec<u8>) {
414        (*self as u32).serialize(buf)
415    }
416    fn deserialize<I>(iter: &mut I) -> Self
417    where
418        I: Iterator<Item = u8>,
419    {
420        unsafe { from_u32_unchecked(u32::deserialize(iter)) }
421    }
422}
423
424impl SerdeByteStr for () {
425    fn serialize(&self, _buf: &mut Vec<u8>) {}
426    fn deserialize<I>(_iter: &mut I) -> Self
427    where
428        I: Iterator<Item = u8>,
429    {
430    }
431}
432
433macro_rules! impl_serdebytestr_tuple {
434    (@impl $($A:ident $a:ident)*) => {
435        impl<$($A,)*> SerdeByteStr for ($($A,)*)
436        where
437            $($A: SerdeByteStr),*
438        {
439            fn serialize(&self, buf: &mut Vec<u8>) {
440                let ($($a,)*) = self;
441                $(SerdeByteStr::serialize($a, buf));*
442            }
443            fn deserialize<Iter>(iter: &mut Iter) -> Self
444            where
445                Iter: Iterator<Item = u8>,
446            {
447                ($(<$A as SerdeByteStr>::deserialize(iter),)*)
448            }
449        }
450    };
451    (@inc , $B:ident $b:ident $($C:ident $c:ident)*) => {
452        impl_serdebytestr_tuple!(@inc $B $b, $($C $c)*);
453    };
454    (@inc $($A:ident $a:ident)*, $B:ident $b:ident $($C:ident $c:ident)*) => {
455        impl_serdebytestr_tuple!(@impl $($A $a)*);
456        impl_serdebytestr_tuple!(@inc $($A $a)* $B $b, $($C $c)*);
457    };
458    (@inc $($A:ident $a:ident)*,) => {
459        impl_serdebytestr_tuple!(@impl $($A $a)*);
460    };
461    ($($t:tt)*) => {
462        impl_serdebytestr_tuple!(@inc , $($t)*);
463    };
464}
465impl_serdebytestr_tuple!(A a B b C c D d E e F f G g H h I i J j K k);
466
467impl<T> SerdeByteStr for Option<T>
468where
469    T: SerdeByteStr,
470{
471    fn serialize(&self, buf: &mut Vec<u8>) {
472        self.is_some().serialize(buf);
473        if let Some(x) = self {
474            x.serialize(buf);
475        }
476    }
477    fn deserialize<I>(iter: &mut I) -> Self
478    where
479        I: Iterator<Item = u8>,
480    {
481        if bool::deserialize(iter) {
482            Some(T::deserialize(iter))
483        } else {
484            None
485        }
486    }
487}
488
489impl SerdeByteStr for String {
490    fn serialize(&self, buf: &mut Vec<u8>) {
491        let bytes = self.bytes();
492        bytes.len().serialize(buf);
493        for x in bytes {
494            x.serialize(buf);
495        }
496    }
497    fn deserialize<I>(iter: &mut I) -> Self
498    where
499        I: Iterator<Item = u8>,
500    {
501        let n = usize::deserialize(iter);
502        unsafe {
503            String::from_utf8_unchecked(
504                repeat_with(|| u8::deserialize(iter))
505                    .take(n)
506                    .collect::<Vec<u8>>(),
507            )
508        }
509    }
510}
511
512macro_rules! impl_serdebytestr_seq {
513    ($([$($g:ident)*] $t:ty $(where [$($tt:tt)*])?),* $(,)?) => {
514        $(
515            impl<$($g),*> SerdeByteStr for $t
516            where
517                $($g: SerdeByteStr,)*
518                $($($tt)*)?
519            {
520                fn serialize(&self, buf: &mut Vec<u8>) {
521                    self.len().serialize(buf);
522                    for x in self {
523                        x.serialize(buf);
524                    }
525                }
526                fn deserialize<I>(iter: &mut I) -> Self
527                where
528                    I: Iterator<Item = u8>,
529                {
530                    let n = usize::deserialize(iter);
531                    repeat_with(|| SerdeByteStr::deserialize(iter)).take(n).collect()
532                }
533            }
534        )*
535    };
536    (@kv $([$($g:ident)*] $t:ty $(where [$($tt:tt)*])?),* $(,)?) => {
537        $(
538            impl<$($g),*> SerdeByteStr for $t
539            where
540                $($g: SerdeByteStr,)*
541                $($($tt)*)?
542            {
543                fn serialize(&self, buf: &mut Vec<u8>) {
544                    self.len().serialize(buf);
545                    for (k, v) in self {
546                        k.serialize(buf);
547                        v.serialize(buf);
548                    }
549                }
550                fn deserialize<I>(iter: &mut I) -> Self
551                where
552                    I: Iterator<Item = u8>,
553                {
554                    let n = usize::deserialize(iter);
555                    repeat_with(|| SerdeByteStr::deserialize(iter)).take(n).collect()
556                }
557            }
558        )*
559    };
560}
561
562impl_serdebytestr_seq!(
563    [T] Vec<T>,
564    [T] VecDeque<T>,
565    [T] BinaryHeap<T> where [T: Ord],
566    [T] BTreeSet<T> where [T: Ord],
567    [T] HashSet<T> where [T: Eq + Hash],
568);
569impl_serdebytestr_seq!(
570    @kv
571    [K V] BTreeMap<K, V> where [K: Ord],
572    [K V] HashMap<K, V> where [K: Eq + Hash],
573);
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_bitrw() {
581        let mut writer = BitWriter::default();
582        for i in 0..10 {
583            writer.push_bit(i % 3 == 0);
584        }
585        for i in 0..10 {
586            writer.push_u8(i);
587        }
588        for i in 1..=10 {
589            writer.push_u64(i, i as _);
590        }
591        let bytes = writer.into_inner();
592
593        let mut reader = BitReader::new(&bytes);
594        for i in 0..10 {
595            assert_eq!(i % 3 == 0, reader.read_bit());
596        }
597        for i in 0..10 {
598            assert_eq!(i, reader.read_u8());
599        }
600        for i in 1..=10 {
601            let mut x = 0u64;
602            for j in (0..i).rev() {
603                x |= (reader.read_bit() as u64) << j;
604            }
605            assert_eq!(i, x);
606        }
607    }
608
609    #[test]
610    fn test_serde() {
611        let a = (
612            (0..=255).collect::<Vec<u8>>(),
613            String::from_utf8((0..128).collect::<Vec<u8>>()).unwrap(),
614            (0..=255).collect::<VecDeque<u64>>(),
615            (0..=255).collect::<BTreeSet<usize>>(),
616            (-255..=255).collect::<HashSet<i128>>(),
617        );
618        let b = a.serialize_bytestr();
619        let c = SerdeByteStr::deserialize_from_bytes(&unescape(b.as_bytes()));
620        assert_eq!(a, c);
621
622        let a = (0, 0);
623        let b = a.serialize_bytestr();
624        let c = SerdeByteStr::deserialize_from_bytes(&unescape(b.as_bytes()));
625        assert_eq!(a, c);
626    }
627}