competitive/algorithm/
automata_learning.rs

1use super::{BitSet, Field, Invertible, Matrix, RandomSpec, SerdeByteStr, Xorshift};
2use std::{
3    cell::RefCell,
4    collections::{HashMap, HashSet},
5    fmt::{self, Debug},
6    iter::{from_fn, once_with},
7    marker::PhantomData,
8    time::Instant,
9};
10
11pub trait BlackBoxAutomaton {
12    type Output;
13    fn sigma(&self) -> usize; // Σ={0,1,...,sigma-1}
14    fn behavior<I>(&self, input: I) -> Self::Output
15    where
16        I: IntoIterator<Item = usize>;
17}
18
19#[derive(Debug, Clone)]
20pub struct BlackBoxAutomatonImpl<T, F>
21where
22    F: Fn(Vec<usize>) -> T,
23{
24    sigma: usize,
25    behavior_fn: F,
26    memo: RefCell<HashMap<Vec<usize>, T>>,
27}
28
29impl<T, F> BlackBoxAutomatonImpl<T, F>
30where
31    F: Fn(Vec<usize>) -> T,
32{
33    pub fn new(sigma: usize, behavior_fn: F) -> Self {
34        Self {
35            sigma,
36            behavior_fn,
37            memo: RefCell::new(HashMap::new()),
38        }
39    }
40}
41
42impl<T, F> BlackBoxAutomaton for BlackBoxAutomatonImpl<T, F>
43where
44    F: Fn(Vec<usize>) -> T,
45    T: Clone,
46{
47    type Output = T;
48
49    fn sigma(&self) -> usize {
50        self.sigma
51    }
52
53    fn behavior<I>(&self, input: I) -> Self::Output
54    where
55        I: IntoIterator<Item = usize>,
56    {
57        let input: Vec<usize> = input.into_iter().collect();
58        self.memo
59            .borrow_mut()
60            .entry(input.clone())
61            .or_insert_with(|| (self.behavior_fn)(input))
62            .clone()
63    }
64}
65
66impl<A> BlackBoxAutomaton for &A
67where
68    A: BlackBoxAutomaton,
69{
70    type Output = A::Output;
71
72    fn sigma(&self) -> usize {
73        (*self).sigma()
74    }
75
76    fn behavior<I>(&self, input: I) -> Self::Output
77    where
78        I: IntoIterator<Item = usize>,
79    {
80        (*self).behavior(input)
81    }
82}
83
84#[derive(Debug, Clone)]
85struct DfaState {
86    delta: Vec<usize>,
87    accept: bool,
88}
89
90#[derive(Debug, Clone)]
91pub struct DeterministicFiniteAutomaton {
92    states: Vec<DfaState>,
93    initial_state: usize,
94}
95
96impl DeterministicFiniteAutomaton {
97    pub fn size(&self) -> usize {
98        self.states.len()
99    }
100    pub fn delta(&self, state: usize, input: usize) -> usize {
101        assert!(state < self.states.len());
102        assert!(input < self.states[0].delta.len());
103        self.states[state].delta[input]
104    }
105    pub fn accept(&self, state: usize) -> bool {
106        assert!(state < self.states.len());
107        self.states[state].accept
108    }
109}
110
111impl BlackBoxAutomaton for DeterministicFiniteAutomaton {
112    type Output = bool;
113
114    fn sigma(&self) -> usize {
115        self.states[0].delta.len()
116    }
117
118    fn behavior<I>(&self, input: I) -> Self::Output
119    where
120        I: IntoIterator<Item = usize>,
121    {
122        let mut state = self.initial_state;
123        for x in input {
124            state = self.states[state].delta[x];
125        }
126        self.states[state].accept
127    }
128}
129
130impl SerdeByteStr for DfaState {
131    fn serialize(&self, buf: &mut Vec<u8>) {
132        self.delta.serialize(buf);
133        self.accept.serialize(buf);
134    }
135
136    fn deserialize<I>(iter: &mut I) -> Self
137    where
138        I: Iterator<Item = u8>,
139    {
140        let delta = Vec::deserialize(iter);
141        let accept = bool::deserialize(iter);
142        Self { delta, accept }
143    }
144}
145
146impl SerdeByteStr for DeterministicFiniteAutomaton {
147    fn serialize(&self, buf: &mut Vec<u8>) {
148        self.states.serialize(buf);
149        self.initial_state.serialize(buf);
150    }
151
152    fn deserialize<I>(iter: &mut I) -> Self
153    where
154        I: Iterator<Item = u8>,
155    {
156        let states = Vec::deserialize(iter);
157        let initial_state = usize::deserialize(iter);
158        Self {
159            states,
160            initial_state,
161        }
162    }
163}
164
165pub struct WeightedFiniteAutomaton<F>
166where
167    F: Field,
168    F::Additive: Invertible,
169    F::Multiplicative: Invertible,
170{
171    pub initial_weights: Matrix<F>,
172    pub transitions: Vec<Matrix<F>>,
173    pub final_weights: Matrix<F>,
174}
175
176impl<F> Debug for WeightedFiniteAutomaton<F>
177where
178    F: Field,
179    F::Additive: Invertible,
180    F::Multiplicative: Invertible,
181    F::T: Debug,
182{
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        f.debug_struct("WeightedFiniteAutomaton")
185            .field("initial_weights", &self.initial_weights)
186            .field("transitions", &self.transitions)
187            .field("final_weights", &self.final_weights)
188            .finish()
189    }
190}
191
192impl<F> Clone for WeightedFiniteAutomaton<F>
193where
194    F: Field,
195    F::Additive: Invertible,
196    F::Multiplicative: Invertible,
197{
198    fn clone(&self) -> Self {
199        Self {
200            initial_weights: self.initial_weights.clone(),
201            transitions: self.transitions.clone(),
202            final_weights: self.final_weights.clone(),
203        }
204    }
205}
206
207impl<F> BlackBoxAutomaton for WeightedFiniteAutomaton<F>
208where
209    F: Field,
210    F::Additive: Invertible,
211    F::Multiplicative: Invertible,
212{
213    type Output = F::T;
214
215    fn sigma(&self) -> usize {
216        self.transitions.len()
217    }
218
219    fn behavior<I>(&self, input: I) -> Self::Output
220    where
221        I: IntoIterator<Item = usize>,
222    {
223        let mut weights = self.initial_weights.clone();
224        for x in input {
225            weights = &weights * &self.transitions[x];
226        }
227        let result = &weights * &self.final_weights;
228        if result.shape != (0, 0) {
229            result[0][0].clone()
230        } else {
231            F::zero()
232        }
233    }
234}
235
236impl<F> SerdeByteStr for WeightedFiniteAutomaton<F>
237where
238    F: Field,
239    F::Additive: Invertible,
240    F::Multiplicative: Invertible,
241    F::T: SerdeByteStr,
242{
243    fn serialize(&self, buf: &mut Vec<u8>) {
244        self.initial_weights.serialize(buf);
245        self.transitions.serialize(buf);
246        self.final_weights.serialize(buf);
247    }
248
249    fn deserialize<I>(iter: &mut I) -> Self
250    where
251        I: Iterator<Item = u8>,
252    {
253        let initial_weights = Matrix::deserialize(iter);
254        let transitions = Vec::deserialize(iter);
255        let final_weights = Matrix::deserialize(iter);
256        Self {
257            initial_weights,
258            transitions,
259            final_weights,
260        }
261    }
262}
263
264pub fn dense_sampling(sigma: usize, max_len: usize) -> impl Iterator<Item = Vec<usize>> {
265    assert_ne!(sigma, 0, "Sigma must be greater than 0");
266    let mut current = vec![];
267    once_with(Vec::new).chain(from_fn(move || {
268        let mut carry = true;
269        for i in (0..current.len()).rev() {
270            current[i] += 1;
271            if current[i] == sigma {
272                current[i] = 0;
273            } else {
274                carry = false;
275                break;
276            }
277        }
278        if carry {
279            current.push(0);
280        }
281        if current.len() <= max_len {
282            Some(current.to_vec())
283        } else {
284            None
285        }
286    }))
287}
288
289pub fn random_sampling(
290    sigma: usize,
291    len_spec: impl RandomSpec<usize>,
292    seconds: f64,
293) -> impl Iterator<Item = Vec<usize>> {
294    assert_ne!(sigma, 0, "Sigma must be greater than 0");
295    let now = Instant::now();
296    let mut rng = Xorshift::new();
297    from_fn(move || {
298        if now.elapsed().as_secs_f64() > seconds {
299            None
300        } else {
301            let n = rng.random(&len_spec);
302            Some(rng.random_iter(0..sigma).take(n).collect())
303        }
304    })
305}
306
307#[derive(Debug, Clone)]
308pub struct DfaLearning<A>
309where
310    A: BlackBoxAutomaton<Output = bool>,
311{
312    automaton: A,
313    prefixes: Vec<Vec<usize>>,
314    suffixes: Vec<Vec<usize>>,
315    table: Vec<BitSet>,
316    row_map: HashMap<BitSet, usize>,
317}
318
319impl<A> DfaLearning<A>
320where
321    A: BlackBoxAutomaton<Output = bool>,
322{
323    pub fn new(automaton: A) -> Self {
324        let mut this = Self {
325            automaton,
326            prefixes: vec![],
327            suffixes: vec![],
328            table: vec![],
329            row_map: HashMap::new(),
330        };
331        this.add_suffix(vec![]);
332        this.add_prefix(vec![]);
333        this
334    }
335    fn add_prefix(&mut self, prefix: Vec<usize>) -> usize {
336        let row: BitSet = self
337            .suffixes
338            .iter()
339            .map(|s| {
340                self.automaton
341                    .behavior(prefix.iter().cloned().chain(s.iter().cloned()))
342            })
343            .collect();
344        *self.row_map.entry(row.clone()).or_insert_with(|| {
345            let idx = self.table.len();
346            self.table.push(row);
347            self.prefixes.push(prefix);
348            idx
349        })
350    }
351    fn add_suffix(&mut self, suffix: Vec<usize>) {
352        if self.suffixes.contains(&suffix) {
353            return;
354        }
355        for (prefix, table) in self.prefixes.iter_mut().zip(&mut self.table) {
356            table.push(
357                self.automaton
358                    .behavior(prefix.iter().cloned().chain(suffix.iter().cloned())),
359            );
360        }
361        self.suffixes.push(suffix);
362        self.row_map.clear();
363        for (i_prefix, row) in self.table.iter().enumerate() {
364            self.row_map.insert(row.clone(), i_prefix);
365        }
366    }
367    pub fn construct_dfa(&mut self) -> DeterministicFiniteAutomaton {
368        let sigma = self.automaton.sigma();
369        let mut dfa = DeterministicFiniteAutomaton {
370            states: vec![],
371            initial_state: 0,
372        };
373        let mut i_prefix = 0;
374        while i_prefix < self.prefixes.len() {
375            let mut delta = vec![];
376            for x in 0..sigma {
377                let prefix: Vec<usize> =
378                    self.prefixes[i_prefix].iter().cloned().chain([x]).collect();
379                let index = self.add_prefix(prefix);
380                delta.push(index);
381            }
382            dfa.states.push(DfaState {
383                delta,
384                accept: self.table[i_prefix].get(0),
385            });
386            i_prefix += 1;
387        }
388        dfa
389    }
390    pub fn train_sample(&mut self, dfa: &DeterministicFiniteAutomaton, sample: &[usize]) -> bool {
391        let expected = self.automaton.behavior(sample.iter().cloned());
392        if expected == dfa.behavior(sample.iter().cloned()) {
393            return false;
394        }
395        let n = sample.len();
396        let mut states: Vec<(usize, usize)> = Vec::with_capacity(n + 1);
397        let mut s = 0usize;
398        states.push((s, 0));
399        for (k, &x) in sample.iter().enumerate() {
400            s = dfa.states[s].delta[x];
401            states.push((s, k + 1));
402        }
403        let split = states.partition_point(|&(state, k)| {
404            self.automaton.behavior(
405                self.prefixes[state]
406                    .iter()
407                    .cloned()
408                    .chain(sample[k..].iter().cloned()),
409            ) == expected
410        });
411        let new_prefix = sample[..split].to_vec();
412        let new_suffix = sample[split..].to_vec();
413        self.add_suffix(new_suffix);
414        self.add_prefix(new_prefix);
415        true
416    }
417    pub fn train(
418        &mut self,
419        samples: impl IntoIterator<Item = Vec<usize>>,
420    ) -> DeterministicFiniteAutomaton {
421        let mut dfa = self.construct_dfa();
422        for sample in samples {
423            if self.train_sample(&dfa, &sample) {
424                dfa = self.construct_dfa();
425            }
426        }
427        dfa
428    }
429}
430
431pub struct WfaLearning<F, A>
432where
433    F: Field,
434    F::Additive: Invertible,
435    F::Multiplicative: Invertible,
436    A: BlackBoxAutomaton<Output = F::T>,
437{
438    automaton: A,
439    prefixes: Vec<Vec<usize>>,
440    suffixes: Vec<Vec<usize>>,
441    inv_h: Matrix<F>,
442    nh: Vec<Matrix<F>>,
443    wfa: WeightedFiniteAutomaton<F>,
444    _marker: PhantomData<fn() -> F>,
445}
446
447impl<F, A> Debug for WfaLearning<F, A>
448where
449    F: Field,
450    F::Additive: Invertible,
451    F::Multiplicative: Invertible,
452    F::T: Debug,
453    A: BlackBoxAutomaton<Output = F::T> + Debug,
454{
455    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456        f.debug_struct("WfaLearning")
457            .field("automaton", &self.automaton)
458            .field("prefixes", &self.prefixes)
459            .field("suffixes", &self.suffixes)
460            .field("inv_h", &self.inv_h)
461            .field("nh", &self.nh)
462            .field("wfa", &self.wfa)
463            .finish()
464    }
465}
466
467impl<F, A> Clone for WfaLearning<F, A>
468where
469    F: Field,
470    F::Additive: Invertible,
471    F::Multiplicative: Invertible,
472    A: BlackBoxAutomaton<Output = F::T> + Clone,
473{
474    fn clone(&self) -> Self {
475        Self {
476            automaton: self.automaton.clone(),
477            prefixes: self.prefixes.clone(),
478            suffixes: self.suffixes.clone(),
479            inv_h: self.inv_h.clone(),
480            nh: self.nh.clone(),
481            wfa: self.wfa.clone(),
482            _marker: self._marker,
483        }
484    }
485}
486
487impl<F, A> WfaLearning<F, A>
488where
489    F: Field,
490    F::Additive: Invertible,
491    F::Multiplicative: Invertible,
492    F::T: PartialEq,
493    A: BlackBoxAutomaton<Output = F::T>,
494{
495    pub fn new(automaton: A) -> Self {
496        let sigma = automaton.sigma();
497        Self {
498            automaton,
499            prefixes: vec![],
500            suffixes: vec![],
501            inv_h: Matrix::zeros((0, 0)),
502            nh: vec![Matrix::zeros((0, 0)); sigma],
503            wfa: WeightedFiniteAutomaton {
504                initial_weights: Matrix::zeros((1, 0)),
505                transitions: vec![Matrix::zeros((0, 0)); sigma],
506                final_weights: Matrix::zeros((0, 1)),
507            },
508            _marker: PhantomData,
509        }
510    }
511    pub fn wfa(&self) -> &WeightedFiniteAutomaton<F> {
512        &self.wfa
513    }
514    fn split_sample(&mut self, sample: &[usize]) -> Option<(Vec<usize>, Vec<usize>)> {
515        if self.prefixes.is_empty() && !F::is_zero(&self.automaton.behavior(sample.iter().cloned()))
516        {
517            return Some((vec![], sample.to_vec()));
518        }
519        let expected = self.automaton.behavior(sample.iter().cloned());
520        if expected == self.wfa.behavior(sample.iter().cloned()) {
521            return None;
522        }
523        let n = sample.len();
524        let dim = self.wfa.final_weights.shape.0;
525        let mut states: Vec<(Matrix<F>, usize)> = Vec::with_capacity(n + 1);
526        let mut v = self.wfa.final_weights.clone();
527        states.push((v.clone(), n));
528        for k in (0..n).rev() {
529            v = &self.wfa.transitions[sample[k]] * &v;
530            states.push((v.clone(), k));
531        }
532        states.reverse();
533        let split = states.partition_point(|(state, k)| {
534            (0..dim).any(|j| {
535                self.automaton.behavior(
536                    self.prefixes[j]
537                        .iter()
538                        .cloned()
539                        .chain(sample[*k..].iter().cloned()),
540                ) != state[j][0]
541            })
542        });
543        Some((sample[..split].to_vec(), sample[split..].to_vec()))
544    }
545    pub fn train_sample(&mut self, sample: &[usize]) -> bool {
546        let Some((prefix, suffix)) = self.split_sample(sample) else {
547            return false;
548        };
549        self.prefixes.push(prefix);
550        self.suffixes.push(suffix);
551        let n = self.inv_h.shape.0;
552        let prefix = &self.prefixes[n];
553        let suffix = &self.suffixes[n];
554        let u = Matrix::<F>::new_with((n, 1), |i, _| {
555            self.automaton.behavior(
556                self.prefixes[i]
557                    .iter()
558                    .cloned()
559                    .chain(suffix.iter().cloned()),
560            )
561        });
562        let v = Matrix::<F>::new_with((1, n), |_, j| {
563            self.automaton.behavior(
564                prefix
565                    .iter()
566                    .cloned()
567                    .chain(self.suffixes[j].iter().cloned()),
568            )
569        });
570        let w = Matrix::<F>::new_with((1, 1), |_, _| {
571            self.automaton
572                .behavior(prefix.iter().cloned().chain(suffix.iter().cloned()))
573        });
574        let t = &self.inv_h * &u;
575        let s = &v * &self.inv_h;
576        let d = F::inv(&(&w - &(&v * &t))[0][0]);
577        let dh = &t * &s;
578        for i in 0..n {
579            for j in 0..n {
580                F::add_assign(&mut self.inv_h[i][j], &F::mul(&dh[i][j], &d));
581            }
582        }
583        self.inv_h
584            .add_col_with(|i, _| F::neg(&F::mul(&t[i][0], &d)));
585        self.inv_h.add_row_with(|_, j| {
586            if j != n {
587                F::neg(&F::mul(&s[0][j], &d))
588            } else {
589                d.clone()
590            }
591        });
592
593        for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
594            let b = &(&self.nh[x] * &t) * &s;
595            for i in 0..n {
596                for j in 0..n {
597                    F::add_assign(&mut transition[i][j], &F::mul(&b[i][j], &d));
598                }
599            }
600        }
601        for (x, nh) in self.nh.iter_mut().enumerate() {
602            nh.add_col_with(|i, j| {
603                self.automaton.behavior(
604                    self.prefixes[i]
605                        .iter()
606                        .cloned()
607                        .chain([x])
608                        .chain(self.suffixes[j].iter().cloned()),
609                )
610            });
611            nh.add_row_with(|i, j| {
612                self.automaton.behavior(
613                    self.prefixes[i]
614                        .iter()
615                        .cloned()
616                        .chain([x])
617                        .chain(self.suffixes[j].iter().cloned()),
618                )
619            });
620        }
621        self.wfa
622            .initial_weights
623            .add_col_with(|_, _| if n == 0 { F::one() } else { F::zero() });
624        self.wfa
625            .final_weights
626            .add_row_with(|_, _| self.automaton.behavior(prefix.iter().cloned()));
627        for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
628            transition.add_col_with(|_, _| F::zero());
629            transition.add_row_with(|_, _| F::zero());
630            for i in 0..=n {
631                for j in 0..=n {
632                    if i == n || j == n {
633                        for k in 0..=n {
634                            if i != n && j != n && k != n {
635                                continue;
636                            }
637                            F::add_assign(
638                                &mut transition[i][k],
639                                &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
640                            );
641                        }
642                    } else {
643                        let k = n;
644                        F::add_assign(
645                            &mut transition[i][k],
646                            &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
647                        );
648                    }
649                }
650            }
651        }
652        true
653    }
654    pub fn train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
655        for sample in samples {
656            self.train_sample(&sample);
657        }
658    }
659    pub fn batch_train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
660        let mut prefix_set: HashSet<_> = self.prefixes.iter().cloned().collect();
661        let mut suffix_set: HashSet<_> = self.suffixes.iter().cloned().collect();
662        for sample in samples {
663            if prefix_set.insert(sample.to_vec()) {
664                self.prefixes.push(sample.to_vec());
665            }
666            if suffix_set.insert(sample.to_vec()) {
667                self.suffixes.push(sample);
668            }
669        }
670        let mut h = Matrix::<F>::new_with((self.prefixes.len(), self.suffixes.len()), |i, j| {
671            self.automaton.behavior(
672                self.prefixes[i]
673                    .iter()
674                    .cloned()
675                    .chain(self.suffixes[j].iter().cloned()),
676            )
677        });
678        if !self.prefixes.is_empty() && !self.suffixes.is_empty() && F::is_zero(&h[0][0]) {
679            for j in 1..self.suffixes.len() {
680                if !F::is_zero(&h[0][j]) {
681                    self.suffixes.swap(0, j);
682                    for i in 0..self.prefixes.len() {
683                        h.data[i].swap(0, j);
684                    }
685                    break;
686                }
687            }
688        }
689        let mut row_id: Vec<usize> = (0..h.shape.0).collect();
690        let mut pivots = vec![];
691        h.row_reduction_with(false, |r, p, c| {
692            row_id.swap(r, p);
693            pivots.push((row_id[r], c));
694        });
695        let mut new_prefixes = vec![];
696        let mut new_suffixes = vec![];
697        for (i, j) in pivots {
698            new_prefixes.push(self.prefixes[i].clone());
699            new_suffixes.push(self.suffixes[j].clone());
700        }
701        self.prefixes = new_prefixes;
702        self.suffixes = new_suffixes;
703        assert_eq!(self.prefixes.len(), self.suffixes.len());
704        let n = self.prefixes.len();
705        let h = Matrix::<F>::new_with((n, n), |i, j| {
706            self.automaton.behavior(
707                self.prefixes[i]
708                    .iter()
709                    .cloned()
710                    .chain(self.suffixes[j].iter().cloned()),
711            )
712        });
713        self.inv_h = h.inverse().expect("Hankel matrix must be invertible");
714        self.wfa = WeightedFiniteAutomaton::<F> {
715            initial_weights: Matrix::new_with((1, n), |_, j| {
716                if self.prefixes[j].is_empty() {
717                    F::one()
718                } else {
719                    F::zero()
720                }
721            }),
722            transitions: (0..self.automaton.sigma())
723                .map(|x| {
724                    &Matrix::new_with((n, n), |i, j| {
725                        self.automaton.behavior(
726                            self.prefixes[i]
727                                .iter()
728                                .cloned()
729                                .chain([x])
730                                .chain(self.suffixes[j].iter().cloned()),
731                        )
732                    }) * &self.inv_h
733                })
734                .collect(),
735            final_weights: Matrix::new_with((n, 1), |i, _| {
736                self.automaton.behavior(self.prefixes[i].iter().cloned())
737            }),
738        };
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745    use crate::{
746        algebra::AddMulOperation,
747        num::{One as _, Zero as _, mint_basic::MInt998244353},
748    };
749    use std::collections::{HashSet, VecDeque};
750
751    #[test]
752    fn test_dense_sampling() {
753        for base in 1usize..=10 {
754            let mut expected = vec![];
755            for len in 0..=3 {
756                for n in 0..base.pow(len) {
757                    let mut n = n;
758                    let mut current = vec![];
759                    for _ in 0..len {
760                        current.push(n % base);
761                        n /= base;
762                    }
763                    current.reverse();
764                    expected.push(current);
765                }
766            }
767
768            for (expected, result) in expected.into_iter().zip(dense_sampling(base, 3)) {
769                assert_eq!(expected, result);
770            }
771        }
772    }
773
774    #[test]
775    fn test_lstar() {
776        {
777            let automaton = BlackBoxAutomatonImpl::new(2, |input| input.len() % 6 == 0);
778            let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 6));
779            for sample in dense_sampling(automaton.sigma(), 12) {
780                let expected = automaton.behavior(sample.iter().cloned());
781                let result = dfa.behavior(sample.iter().cloned());
782                assert_eq!(expected, result);
783            }
784        }
785        {
786            let automaton =
787                BlackBoxAutomatonImpl::new(3, |input| input.iter().sum::<usize>() % 4 == 0);
788            let dfa = DfaLearning::new(&automaton).train(dense_sampling(3, 4));
789            for sample in dense_sampling(automaton.sigma(), 8) {
790                let expected = automaton.behavior(sample.iter().cloned());
791                let result = dfa.behavior(sample.iter().cloned());
792                assert_eq!(expected, result);
793            }
794        }
795        for i in 0usize..16 {
796            let a = i >> 3 & 1;
797            let b = i >> 2 & 1;
798            let c = i >> 1 & 1;
799            let d = i & 1;
800            let naive = |t: &[usize]| {
801                let mut set = HashSet::new();
802                let mut deq = VecDeque::new();
803                deq.push_back(t.to_vec());
804                set.insert(t.to_vec());
805                while let Some(t) = deq.pop_front() {
806                    for i in 0..t.len().saturating_sub(1) {
807                        let x = match (t[i], t[i + 1]) {
808                            (0, 0) => a,
809                            (0, 1) => b,
810                            (1, 0) => c,
811                            (1, 1) => d,
812                            _ => unreachable!(),
813                        };
814                        let mut t = t.to_vec();
815                        t.remove(i);
816                        t[i] = x;
817                        if set.insert(t.to_vec()) {
818                            deq.push_back(t);
819                        }
820                    }
821                }
822                set.contains(&vec![1])
823            };
824            let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
825            let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 4));
826            for sample in dense_sampling(automaton.sigma(), 8) {
827                let expected = automaton.behavior(sample.iter().cloned());
828                let result = dfa.behavior(sample.iter().cloned());
829                assert_eq!(expected, result);
830            }
831        }
832    }
833
834    #[test]
835    fn test_wfa_learning() {
836        {
837            let automaton = BlackBoxAutomatonImpl::new(2, |input| {
838                MInt998244353::from(input.iter().sum::<usize>())
839            });
840            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
841            wl.train(dense_sampling(2, 3));
842            let wfa = wl.wfa();
843            for sample in dense_sampling(automaton.sigma(), 12) {
844                let expected = automaton.behavior(sample.iter().cloned());
845                let result = wfa.behavior(sample.iter().cloned());
846                assert_eq!(expected, result);
847            }
848        }
849        {
850            let automaton = BlackBoxAutomatonImpl::new(3, |input| {
851                let mut s = MInt998244353::zero();
852                let mut c = MInt998244353::one();
853                for &x in &input {
854                    s += MInt998244353::from(x) * c;
855                    c = -c;
856                }
857                s
858            });
859            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
860            wl.train(dense_sampling(3, 4));
861            let wfa = wl.wfa();
862            for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
863                automaton.sigma(),
864                6..=12,
865                0.1,
866            )) {
867                let expected = automaton.behavior(sample.iter().cloned());
868                let result = wfa.behavior(sample.iter().cloned());
869                assert_eq!(expected, result);
870            }
871        }
872        {
873            // Xor Sum
874            let automaton = BlackBoxAutomatonImpl::new(2, |input| {
875                let mut n = 1; // prevent leading zero
876                for x in input {
877                    n = n * 2 + x;
878                }
879                let mut s = MInt998244353::zero();
880                for u in 0..=n {
881                    for v in 0..=n {
882                        let mut ok = false;
883                        for a in 0..=n {
884                            let b = u ^ a;
885                            ok |= a + b == v;
886                        }
887                        s += MInt998244353::new(ok as _);
888                    }
889                }
890                s
891            });
892            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
893            wl.train(dense_sampling(2, 4));
894            let wfa = wl.wfa();
895            for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
896                automaton.sigma(),
897                6..=12,
898                0.1,
899            )) {
900                let expected = automaton.behavior(sample.iter().cloned());
901                let result = wfa.behavior(sample.iter().cloned());
902                assert_eq!(expected, result);
903            }
904        }
905        for i in 0usize..16 {
906            let a = i >> 3 & 1;
907            let b = i >> 2 & 1;
908            let c = i >> 1 & 1;
909            let d = i & 1;
910            let naive = |t: &[usize]| {
911                let mut set = HashSet::new();
912                let mut deq = VecDeque::new();
913                deq.push_back(t.to_vec());
914                set.insert(t.to_vec());
915                while let Some(t) = deq.pop_front() {
916                    for i in 0..t.len().saturating_sub(1) {
917                        let x = match (t[i], t[i + 1]) {
918                            (0, 0) => a,
919                            (0, 1) => b,
920                            (1, 0) => c,
921                            (1, 1) => d,
922                            _ => unreachable!(),
923                        };
924                        let mut t = t.to_vec();
925                        t.remove(i);
926                        t[i] = x;
927                        if set.insert(t.to_vec()) {
928                            deq.push_back(t);
929                        }
930                    }
931                }
932                set.contains(&vec![1])
933            };
934            let naive = |t: &[usize]| {
935                let mut s = MInt998244353::zero();
936                for l in 0..t.len() {
937                    for r in l + 1..=t.len() {
938                        if naive(&t[l..r]) {
939                            s += MInt998244353::one();
940                        }
941                    }
942                }
943                s
944            };
945            let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
946            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
947            wl.train(dense_sampling(2, 6));
948            let wfa = wl.wfa();
949            for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
950                automaton.sigma(),
951                9..=12,
952                0.1,
953            )) {
954                let expected = automaton.behavior(sample.iter().cloned());
955                let result = wfa.behavior(sample.iter().cloned());
956                assert_eq!(expected, result);
957            }
958            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
959            wl.batch_train(dense_sampling(2, 3));
960            let wfa = wl.wfa();
961            for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
962                automaton.sigma(),
963                9..=12,
964                0.1,
965            )) {
966                let expected = automaton.behavior(sample.iter().cloned());
967                let result = wfa.behavior(sample.iter().cloned());
968                assert_eq!(expected, result);
969            }
970        }
971    }
972}