competitive/algorithm/
automata_learning.rs

1use super::{BitSet, Field, Invertible, Matrix, RandomSpec, SerdeByteStr, Xorshift};
2use std::{
3    cell::RefCell,
4    collections::{HashMap, HashSet},
5    fmt::{self, Debug},
6    iter::{from_fn, once_with},
7    marker::PhantomData,
8    time::Instant,
9};
10
11pub trait BlackBoxAutomaton {
12    type Output;
13    fn sigma(&self) -> usize; // Σ={0,1,...,sigma-1}
14    fn behavior<I>(&self, input: I) -> Self::Output
15    where
16        I: IntoIterator<Item = usize>;
17}
18
19#[derive(Debug, Clone)]
20pub struct BlackBoxAutomatonImpl<T, F>
21where
22    F: Fn(Vec<usize>) -> T,
23{
24    sigma: usize,
25    behavior_fn: F,
26    memo: RefCell<HashMap<Vec<usize>, T>>,
27}
28
29impl<T, F> BlackBoxAutomatonImpl<T, F>
30where
31    F: Fn(Vec<usize>) -> T,
32{
33    pub fn new(sigma: usize, behavior_fn: F) -> Self {
34        Self {
35            sigma,
36            behavior_fn,
37            memo: RefCell::new(HashMap::new()),
38        }
39    }
40}
41
42impl<T, F> BlackBoxAutomaton for BlackBoxAutomatonImpl<T, F>
43where
44    F: Fn(Vec<usize>) -> T,
45    T: Clone,
46{
47    type Output = T;
48
49    fn sigma(&self) -> usize {
50        self.sigma
51    }
52
53    fn behavior<I>(&self, input: I) -> Self::Output
54    where
55        I: IntoIterator<Item = usize>,
56    {
57        let input: Vec<usize> = input.into_iter().collect();
58        self.memo
59            .borrow_mut()
60            .entry(input.clone())
61            .or_insert_with(|| (self.behavior_fn)(input))
62            .clone()
63    }
64}
65
66impl<A> BlackBoxAutomaton for &A
67where
68    A: BlackBoxAutomaton,
69{
70    type Output = A::Output;
71
72    fn sigma(&self) -> usize {
73        (*self).sigma()
74    }
75
76    fn behavior<I>(&self, input: I) -> Self::Output
77    where
78        I: IntoIterator<Item = usize>,
79    {
80        (*self).behavior(input)
81    }
82}
83
84#[derive(Debug, Clone)]
85struct DfaState {
86    delta: Vec<usize>,
87    accept: bool,
88}
89
90#[derive(Debug, Clone)]
91pub struct DeterministicFiniteAutomaton {
92    states: Vec<DfaState>,
93    initial_state: usize,
94}
95
96impl DeterministicFiniteAutomaton {
97    pub fn size(&self) -> usize {
98        self.states.len()
99    }
100    pub fn delta(&self, state: usize, input: usize) -> usize {
101        assert!(state < self.states.len());
102        assert!(input < self.states[0].delta.len());
103        self.states[state].delta[input]
104    }
105    pub fn accept(&self, state: usize) -> bool {
106        assert!(state < self.states.len());
107        self.states[state].accept
108    }
109}
110
111impl BlackBoxAutomaton for DeterministicFiniteAutomaton {
112    type Output = bool;
113
114    fn sigma(&self) -> usize {
115        self.states[0].delta.len()
116    }
117
118    fn behavior<I>(&self, input: I) -> Self::Output
119    where
120        I: IntoIterator<Item = usize>,
121    {
122        let mut state = self.initial_state;
123        for x in input {
124            state = self.states[state].delta[x];
125        }
126        self.states[state].accept
127    }
128}
129
130impl SerdeByteStr for DfaState {
131    fn serialize(&self, buf: &mut Vec<u8>) {
132        self.delta.serialize(buf);
133        self.accept.serialize(buf);
134    }
135
136    fn deserialize<I>(iter: &mut I) -> Self
137    where
138        I: Iterator<Item = u8>,
139    {
140        let delta = Vec::deserialize(iter);
141        let accept = bool::deserialize(iter);
142        Self { delta, accept }
143    }
144}
145
146impl SerdeByteStr for DeterministicFiniteAutomaton {
147    fn serialize(&self, buf: &mut Vec<u8>) {
148        self.states.serialize(buf);
149        self.initial_state.serialize(buf);
150    }
151
152    fn deserialize<I>(iter: &mut I) -> Self
153    where
154        I: Iterator<Item = u8>,
155    {
156        let states = Vec::deserialize(iter);
157        let initial_state = usize::deserialize(iter);
158        Self {
159            states,
160            initial_state,
161        }
162    }
163}
164
165pub struct WeightedFiniteAutomaton<F>
166where
167    F: Field<Additive: Invertible, Multiplicative: Invertible>,
168{
169    pub initial_weights: Matrix<F>,
170    pub transitions: Vec<Matrix<F>>,
171    pub final_weights: Matrix<F>,
172}
173
174impl<F> Debug for WeightedFiniteAutomaton<F>
175where
176    F: Field<T: Debug, Additive: Invertible, Multiplicative: Invertible>,
177{
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        f.debug_struct("WeightedFiniteAutomaton")
180            .field("initial_weights", &self.initial_weights)
181            .field("transitions", &self.transitions)
182            .field("final_weights", &self.final_weights)
183            .finish()
184    }
185}
186
187impl<F> Clone for WeightedFiniteAutomaton<F>
188where
189    F: Field<Additive: Invertible, Multiplicative: Invertible>,
190{
191    fn clone(&self) -> Self {
192        Self {
193            initial_weights: self.initial_weights.clone(),
194            transitions: self.transitions.clone(),
195            final_weights: self.final_weights.clone(),
196        }
197    }
198}
199
200impl<F> BlackBoxAutomaton for WeightedFiniteAutomaton<F>
201where
202    F: Field<Additive: Invertible, Multiplicative: Invertible>,
203{
204    type Output = F::T;
205
206    fn sigma(&self) -> usize {
207        self.transitions.len()
208    }
209
210    fn behavior<I>(&self, input: I) -> Self::Output
211    where
212        I: IntoIterator<Item = usize>,
213    {
214        let mut weights = self.initial_weights.clone();
215        for x in input {
216            weights = &weights * &self.transitions[x];
217        }
218        let result = &weights * &self.final_weights;
219        if result.shape != (0, 0) {
220            result[0][0].clone()
221        } else {
222            F::zero()
223        }
224    }
225}
226
227impl<F> SerdeByteStr for WeightedFiniteAutomaton<F>
228where
229    F: Field<T: SerdeByteStr, Additive: Invertible, Multiplicative: Invertible>,
230{
231    fn serialize(&self, buf: &mut Vec<u8>) {
232        self.initial_weights.serialize(buf);
233        self.transitions.serialize(buf);
234        self.final_weights.serialize(buf);
235    }
236
237    fn deserialize<I>(iter: &mut I) -> Self
238    where
239        I: Iterator<Item = u8>,
240    {
241        let initial_weights = Matrix::deserialize(iter);
242        let transitions = Vec::deserialize(iter);
243        let final_weights = Matrix::deserialize(iter);
244        Self {
245            initial_weights,
246            transitions,
247            final_weights,
248        }
249    }
250}
251
252pub fn dense_sampling(sigma: usize, max_len: usize) -> impl Iterator<Item = Vec<usize>> {
253    assert_ne!(sigma, 0, "Sigma must be greater than 0");
254    let mut current = vec![];
255    once_with(Vec::new).chain(from_fn(move || {
256        let mut carry = true;
257        for i in (0..current.len()).rev() {
258            current[i] += 1;
259            if current[i] == sigma {
260                current[i] = 0;
261            } else {
262                carry = false;
263                break;
264            }
265        }
266        if carry {
267            current.push(0);
268        }
269        if current.len() <= max_len {
270            Some(current.to_vec())
271        } else {
272            None
273        }
274    }))
275}
276
277pub fn random_sampling(
278    sigma: usize,
279    len_spec: impl RandomSpec<usize>,
280    seconds: f64,
281) -> impl Iterator<Item = Vec<usize>> {
282    assert_ne!(sigma, 0, "Sigma must be greater than 0");
283    let now = Instant::now();
284    let mut rng = Xorshift::new();
285    from_fn(move || {
286        if now.elapsed().as_secs_f64() > seconds {
287            None
288        } else {
289            let n = rng.random(&len_spec);
290            Some(rng.random_iter(0..sigma).take(n).collect())
291        }
292    })
293}
294
295#[derive(Debug, Clone)]
296pub struct DfaLearning<A>
297where
298    A: BlackBoxAutomaton<Output = bool>,
299{
300    automaton: A,
301    prefixes: Vec<Vec<usize>>,
302    suffixes: Vec<Vec<usize>>,
303    table: Vec<BitSet>,
304    row_map: HashMap<BitSet, usize>,
305}
306
307impl<A> DfaLearning<A>
308where
309    A: BlackBoxAutomaton<Output = bool>,
310{
311    pub fn new(automaton: A) -> Self {
312        let mut this = Self {
313            automaton,
314            prefixes: vec![],
315            suffixes: vec![],
316            table: vec![],
317            row_map: HashMap::new(),
318        };
319        this.add_suffix(vec![]);
320        this.add_prefix(vec![]);
321        this
322    }
323    fn add_prefix(&mut self, prefix: Vec<usize>) -> usize {
324        let row: BitSet = self
325            .suffixes
326            .iter()
327            .map(|s| {
328                self.automaton
329                    .behavior(prefix.iter().cloned().chain(s.iter().cloned()))
330            })
331            .collect();
332        *self.row_map.entry(row.clone()).or_insert_with(|| {
333            let idx = self.table.len();
334            self.table.push(row);
335            self.prefixes.push(prefix);
336            idx
337        })
338    }
339    fn add_suffix(&mut self, suffix: Vec<usize>) {
340        if self.suffixes.contains(&suffix) {
341            return;
342        }
343        for (prefix, table) in self.prefixes.iter_mut().zip(&mut self.table) {
344            table.push(
345                self.automaton
346                    .behavior(prefix.iter().cloned().chain(suffix.iter().cloned())),
347            );
348        }
349        self.suffixes.push(suffix);
350        self.row_map.clear();
351        for (i_prefix, row) in self.table.iter().enumerate() {
352            self.row_map.insert(row.clone(), i_prefix);
353        }
354    }
355    pub fn construct_dfa(&mut self) -> DeterministicFiniteAutomaton {
356        let sigma = self.automaton.sigma();
357        let mut dfa = DeterministicFiniteAutomaton {
358            states: vec![],
359            initial_state: 0,
360        };
361        let mut i_prefix = 0;
362        while i_prefix < self.prefixes.len() {
363            let mut delta = vec![];
364            for x in 0..sigma {
365                let prefix: Vec<usize> =
366                    self.prefixes[i_prefix].iter().cloned().chain([x]).collect();
367                let index = self.add_prefix(prefix);
368                delta.push(index);
369            }
370            dfa.states.push(DfaState {
371                delta,
372                accept: self.table[i_prefix].get(0),
373            });
374            i_prefix += 1;
375        }
376        dfa
377    }
378    pub fn train_sample(&mut self, dfa: &DeterministicFiniteAutomaton, sample: &[usize]) -> bool {
379        let expected = self.automaton.behavior(sample.iter().cloned());
380        if expected == dfa.behavior(sample.iter().cloned()) {
381            return false;
382        }
383        let n = sample.len();
384        let mut states: Vec<(usize, usize)> = Vec::with_capacity(n + 1);
385        let mut s = 0usize;
386        states.push((s, 0));
387        for (k, &x) in sample.iter().enumerate() {
388            s = dfa.states[s].delta[x];
389            states.push((s, k + 1));
390        }
391        let split = states.partition_point(|&(state, k)| {
392            self.automaton.behavior(
393                self.prefixes[state]
394                    .iter()
395                    .cloned()
396                    .chain(sample[k..].iter().cloned()),
397            ) == expected
398        });
399        let new_prefix = sample[..split].to_vec();
400        let new_suffix = sample[split..].to_vec();
401        self.add_suffix(new_suffix);
402        self.add_prefix(new_prefix);
403        true
404    }
405    pub fn train(
406        &mut self,
407        samples: impl IntoIterator<Item = Vec<usize>>,
408    ) -> DeterministicFiniteAutomaton {
409        let mut dfa = self.construct_dfa();
410        for sample in samples {
411            if self.train_sample(&dfa, &sample) {
412                dfa = self.construct_dfa();
413            }
414        }
415        dfa
416    }
417}
418
419pub struct WfaLearning<F, A>
420where
421    F: Field<Additive: Invertible, Multiplicative: Invertible>,
422    A: BlackBoxAutomaton<Output = F::T>,
423{
424    automaton: A,
425    prefixes: Vec<Vec<usize>>,
426    suffixes: Vec<Vec<usize>>,
427    inv_h: Matrix<F>,
428    nh: Vec<Matrix<F>>,
429    wfa: WeightedFiniteAutomaton<F>,
430    _marker: PhantomData<fn() -> F>,
431}
432
433impl<F, A> Debug for WfaLearning<F, A>
434where
435    F: Field<T: Debug, Additive: Invertible, Multiplicative: Invertible>,
436    A: BlackBoxAutomaton<Output = F::T> + Debug,
437{
438    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439        f.debug_struct("WfaLearning")
440            .field("automaton", &self.automaton)
441            .field("prefixes", &self.prefixes)
442            .field("suffixes", &self.suffixes)
443            .field("inv_h", &self.inv_h)
444            .field("nh", &self.nh)
445            .field("wfa", &self.wfa)
446            .finish()
447    }
448}
449
450impl<F, A> Clone for WfaLearning<F, A>
451where
452    F: Field<Additive: Invertible, Multiplicative: Invertible>,
453    A: BlackBoxAutomaton<Output = F::T> + Clone,
454{
455    fn clone(&self) -> Self {
456        Self {
457            automaton: self.automaton.clone(),
458            prefixes: self.prefixes.clone(),
459            suffixes: self.suffixes.clone(),
460            inv_h: self.inv_h.clone(),
461            nh: self.nh.clone(),
462            wfa: self.wfa.clone(),
463            _marker: self._marker,
464        }
465    }
466}
467
468impl<F, A> WfaLearning<F, A>
469where
470    F: Field<T: PartialEq, Additive: Invertible, Multiplicative: Invertible>,
471    A: BlackBoxAutomaton<Output = F::T>,
472{
473    pub fn new(automaton: A) -> Self {
474        let sigma = automaton.sigma();
475        Self {
476            automaton,
477            prefixes: vec![],
478            suffixes: vec![],
479            inv_h: Matrix::zeros((0, 0)),
480            nh: vec![Matrix::zeros((0, 0)); sigma],
481            wfa: WeightedFiniteAutomaton {
482                initial_weights: Matrix::zeros((1, 0)),
483                transitions: vec![Matrix::zeros((0, 0)); sigma],
484                final_weights: Matrix::zeros((0, 1)),
485            },
486            _marker: PhantomData,
487        }
488    }
489    pub fn wfa(&self) -> &WeightedFiniteAutomaton<F> {
490        &self.wfa
491    }
492    fn split_sample(&mut self, sample: &[usize]) -> Option<(Vec<usize>, Vec<usize>)> {
493        if self.prefixes.is_empty() && !F::is_zero(&self.automaton.behavior(sample.iter().cloned()))
494        {
495            return Some((vec![], sample.to_vec()));
496        }
497        let expected = self.automaton.behavior(sample.iter().cloned());
498        if expected == self.wfa.behavior(sample.iter().cloned()) {
499            return None;
500        }
501        let n = sample.len();
502        let dim = self.wfa.final_weights.shape.0;
503        let mut states: Vec<(Matrix<F>, usize)> = Vec::with_capacity(n + 1);
504        let mut v = self.wfa.final_weights.clone();
505        states.push((v.clone(), n));
506        for k in (0..n).rev() {
507            v = &self.wfa.transitions[sample[k]] * &v;
508            states.push((v.clone(), k));
509        }
510        states.reverse();
511        let split = states.partition_point(|(state, k)| {
512            (0..dim).any(|j| {
513                self.automaton.behavior(
514                    self.prefixes[j]
515                        .iter()
516                        .cloned()
517                        .chain(sample[*k..].iter().cloned()),
518                ) != state[j][0]
519            })
520        });
521        Some((sample[..split].to_vec(), sample[split..].to_vec()))
522    }
523    pub fn train_sample(&mut self, sample: &[usize]) -> bool {
524        let Some((prefix, suffix)) = self.split_sample(sample) else {
525            return false;
526        };
527        self.prefixes.push(prefix);
528        self.suffixes.push(suffix);
529        let n = self.inv_h.shape.0;
530        let prefix = &self.prefixes[n];
531        let suffix = &self.suffixes[n];
532        let u = Matrix::<F>::new_with((n, 1), |i, _| {
533            self.automaton.behavior(
534                self.prefixes[i]
535                    .iter()
536                    .cloned()
537                    .chain(suffix.iter().cloned()),
538            )
539        });
540        let v = Matrix::<F>::new_with((1, n), |_, j| {
541            self.automaton.behavior(
542                prefix
543                    .iter()
544                    .cloned()
545                    .chain(self.suffixes[j].iter().cloned()),
546            )
547        });
548        let w = Matrix::<F>::new_with((1, 1), |_, _| {
549            self.automaton
550                .behavior(prefix.iter().cloned().chain(suffix.iter().cloned()))
551        });
552        let t = &self.inv_h * &u;
553        let s = &v * &self.inv_h;
554        let d = F::inv(&(&w - &(&v * &t))[0][0]);
555        let dh = &t * &s;
556        for i in 0..n {
557            for j in 0..n {
558                F::add_assign(&mut self.inv_h[i][j], &F::mul(&dh[i][j], &d));
559            }
560        }
561        self.inv_h
562            .add_col_with(|i, _| F::neg(&F::mul(&t[i][0], &d)));
563        self.inv_h.add_row_with(|_, j| {
564            if j != n {
565                F::neg(&F::mul(&s[0][j], &d))
566            } else {
567                d.clone()
568            }
569        });
570
571        for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
572            let b = &(&self.nh[x] * &t) * &s;
573            for i in 0..n {
574                for j in 0..n {
575                    F::add_assign(&mut transition[i][j], &F::mul(&b[i][j], &d));
576                }
577            }
578        }
579        for (x, nh) in self.nh.iter_mut().enumerate() {
580            nh.add_col_with(|i, j| {
581                self.automaton.behavior(
582                    self.prefixes[i]
583                        .iter()
584                        .cloned()
585                        .chain([x])
586                        .chain(self.suffixes[j].iter().cloned()),
587                )
588            });
589            nh.add_row_with(|i, j| {
590                self.automaton.behavior(
591                    self.prefixes[i]
592                        .iter()
593                        .cloned()
594                        .chain([x])
595                        .chain(self.suffixes[j].iter().cloned()),
596                )
597            });
598        }
599        self.wfa
600            .initial_weights
601            .add_col_with(|_, _| if n == 0 { F::one() } else { F::zero() });
602        self.wfa
603            .final_weights
604            .add_row_with(|_, _| self.automaton.behavior(prefix.iter().cloned()));
605        for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
606            transition.add_col_with(|_, _| F::zero());
607            transition.add_row_with(|_, _| F::zero());
608            for i in 0..=n {
609                for j in 0..=n {
610                    if i == n || j == n {
611                        for k in 0..=n {
612                            if i != n && j != n && k != n {
613                                continue;
614                            }
615                            F::add_assign(
616                                &mut transition[i][k],
617                                &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
618                            );
619                        }
620                    } else {
621                        let k = n;
622                        F::add_assign(
623                            &mut transition[i][k],
624                            &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
625                        );
626                    }
627                }
628            }
629        }
630        true
631    }
632    pub fn train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
633        for sample in samples {
634            self.train_sample(&sample);
635        }
636    }
637    pub fn batch_train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
638        let mut prefix_set: HashSet<_> = self.prefixes.iter().cloned().collect();
639        let mut suffix_set: HashSet<_> = self.suffixes.iter().cloned().collect();
640        for sample in samples {
641            if prefix_set.insert(sample.to_vec()) {
642                self.prefixes.push(sample.to_vec());
643            }
644            if suffix_set.insert(sample.to_vec()) {
645                self.suffixes.push(sample);
646            }
647        }
648        let mut h = Matrix::<F>::new_with((self.prefixes.len(), self.suffixes.len()), |i, j| {
649            self.automaton.behavior(
650                self.prefixes[i]
651                    .iter()
652                    .cloned()
653                    .chain(self.suffixes[j].iter().cloned()),
654            )
655        });
656        if !self.prefixes.is_empty() && !self.suffixes.is_empty() && F::is_zero(&h[0][0]) {
657            for j in 1..self.suffixes.len() {
658                if !F::is_zero(&h[0][j]) {
659                    self.suffixes.swap(0, j);
660                    for i in 0..self.prefixes.len() {
661                        h.data[i].swap(0, j);
662                    }
663                    break;
664                }
665            }
666        }
667        let mut row_id: Vec<usize> = (0..h.shape.0).collect();
668        let mut pivots = vec![];
669        h.row_reduction_with(false, |r, p, c| {
670            row_id.swap(r, p);
671            pivots.push((row_id[r], c));
672        });
673        let mut new_prefixes = vec![];
674        let mut new_suffixes = vec![];
675        for (i, j) in pivots {
676            new_prefixes.push(self.prefixes[i].clone());
677            new_suffixes.push(self.suffixes[j].clone());
678        }
679        self.prefixes = new_prefixes;
680        self.suffixes = new_suffixes;
681        assert_eq!(self.prefixes.len(), self.suffixes.len());
682        let n = self.prefixes.len();
683        let h = Matrix::<F>::new_with((n, n), |i, j| {
684            self.automaton.behavior(
685                self.prefixes[i]
686                    .iter()
687                    .cloned()
688                    .chain(self.suffixes[j].iter().cloned()),
689            )
690        });
691        self.inv_h = h.inverse().expect("Hankel matrix must be invertible");
692        self.wfa = WeightedFiniteAutomaton::<F> {
693            initial_weights: Matrix::new_with((1, n), |_, j| {
694                if self.prefixes[j].is_empty() {
695                    F::one()
696                } else {
697                    F::zero()
698                }
699            }),
700            transitions: (0..self.automaton.sigma())
701                .map(|x| {
702                    &Matrix::new_with((n, n), |i, j| {
703                        self.automaton.behavior(
704                            self.prefixes[i]
705                                .iter()
706                                .cloned()
707                                .chain([x])
708                                .chain(self.suffixes[j].iter().cloned()),
709                        )
710                    }) * &self.inv_h
711                })
712                .collect(),
713            final_weights: Matrix::new_with((n, 1), |i, _| {
714                self.automaton.behavior(self.prefixes[i].iter().cloned())
715            }),
716        };
717    }
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723    use crate::{
724        algebra::AddMulOperation,
725        num::{One as _, Zero as _, mint_basic::MInt998244353},
726    };
727    use std::collections::{HashSet, VecDeque};
728
729    #[test]
730    fn test_dense_sampling() {
731        for base in 1usize..=10 {
732            let mut expected = vec![];
733            for len in 0..=3 {
734                for n in 0..base.pow(len) {
735                    let mut n = n;
736                    let mut current = vec![];
737                    for _ in 0..len {
738                        current.push(n % base);
739                        n /= base;
740                    }
741                    current.reverse();
742                    expected.push(current);
743                }
744            }
745
746            for (expected, result) in expected.into_iter().zip(dense_sampling(base, 3)) {
747                assert_eq!(expected, result);
748            }
749        }
750    }
751
752    #[test]
753    fn test_lstar() {
754        {
755            let automaton = BlackBoxAutomatonImpl::new(2, |input| input.len() % 6 == 0);
756            let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 6));
757            for sample in dense_sampling(automaton.sigma(), 12) {
758                let expected = automaton.behavior(sample.iter().cloned());
759                let result = dfa.behavior(sample.iter().cloned());
760                assert_eq!(expected, result);
761            }
762        }
763        {
764            let automaton =
765                BlackBoxAutomatonImpl::new(3, |input| input.iter().sum::<usize>() % 4 == 0);
766            let dfa = DfaLearning::new(&automaton).train(dense_sampling(3, 4));
767            for sample in dense_sampling(automaton.sigma(), 8) {
768                let expected = automaton.behavior(sample.iter().cloned());
769                let result = dfa.behavior(sample.iter().cloned());
770                assert_eq!(expected, result);
771            }
772        }
773        for i in 0usize..16 {
774            let a = i >> 3 & 1;
775            let b = i >> 2 & 1;
776            let c = i >> 1 & 1;
777            let d = i & 1;
778            let naive = |t: &[usize]| {
779                let mut set = HashSet::new();
780                let mut deq = VecDeque::new();
781                deq.push_back(t.to_vec());
782                set.insert(t.to_vec());
783                while let Some(t) = deq.pop_front() {
784                    for i in 0..t.len().saturating_sub(1) {
785                        let x = match (t[i], t[i + 1]) {
786                            (0, 0) => a,
787                            (0, 1) => b,
788                            (1, 0) => c,
789                            (1, 1) => d,
790                            _ => unreachable!(),
791                        };
792                        let mut t = t.to_vec();
793                        t.remove(i);
794                        t[i] = x;
795                        if set.insert(t.to_vec()) {
796                            deq.push_back(t);
797                        }
798                    }
799                }
800                set.contains(&vec![1])
801            };
802            let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
803            let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 4));
804            for sample in dense_sampling(automaton.sigma(), 8) {
805                let expected = automaton.behavior(sample.iter().cloned());
806                let result = dfa.behavior(sample.iter().cloned());
807                assert_eq!(expected, result);
808            }
809        }
810    }
811
812    #[test]
813    fn test_wfa_learning() {
814        {
815            let automaton = BlackBoxAutomatonImpl::new(2, |input| {
816                MInt998244353::from(input.iter().sum::<usize>())
817            });
818            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
819            wl.train(dense_sampling(2, 3));
820            let wfa = wl.wfa();
821            for sample in dense_sampling(automaton.sigma(), 12) {
822                let expected = automaton.behavior(sample.iter().cloned());
823                let result = wfa.behavior(sample.iter().cloned());
824                assert_eq!(expected, result);
825            }
826        }
827        {
828            let automaton = BlackBoxAutomatonImpl::new(3, |input| {
829                let mut s = MInt998244353::zero();
830                let mut c = MInt998244353::one();
831                for &x in &input {
832                    s += MInt998244353::from(x) * c;
833                    c = -c;
834                }
835                s
836            });
837            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
838            wl.train(dense_sampling(3, 4));
839            let wfa = wl.wfa();
840            for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
841                automaton.sigma(),
842                6..=12,
843                0.1,
844            )) {
845                let expected = automaton.behavior(sample.iter().cloned());
846                let result = wfa.behavior(sample.iter().cloned());
847                assert_eq!(expected, result);
848            }
849        }
850        {
851            // Xor Sum
852            let automaton = BlackBoxAutomatonImpl::new(2, |input| {
853                let mut n = 1; // prevent leading zero
854                for x in input {
855                    n = n * 2 + x;
856                }
857                let mut s = MInt998244353::zero();
858                for u in 0..=n {
859                    for v in 0..=n {
860                        let mut ok = false;
861                        for a in 0..=n {
862                            let b = u ^ a;
863                            ok |= a + b == v;
864                        }
865                        s += MInt998244353::new(ok as _);
866                    }
867                }
868                s
869            });
870            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
871            wl.train(dense_sampling(2, 4));
872            let wfa = wl.wfa();
873            for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
874                automaton.sigma(),
875                6..=12,
876                0.1,
877            )) {
878                let expected = automaton.behavior(sample.iter().cloned());
879                let result = wfa.behavior(sample.iter().cloned());
880                assert_eq!(expected, result);
881            }
882        }
883        for i in 0usize..16 {
884            let a = i >> 3 & 1;
885            let b = i >> 2 & 1;
886            let c = i >> 1 & 1;
887            let d = i & 1;
888            let naive = |t: &[usize]| {
889                let mut set = HashSet::new();
890                let mut deq = VecDeque::new();
891                deq.push_back(t.to_vec());
892                set.insert(t.to_vec());
893                while let Some(t) = deq.pop_front() {
894                    for i in 0..t.len().saturating_sub(1) {
895                        let x = match (t[i], t[i + 1]) {
896                            (0, 0) => a,
897                            (0, 1) => b,
898                            (1, 0) => c,
899                            (1, 1) => d,
900                            _ => unreachable!(),
901                        };
902                        let mut t = t.to_vec();
903                        t.remove(i);
904                        t[i] = x;
905                        if set.insert(t.to_vec()) {
906                            deq.push_back(t);
907                        }
908                    }
909                }
910                set.contains(&vec![1])
911            };
912            let naive = |t: &[usize]| {
913                let mut s = MInt998244353::zero();
914                for l in 0..t.len() {
915                    for r in l + 1..=t.len() {
916                        if naive(&t[l..r]) {
917                            s += MInt998244353::one();
918                        }
919                    }
920                }
921                s
922            };
923            let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
924            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
925            wl.train(dense_sampling(2, 6));
926            let wfa = wl.wfa();
927            for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
928                automaton.sigma(),
929                9..=12,
930                0.1,
931            )) {
932                let expected = automaton.behavior(sample.iter().cloned());
933                let result = wfa.behavior(sample.iter().cloned());
934                assert_eq!(expected, result);
935            }
936            let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
937            wl.batch_train(dense_sampling(2, 3));
938            let wfa = wl.wfa();
939            for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
940                automaton.sigma(),
941                9..=12,
942                0.1,
943            )) {
944                let expected = automaton.behavior(sample.iter().cloned());
945                let result = wfa.behavior(sample.iter().cloned());
946                assert_eq!(expected, result);
947            }
948        }
949    }
950}