competitive/tools/
capture.rs

1/// Macro that returns a recursive function that (semi-)automatically captures.
2///
3/// # Example
4/// default version
5/// ```
6/// # use competitive::crecurse;
7/// let mut res = 0usize;
8/// let coeff = 3usize;
9/// crecurse!(
10///     // (1) semi-automatically capture mutable reference (res: &mut usize)
11///     [res: usize],
12///     fn mul(x: usize, y: usize) {
13///         if y > 0 {
14///             if y % 2 == 1 {
15///                 // (2) automatically capture reference (coeff: &usize)
16///                 *res += coeff * x;
17///             }
18///             // (3) call macro to recurse
19///             mul!(x + x, y / 2);
20///         }
21///     }
22/// )(10, 19); // (4) macro returns captured version of the recursive function
23/// assert_eq!(res, coeff * 10 * 19);
24/// ```
25///
26/// unsafe version (automatically capture everything)
27/// ```
28/// # use competitive::crecurse;
29/// let mut res = 0usize;
30/// let coeff = 3usize;
31/// crecurse!(
32///     unsafe fn mul(x: usize, y: usize) {
33///         if y > 0 {
34///             if y % 2 == 1 {
35///                 res += coeff * x;
36///             }
37///             mul!(x + x, y / 2);
38///         }
39///     }
40/// )(10, 19);
41/// assert_eq!(res, coeff * 10 * 19);
42/// ```
43///
44/// no overhead version (semi-automatically capture everything)
45/// ```
46/// # use competitive::crecurse;
47/// let mut res = 0usize;
48/// let coeff = 3usize;
49/// crecurse!(
50///     [res: &mut usize, coeff: &usize],
51///     static fn mul(x: usize, y: usize) {
52///         if y > 0 {
53///             if y % 2 == 1 {
54///                 *res += coeff * x;
55///             }
56///             mul!(x + x, y / 2);
57///         }
58///     }
59/// )(10, 19);
60/// assert_eq!(res, coeff * 10 * 19);
61/// ```
62///
63/// # Syntax
64/// ```txt
65/// crecurse!(
66///     ([($ident: $type),*,?],)?
67///     (unsafe|static)? fn $ident\(($ident: $type),*,?\) (-> $type)? $block
68/// )
69/// ```
70#[macro_export]
71macro_rules! crecurse {
72    (@macro_def ($dol:tt) $name:ident $($cargs:ident)*) => {
73        #[allow(unused_macros)]
74        macro_rules! $name { ($dol($dol args:expr),*) => { $name($dol($dol args,)* $($cargs,)* ) } }
75    };
76
77    (
78        @static [$(($cargs:ident, $cargsexpr:expr, $cargsty:ty))*] [$(,)?],
79        fn $func:ident ($($args:ident: $argsty:ty),* $(,)?) -> $ret:ty $body:block
80    ) => {{
81        fn $func($($args: $argsty,)* $($cargs: $cargsty,)*) -> $ret {
82            $crate::crecurse!(@macro_def ($) $func $($cargs)*);
83            $body
84        }
85        |$($args: $argsty,)*| -> $ret { $func($($args,)* $($cargsexpr,)*) }
86    }};
87    (@static [$($pcaps:tt)*] [$(,)?], fn $func:ident ($($argstt:tt)*) $($rest:tt)*) => {
88        $crate::crecurse!(@static [$($pcaps)*] [], fn $func ($($argstt)*) -> () $($rest)*)
89    };
90    (@static [$($pcaps:tt)*] [$carg:ident: &mut $cargty:ty, $($caps:tt)*], $($rest:tt)*) => {
91        $crate::crecurse!(@static [$($pcaps)* ($carg, &mut $carg, &mut $cargty)] [$($caps)*], $($rest)*)
92    };
93    (@static [$($pcaps:tt)*] [$carg:ident: &$cargty:ty, $($caps:tt)*], $($rest:tt)*) => {
94        $crate::crecurse!(@static [$($pcaps)* ($carg, &$carg, &$cargty)] [$($caps)*], $($rest)*)
95    };
96    (@static [$($pcaps:tt)*] [$carg:ident: $cargty:ty, $($caps:tt)*], $($rest:tt)*) => {
97        $crate::crecurse!(@static [$($pcaps)* ($carg, $carg, $cargty)] [$($caps)*], $($rest)*)
98    };
99    ($([$($caps:tt)*],)? static fn $func:ident ($($args:ident: $argsty:ty),* $(,)?) $($rest:tt)*) => {
100        $crate::crecurse!(@static [] [$($($caps)*)?,], fn $func ($($args: $argsty),*) $($rest)*)
101    };
102
103    (
104        @default [$($cargs:ident: $cargsty:ty),* $(,)?],
105        fn $func:ident ($($args:ident: $argsty:ty),* $(,)?) -> $ret:ty $body:block
106    ) => {{
107        fn call<F>(f: &F, $($args: $argsty,)* $($cargs: &mut $cargsty,)*) -> $ret
108        where
109            F: Fn(&dyn Fn($($argsty,)* $(&mut $cargsty,)*) -> $ret, $($argsty,)* $(&mut $cargsty,)*) -> $ret,
110        {
111            f(
112                &|$($args: $argsty,)* $($cargs: &mut $cargsty,)*| -> $ret {
113                    call(f, $($args,)* $($cargs,)*)
114                },
115                $($args,)* $($cargs,)*
116            )
117        }
118        |$($args: $argsty,)*| -> $ret {
119            call(
120                &|$func, $($args: $argsty,)* $($cargs: &mut $cargsty,)*| -> $ret {
121                    $crate::crecurse!(@macro_def ($) $func $($cargs)*);
122                    $body
123                },
124                $($args,)* $(&mut $cargs,)*
125            )
126        }
127    }};
128    (@default [$($caps:tt)*], fn $func:ident ($($argstt:tt)*) $($rest:tt)*) => {
129        $crate::crecurse!(@default [$($caps)*], fn $func ($($argstt)*) -> () $($rest)*)
130    };
131    ($([$($caps:tt)*],)? fn $func:ident ($($args:ident: $argsty:ty),* $(,)?) $($rest:tt)*) => {
132        $crate::crecurse!(@default [$($($caps)*)?], fn $func ($($args: $argsty),*) $($rest)*)
133    };
134
135    (
136        @unsafe [$($cargs:ident: $cargsty:ty),* $(,)?],
137        fn $func:ident ($($args:ident: $argsty:ty),* $(,)?) -> $ret:ty $body:block
138    ) => {{
139        fn call<F>(f: &mut F, $($args: $argsty,)* $($cargs: &mut $cargsty,)*) -> $ret
140        where
141            F: FnMut(&mut dyn FnMut($($argsty,)* $(&mut $cargsty,)*) -> $ret, $($argsty,)* $(&mut $cargsty,)*) -> $ret,
142        {
143            let fp = f as *mut F;
144            (unsafe { &mut *fp })(
145                &mut |$($args: $argsty,)* $($cargs: &mut $cargsty,)*| -> $ret {
146                    call(unsafe { &mut *fp }, $($args,)* $($cargs,)*)
147                },
148                $($args,)* $($cargs,)*
149            )
150        }
151        |$($args: $argsty,)*| -> $ret {
152            call(
153                &mut |$func, $($args: $argsty,)* $($cargs: &mut $cargsty,)*| -> $ret {
154                    $crate::crecurse!(@macro_def ($) $func $($cargs)*);
155                    $body
156                },
157                $($args,)* $(&mut $cargs,)*
158            )
159        }
160    }};
161
162    (@unsafe [$($caps:tt)*], fn $func:ident ($($argstt:tt)*) $($rest:tt)*) => {
163        $crate::crecurse!(@unsafe [$($caps)*], fn $func ($($argstt)*) -> () $($rest)*)
164    };
165    ($([$($caps:tt)*],)? unsafe fn $func:ident ($($args:ident: $argsty:ty),* $(,)?) $($rest:tt)*) => {
166        $crate::crecurse!(@unsafe [$($($caps)*)?], fn $func ($($args: $argsty),*) $($rest)*)
167    };
168    ($($t:tt)*) => {
169        ::std::compile_error!(::std::concat!("invalid input: ", ::std::stringify!($($t)*)))
170    };
171}
172
173/// Automatic memorization for recursive functions.
174///
175/// This macro binds memorized version of the recursive functions to a local variable.
176/// The specification of the function declaration part is the same as [`crecurse`].
177///
178/// [`crecurse`]: crate::crecurse
179///
180/// # Example
181/// ```
182/// # use competitive::memorize;
183/// memorize!(
184///     fn comb(n: usize, r: usize) -> usize {
185///         if r > n {
186///             0
187///         } else if r == 0 || r == n {
188///             1
189///         } else {
190///             comb!(n - 1, r) + comb!(n - 1, r - 1)
191///         }
192///     }
193/// );
194/// assert_eq!(comb(30, 12), 86493225);
195/// ```
196#[macro_export]
197macro_rules! memorize {
198    (
199        @inner [$map:ident, $Map:ty, $init:expr]
200        fn $name:ident ($($args:ident: $argsty:ty),* $(,)?) -> $ret:ty $body:block
201    ) => {
202        let mut $map: $Map = $init;
203        #[allow(unused_mut)]
204        let mut $name = $crate::crecurse!(
205            [$map: $Map],
206            fn $name ($($args: $argsty),*) -> $ret {
207                if let Some(value) = $map.get(&($($args,)*)).cloned() {
208                    value
209                } else {
210                    let value = (|| $body)();
211                    $map.insert(($($args,)*), value.clone());
212                    value
213                }
214            }
215        );
216    };
217    (fn $name:ident ($($args:ident: $argsty:ty),* $(,)?) -> $ret:ty $body:block) => {
218        $crate::memorize!(
219            @inner [
220                __memorize_map,
221                ::std::collections::HashMap<($($argsty,)*), $ret>,
222                ::std::default::Default::default()
223            ]
224            fn $name ($($args: $argsty),*) -> $ret $body
225        );
226    }
227}