competitive/heuristic/
beam_search.rs

1use std::{cmp::Reverse, collections::HashSet, fmt::Debug, hash::Hash};
2
3pub trait ModifiableState: Debug {
4    type Operation: Clone + Debug;
5    type Score: Clone + Ord + Debug;
6    type Hash: Clone + Eq + Hash + Debug;
7    type Cands: Iterator<Item = Self::Operation>;
8    fn score(&self) -> Self::Score;
9    fn hash(&self) -> Self::Hash;
10    fn accept(&self) -> bool;
11    fn soft_update(
12        &mut self,
13        op: Self::Operation,
14        _score: Self::Score,
15        _hash: Self::Hash,
16    ) -> Option<(Self::Score, Self::Hash, bool)> {
17        self.update(op.clone());
18        let res = (self.score(), self.hash(), self.accept());
19        self.revert(op);
20        Some(res)
21    }
22    fn update(&mut self, op: Self::Operation) {
23        self.change(op);
24    }
25    fn revert(&mut self, op: Self::Operation) {
26        self.change(op);
27    }
28    fn change(&mut self, _op: Self::Operation) {}
29    fn candidates(&self) -> Self::Cands;
30}
31
32#[derive(Debug)]
33pub struct Candidate<S>
34where
35    S: ModifiableState,
36{
37    parent: usize,
38    op: S::Operation,
39    score: S::Score,
40    hash: S::Hash,
41    accept: bool,
42}
43
44impl<S> Clone for Candidate<S>
45where
46    S: ModifiableState,
47{
48    fn clone(&self) -> Self {
49        Self {
50            parent: self.parent,
51            op: self.op.clone(),
52            score: self.score.clone(),
53            hash: self.hash.clone(),
54            accept: self.accept,
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct Node<S>
61where
62    S: ModifiableState,
63{
64    parent: usize,
65    child: usize,
66    prev: usize,
67    next: usize,
68    op: S::Operation,
69    score: S::Score,
70    hash: S::Hash,
71}
72
73impl<S> Node<S>
74where
75    S: ModifiableState,
76{
77    pub fn new(state: &S, init_op: S::Operation) -> Self {
78        Node {
79            parent: !0,
80            child: !0,
81            prev: !0,
82            next: !0,
83            op: init_op,
84            score: state.score(),
85            hash: state.hash(),
86        }
87    }
88}
89
90#[derive(Debug)]
91pub struct Tree<S>
92where
93    S: ModifiableState,
94{
95    state: S,
96    latest: usize,
97    nodes: Vec<Node<S>>,
98    cur_node: usize,
99}
100
101impl<S> Tree<S>
102where
103    S: ModifiableState,
104{
105    pub fn new(state: S, init_op: S::Operation) -> Self {
106        let node = Node::new(&state, init_op);
107        Tree {
108            state,
109            latest: 0,
110            nodes: vec![node],
111            cur_node: 0,
112        }
113    }
114
115    fn add_node(&mut self, op: S::Operation, parent: usize, score: S::Score, hash: S::Hash) {
116        let next = self.nodes[parent].child;
117        if next != !0 {
118            self.nodes[next].prev = self.nodes.len();
119        }
120        self.nodes[parent].child = self.nodes.len();
121
122        self.nodes.push(Node {
123            parent,
124            child: !0,
125            prev: !0,
126            next,
127            op,
128            score,
129            hash,
130        });
131    }
132
133    fn remove_node(&mut self, mut idx: usize) {
134        loop {
135            let Node {
136                prev, next, parent, ..
137            } = self.nodes[idx];
138            assert_ne!(parent, !0);
139            if prev & next == !0 {
140                idx = parent;
141                continue;
142            }
143
144            if prev != !0 {
145                self.nodes[prev].next = next;
146            } else {
147                self.nodes[parent].child = next;
148            }
149            if next != !0 {
150                self.nodes[next].prev = prev;
151            }
152
153            break;
154        }
155    }
156
157    pub fn operations(&self, mut idx: usize) -> Vec<S::Operation> {
158        let mut ret = vec![];
159        loop {
160            let Node { op, parent, .. } = &self.nodes[idx];
161            if *parent == !0 {
162                break;
163            }
164            ret.push(op.clone());
165            idx = *parent;
166        }
167        ret.reverse();
168        ret
169    }
170
171    fn update(&mut self, cands: &mut Vec<Candidate<S>>, beam_weidth: usize, minimize: bool) {
172        if cands.len() > beam_weidth {
173            if minimize {
174                cands.select_nth_unstable_by_key(beam_weidth, |s| s.score.clone());
175            } else {
176                cands.select_nth_unstable_by_key(beam_weidth, |s| Reverse(s.score.clone()));
177            }
178            cands.truncate(beam_weidth);
179        }
180        let len = self.nodes.len();
181        for Candidate {
182            parent,
183            op,
184            score,
185            hash,
186            ..
187        } in cands.drain(..)
188        {
189            self.add_node(op, parent, score, hash);
190        }
191        for i in self.latest..len {
192            if self.nodes[i].child == !0 {
193                self.remove_node(i);
194            }
195        }
196        self.latest = len;
197    }
198
199    pub fn dfs(&mut self, cands: &mut Vec<Candidate<S>>, set: &HashSet<S::Hash>, single: bool) {
200        let node = &self.nodes[self.cur_node];
201        if node.child == !0 {
202            assert!(node.score == self.state.score());
203            assert!(node.hash == self.state.hash());
204
205            for op in self.state.candidates() {
206                if let Some((score, hash, accept)) =
207                    self.state
208                        .soft_update(op.clone(), node.score.clone(), node.hash.clone())
209                {
210                    if !set.contains(&hash) {
211                        cands.push(Candidate {
212                            parent: self.cur_node,
213                            op,
214                            score,
215                            hash,
216                            accept,
217                        });
218                    }
219                };
220            }
221        } else {
222            let node = self.cur_node;
223            let mut child = self.nodes[node].child;
224            let next_single = single & (self.nodes[child].next == !0);
225
226            loop {
227                self.cur_node = child;
228                self.state.update(self.nodes[child].op.clone());
229                self.dfs(cands, set, next_single);
230
231                if !next_single {
232                    self.state.revert(self.nodes[child].op.clone());
233                }
234                child = self.nodes[child].next;
235                if child == !0 {
236                    break;
237                }
238            }
239
240            if !next_single {
241                self.cur_node = node;
242            }
243        }
244    }
245
246    pub fn take_best(
247        &self,
248        cands: &[Candidate<S>],
249        minimize: bool,
250    ) -> Option<(S::Score, Vec<S::Operation>)> {
251        let cands = cands.iter().filter(|cand| cand.accept);
252        if let Some(Candidate {
253            op, parent, score, ..
254        }) = if minimize {
255            cands.min_by_key(|cand| cand.score.clone())
256        } else {
257            cands.max_by_key(|cand| cand.score.clone())
258        } {
259            let mut ret = self.operations(*parent);
260            ret.push(op.clone());
261            Some((score.clone(), ret))
262        } else {
263            None
264        }
265    }
266}
267
268pub fn beam_search<S>(
269    state: S,
270    init_op: S::Operation,
271    beam_weidth: usize,
272    minimize: bool,
273) -> Option<(S::Score, Vec<S::Operation>)>
274where
275    S: ModifiableState,
276{
277    let mut tree = Tree::new(state, init_op);
278    let mut cands = vec![];
279    let mut set = HashSet::<S::Hash>::default();
280    loop {
281        tree.dfs(&mut cands, &set, true);
282        if let Some(res) = tree.take_best(&cands, minimize) {
283            return Some(res);
284        }
285        if cands.is_empty() {
286            return None;
287        }
288        set.extend(cands.iter().map(|cand| cand.hash.clone()));
289        tree.update(&mut cands, beam_weidth, minimize);
290    }
291}