1use super::{BitSet, Field, Invertible, Matrix, RandomSpec, SerdeByteStr, Xorshift};
2use std::{
3 cell::RefCell,
4 collections::{HashMap, HashSet},
5 fmt::{self, Debug},
6 iter::{from_fn, once_with},
7 marker::PhantomData,
8 time::Instant,
9};
10
11pub trait BlackBoxAutomaton {
12 type Output;
13 fn sigma(&self) -> usize; fn behavior<I>(&self, input: I) -> Self::Output
15 where
16 I: IntoIterator<Item = usize>;
17}
18
19#[derive(Debug, Clone)]
20pub struct BlackBoxAutomatonImpl<T, F>
21where
22 F: Fn(Vec<usize>) -> T,
23{
24 sigma: usize,
25 behavior_fn: F,
26 memo: RefCell<HashMap<Vec<usize>, T>>,
27}
28
29impl<T, F> BlackBoxAutomatonImpl<T, F>
30where
31 F: Fn(Vec<usize>) -> T,
32{
33 pub fn new(sigma: usize, behavior_fn: F) -> Self {
34 Self {
35 sigma,
36 behavior_fn,
37 memo: RefCell::new(HashMap::new()),
38 }
39 }
40}
41
42impl<T, F> BlackBoxAutomaton for BlackBoxAutomatonImpl<T, F>
43where
44 F: Fn(Vec<usize>) -> T,
45 T: Clone,
46{
47 type Output = T;
48
49 fn sigma(&self) -> usize {
50 self.sigma
51 }
52
53 fn behavior<I>(&self, input: I) -> Self::Output
54 where
55 I: IntoIterator<Item = usize>,
56 {
57 let input: Vec<usize> = input.into_iter().collect();
58 self.memo
59 .borrow_mut()
60 .entry(input.clone())
61 .or_insert_with(|| (self.behavior_fn)(input))
62 .clone()
63 }
64}
65
66impl<A> BlackBoxAutomaton for &A
67where
68 A: BlackBoxAutomaton,
69{
70 type Output = A::Output;
71
72 fn sigma(&self) -> usize {
73 (*self).sigma()
74 }
75
76 fn behavior<I>(&self, input: I) -> Self::Output
77 where
78 I: IntoIterator<Item = usize>,
79 {
80 (*self).behavior(input)
81 }
82}
83
84#[derive(Debug, Clone)]
85struct DfaState {
86 delta: Vec<usize>,
87 accept: bool,
88}
89
90#[derive(Debug, Clone)]
91pub struct DeterministicFiniteAutomaton {
92 states: Vec<DfaState>,
93 initial_state: usize,
94}
95
96impl DeterministicFiniteAutomaton {
97 pub fn size(&self) -> usize {
98 self.states.len()
99 }
100 pub fn delta(&self, state: usize, input: usize) -> usize {
101 assert!(state < self.states.len());
102 assert!(input < self.states[0].delta.len());
103 self.states[state].delta[input]
104 }
105 pub fn accept(&self, state: usize) -> bool {
106 assert!(state < self.states.len());
107 self.states[state].accept
108 }
109}
110
111impl BlackBoxAutomaton for DeterministicFiniteAutomaton {
112 type Output = bool;
113
114 fn sigma(&self) -> usize {
115 self.states[0].delta.len()
116 }
117
118 fn behavior<I>(&self, input: I) -> Self::Output
119 where
120 I: IntoIterator<Item = usize>,
121 {
122 let mut state = self.initial_state;
123 for x in input {
124 state = self.states[state].delta[x];
125 }
126 self.states[state].accept
127 }
128}
129
130impl SerdeByteStr for DfaState {
131 fn serialize(&self, buf: &mut Vec<u8>) {
132 self.delta.serialize(buf);
133 self.accept.serialize(buf);
134 }
135
136 fn deserialize<I>(iter: &mut I) -> Self
137 where
138 I: Iterator<Item = u8>,
139 {
140 let delta = Vec::deserialize(iter);
141 let accept = bool::deserialize(iter);
142 Self { delta, accept }
143 }
144}
145
146impl SerdeByteStr for DeterministicFiniteAutomaton {
147 fn serialize(&self, buf: &mut Vec<u8>) {
148 self.states.serialize(buf);
149 self.initial_state.serialize(buf);
150 }
151
152 fn deserialize<I>(iter: &mut I) -> Self
153 where
154 I: Iterator<Item = u8>,
155 {
156 let states = Vec::deserialize(iter);
157 let initial_state = usize::deserialize(iter);
158 Self {
159 states,
160 initial_state,
161 }
162 }
163}
164
165pub struct WeightedFiniteAutomaton<F>
166where
167 F: Field<Additive: Invertible, Multiplicative: Invertible>,
168{
169 pub initial_weights: Matrix<F>,
170 pub transitions: Vec<Matrix<F>>,
171 pub final_weights: Matrix<F>,
172}
173
174impl<F> Debug for WeightedFiniteAutomaton<F>
175where
176 F: Field<T: Debug, Additive: Invertible, Multiplicative: Invertible>,
177{
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 f.debug_struct("WeightedFiniteAutomaton")
180 .field("initial_weights", &self.initial_weights)
181 .field("transitions", &self.transitions)
182 .field("final_weights", &self.final_weights)
183 .finish()
184 }
185}
186
187impl<F> Clone for WeightedFiniteAutomaton<F>
188where
189 F: Field<Additive: Invertible, Multiplicative: Invertible>,
190{
191 fn clone(&self) -> Self {
192 Self {
193 initial_weights: self.initial_weights.clone(),
194 transitions: self.transitions.clone(),
195 final_weights: self.final_weights.clone(),
196 }
197 }
198}
199
200impl<F> BlackBoxAutomaton for WeightedFiniteAutomaton<F>
201where
202 F: Field<Additive: Invertible, Multiplicative: Invertible>,
203{
204 type Output = F::T;
205
206 fn sigma(&self) -> usize {
207 self.transitions.len()
208 }
209
210 fn behavior<I>(&self, input: I) -> Self::Output
211 where
212 I: IntoIterator<Item = usize>,
213 {
214 let mut weights = self.initial_weights.clone();
215 for x in input {
216 weights = &weights * &self.transitions[x];
217 }
218 let result = &weights * &self.final_weights;
219 if result.shape != (0, 0) {
220 result[0][0].clone()
221 } else {
222 F::zero()
223 }
224 }
225}
226
227impl<F> SerdeByteStr for WeightedFiniteAutomaton<F>
228where
229 F: Field<T: SerdeByteStr, Additive: Invertible, Multiplicative: Invertible>,
230{
231 fn serialize(&self, buf: &mut Vec<u8>) {
232 self.initial_weights.serialize(buf);
233 self.transitions.serialize(buf);
234 self.final_weights.serialize(buf);
235 }
236
237 fn deserialize<I>(iter: &mut I) -> Self
238 where
239 I: Iterator<Item = u8>,
240 {
241 let initial_weights = Matrix::deserialize(iter);
242 let transitions = Vec::deserialize(iter);
243 let final_weights = Matrix::deserialize(iter);
244 Self {
245 initial_weights,
246 transitions,
247 final_weights,
248 }
249 }
250}
251
252pub fn dense_sampling(sigma: usize, max_len: usize) -> impl Iterator<Item = Vec<usize>> {
253 assert_ne!(sigma, 0, "Sigma must be greater than 0");
254 let mut current = vec![];
255 once_with(Vec::new).chain(from_fn(move || {
256 let mut carry = true;
257 for i in (0..current.len()).rev() {
258 current[i] += 1;
259 if current[i] == sigma {
260 current[i] = 0;
261 } else {
262 carry = false;
263 break;
264 }
265 }
266 if carry {
267 current.push(0);
268 }
269 if current.len() <= max_len {
270 Some(current.to_vec())
271 } else {
272 None
273 }
274 }))
275}
276
277pub fn random_sampling(
278 sigma: usize,
279 len_spec: impl RandomSpec<usize>,
280 seconds: f64,
281) -> impl Iterator<Item = Vec<usize>> {
282 assert_ne!(sigma, 0, "Sigma must be greater than 0");
283 let now = Instant::now();
284 let mut rng = Xorshift::new();
285 from_fn(move || {
286 if now.elapsed().as_secs_f64() > seconds {
287 None
288 } else {
289 let n = rng.random(&len_spec);
290 Some(rng.random_iter(0..sigma).take(n).collect())
291 }
292 })
293}
294
295#[derive(Debug, Clone)]
296pub struct DfaLearning<A>
297where
298 A: BlackBoxAutomaton<Output = bool>,
299{
300 automaton: A,
301 prefixes: Vec<Vec<usize>>,
302 suffixes: Vec<Vec<usize>>,
303 table: Vec<BitSet>,
304 row_map: HashMap<BitSet, usize>,
305}
306
307impl<A> DfaLearning<A>
308where
309 A: BlackBoxAutomaton<Output = bool>,
310{
311 pub fn new(automaton: A) -> Self {
312 let mut this = Self {
313 automaton,
314 prefixes: vec![],
315 suffixes: vec![],
316 table: vec![],
317 row_map: HashMap::new(),
318 };
319 this.add_suffix(vec![]);
320 this.add_prefix(vec![]);
321 this
322 }
323 fn add_prefix(&mut self, prefix: Vec<usize>) -> usize {
324 let row: BitSet = self
325 .suffixes
326 .iter()
327 .map(|s| {
328 self.automaton
329 .behavior(prefix.iter().cloned().chain(s.iter().cloned()))
330 })
331 .collect();
332 *self.row_map.entry(row.clone()).or_insert_with(|| {
333 let idx = self.table.len();
334 self.table.push(row);
335 self.prefixes.push(prefix);
336 idx
337 })
338 }
339 fn add_suffix(&mut self, suffix: Vec<usize>) {
340 if self.suffixes.contains(&suffix) {
341 return;
342 }
343 for (prefix, table) in self.prefixes.iter_mut().zip(&mut self.table) {
344 table.push(
345 self.automaton
346 .behavior(prefix.iter().cloned().chain(suffix.iter().cloned())),
347 );
348 }
349 self.suffixes.push(suffix);
350 self.row_map.clear();
351 for (i_prefix, row) in self.table.iter().enumerate() {
352 self.row_map.insert(row.clone(), i_prefix);
353 }
354 }
355 pub fn construct_dfa(&mut self) -> DeterministicFiniteAutomaton {
356 let sigma = self.automaton.sigma();
357 let mut dfa = DeterministicFiniteAutomaton {
358 states: vec![],
359 initial_state: 0,
360 };
361 let mut i_prefix = 0;
362 while i_prefix < self.prefixes.len() {
363 let mut delta = vec![];
364 for x in 0..sigma {
365 let prefix: Vec<usize> =
366 self.prefixes[i_prefix].iter().cloned().chain([x]).collect();
367 let index = self.add_prefix(prefix);
368 delta.push(index);
369 }
370 dfa.states.push(DfaState {
371 delta,
372 accept: self.table[i_prefix].get(0),
373 });
374 i_prefix += 1;
375 }
376 dfa
377 }
378 pub fn train_sample(&mut self, dfa: &DeterministicFiniteAutomaton, sample: &[usize]) -> bool {
379 let expected = self.automaton.behavior(sample.iter().cloned());
380 if expected == dfa.behavior(sample.iter().cloned()) {
381 return false;
382 }
383 let n = sample.len();
384 let mut states: Vec<(usize, usize)> = Vec::with_capacity(n + 1);
385 let mut s = 0usize;
386 states.push((s, 0));
387 for (k, &x) in sample.iter().enumerate() {
388 s = dfa.states[s].delta[x];
389 states.push((s, k + 1));
390 }
391 let split = states.partition_point(|&(state, k)| {
392 self.automaton.behavior(
393 self.prefixes[state]
394 .iter()
395 .cloned()
396 .chain(sample[k..].iter().cloned()),
397 ) == expected
398 });
399 let new_prefix = sample[..split].to_vec();
400 let new_suffix = sample[split..].to_vec();
401 self.add_suffix(new_suffix);
402 self.add_prefix(new_prefix);
403 true
404 }
405 pub fn train(
406 &mut self,
407 samples: impl IntoIterator<Item = Vec<usize>>,
408 ) -> DeterministicFiniteAutomaton {
409 let mut dfa = self.construct_dfa();
410 for sample in samples {
411 if self.train_sample(&dfa, &sample) {
412 dfa = self.construct_dfa();
413 }
414 }
415 dfa
416 }
417}
418
419pub struct WfaLearning<F, A>
420where
421 F: Field<Additive: Invertible, Multiplicative: Invertible>,
422 A: BlackBoxAutomaton<Output = F::T>,
423{
424 automaton: A,
425 prefixes: Vec<Vec<usize>>,
426 suffixes: Vec<Vec<usize>>,
427 inv_h: Matrix<F>,
428 nh: Vec<Matrix<F>>,
429 wfa: WeightedFiniteAutomaton<F>,
430 _marker: PhantomData<fn() -> F>,
431}
432
433impl<F, A> Debug for WfaLearning<F, A>
434where
435 F: Field<T: Debug, Additive: Invertible, Multiplicative: Invertible>,
436 A: BlackBoxAutomaton<Output = F::T> + Debug,
437{
438 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439 f.debug_struct("WfaLearning")
440 .field("automaton", &self.automaton)
441 .field("prefixes", &self.prefixes)
442 .field("suffixes", &self.suffixes)
443 .field("inv_h", &self.inv_h)
444 .field("nh", &self.nh)
445 .field("wfa", &self.wfa)
446 .finish()
447 }
448}
449
450impl<F, A> Clone for WfaLearning<F, A>
451where
452 F: Field<Additive: Invertible, Multiplicative: Invertible>,
453 A: BlackBoxAutomaton<Output = F::T> + Clone,
454{
455 fn clone(&self) -> Self {
456 Self {
457 automaton: self.automaton.clone(),
458 prefixes: self.prefixes.clone(),
459 suffixes: self.suffixes.clone(),
460 inv_h: self.inv_h.clone(),
461 nh: self.nh.clone(),
462 wfa: self.wfa.clone(),
463 _marker: self._marker,
464 }
465 }
466}
467
468impl<F, A> WfaLearning<F, A>
469where
470 F: Field<T: PartialEq, Additive: Invertible, Multiplicative: Invertible>,
471 A: BlackBoxAutomaton<Output = F::T>,
472{
473 pub fn new(automaton: A) -> Self {
474 let sigma = automaton.sigma();
475 Self {
476 automaton,
477 prefixes: vec![],
478 suffixes: vec![],
479 inv_h: Matrix::zeros((0, 0)),
480 nh: vec![Matrix::zeros((0, 0)); sigma],
481 wfa: WeightedFiniteAutomaton {
482 initial_weights: Matrix::zeros((1, 0)),
483 transitions: vec![Matrix::zeros((0, 0)); sigma],
484 final_weights: Matrix::zeros((0, 1)),
485 },
486 _marker: PhantomData,
487 }
488 }
489 pub fn wfa(&self) -> &WeightedFiniteAutomaton<F> {
490 &self.wfa
491 }
492 fn split_sample(&mut self, sample: &[usize]) -> Option<(Vec<usize>, Vec<usize>)> {
493 if self.prefixes.is_empty() && !F::is_zero(&self.automaton.behavior(sample.iter().cloned()))
494 {
495 return Some((vec![], sample.to_vec()));
496 }
497 let expected = self.automaton.behavior(sample.iter().cloned());
498 if expected == self.wfa.behavior(sample.iter().cloned()) {
499 return None;
500 }
501 let n = sample.len();
502 let dim = self.wfa.final_weights.shape.0;
503 let mut states: Vec<(Matrix<F>, usize)> = Vec::with_capacity(n + 1);
504 let mut v = self.wfa.final_weights.clone();
505 states.push((v.clone(), n));
506 for k in (0..n).rev() {
507 v = &self.wfa.transitions[sample[k]] * &v;
508 states.push((v.clone(), k));
509 }
510 states.reverse();
511 let split = states.partition_point(|(state, k)| {
512 (0..dim).any(|j| {
513 self.automaton.behavior(
514 self.prefixes[j]
515 .iter()
516 .cloned()
517 .chain(sample[*k..].iter().cloned()),
518 ) != state[j][0]
519 })
520 });
521 Some((sample[..split].to_vec(), sample[split..].to_vec()))
522 }
523 pub fn train_sample(&mut self, sample: &[usize]) -> bool {
524 let Some((prefix, suffix)) = self.split_sample(sample) else {
525 return false;
526 };
527 self.prefixes.push(prefix);
528 self.suffixes.push(suffix);
529 let n = self.inv_h.shape.0;
530 let prefix = &self.prefixes[n];
531 let suffix = &self.suffixes[n];
532 let u = Matrix::<F>::new_with((n, 1), |i, _| {
533 self.automaton.behavior(
534 self.prefixes[i]
535 .iter()
536 .cloned()
537 .chain(suffix.iter().cloned()),
538 )
539 });
540 let v = Matrix::<F>::new_with((1, n), |_, j| {
541 self.automaton.behavior(
542 prefix
543 .iter()
544 .cloned()
545 .chain(self.suffixes[j].iter().cloned()),
546 )
547 });
548 let w = Matrix::<F>::new_with((1, 1), |_, _| {
549 self.automaton
550 .behavior(prefix.iter().cloned().chain(suffix.iter().cloned()))
551 });
552 let t = &self.inv_h * &u;
553 let s = &v * &self.inv_h;
554 let d = F::inv(&(&w - &(&v * &t))[0][0]);
555 let dh = &t * &s;
556 for i in 0..n {
557 for j in 0..n {
558 F::add_assign(&mut self.inv_h[i][j], &F::mul(&dh[i][j], &d));
559 }
560 }
561 self.inv_h
562 .add_col_with(|i, _| F::neg(&F::mul(&t[i][0], &d)));
563 self.inv_h.add_row_with(|_, j| {
564 if j != n {
565 F::neg(&F::mul(&s[0][j], &d))
566 } else {
567 d.clone()
568 }
569 });
570
571 for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
572 let b = &(&self.nh[x] * &t) * &s;
573 for i in 0..n {
574 for j in 0..n {
575 F::add_assign(&mut transition[i][j], &F::mul(&b[i][j], &d));
576 }
577 }
578 }
579 for (x, nh) in self.nh.iter_mut().enumerate() {
580 nh.add_col_with(|i, j| {
581 self.automaton.behavior(
582 self.prefixes[i]
583 .iter()
584 .cloned()
585 .chain([x])
586 .chain(self.suffixes[j].iter().cloned()),
587 )
588 });
589 nh.add_row_with(|i, j| {
590 self.automaton.behavior(
591 self.prefixes[i]
592 .iter()
593 .cloned()
594 .chain([x])
595 .chain(self.suffixes[j].iter().cloned()),
596 )
597 });
598 }
599 self.wfa
600 .initial_weights
601 .add_col_with(|_, _| if n == 0 { F::one() } else { F::zero() });
602 self.wfa
603 .final_weights
604 .add_row_with(|_, _| self.automaton.behavior(prefix.iter().cloned()));
605 for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
606 transition.add_col_with(|_, _| F::zero());
607 transition.add_row_with(|_, _| F::zero());
608 for i in 0..=n {
609 for j in 0..=n {
610 if i == n || j == n {
611 for k in 0..=n {
612 if i != n && j != n && k != n {
613 continue;
614 }
615 F::add_assign(
616 &mut transition[i][k],
617 &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
618 );
619 }
620 } else {
621 let k = n;
622 F::add_assign(
623 &mut transition[i][k],
624 &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
625 );
626 }
627 }
628 }
629 }
630 true
631 }
632 pub fn train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
633 for sample in samples {
634 self.train_sample(&sample);
635 }
636 }
637 pub fn batch_train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
638 let mut prefix_set: HashSet<_> = self.prefixes.iter().cloned().collect();
639 let mut suffix_set: HashSet<_> = self.suffixes.iter().cloned().collect();
640 for sample in samples {
641 if prefix_set.insert(sample.to_vec()) {
642 self.prefixes.push(sample.to_vec());
643 }
644 if suffix_set.insert(sample.to_vec()) {
645 self.suffixes.push(sample);
646 }
647 }
648 let mut h = Matrix::<F>::new_with((self.prefixes.len(), self.suffixes.len()), |i, j| {
649 self.automaton.behavior(
650 self.prefixes[i]
651 .iter()
652 .cloned()
653 .chain(self.suffixes[j].iter().cloned()),
654 )
655 });
656 if !self.prefixes.is_empty() && !self.suffixes.is_empty() && F::is_zero(&h[0][0]) {
657 for j in 1..self.suffixes.len() {
658 if !F::is_zero(&h[0][j]) {
659 self.suffixes.swap(0, j);
660 for i in 0..self.prefixes.len() {
661 h.data[i].swap(0, j);
662 }
663 break;
664 }
665 }
666 }
667 let mut row_id: Vec<usize> = (0..h.shape.0).collect();
668 let mut pivots = vec![];
669 h.row_reduction_with(false, |r, p, c| {
670 row_id.swap(r, p);
671 pivots.push((row_id[r], c));
672 });
673 let mut new_prefixes = vec![];
674 let mut new_suffixes = vec![];
675 for (i, j) in pivots {
676 new_prefixes.push(self.prefixes[i].clone());
677 new_suffixes.push(self.suffixes[j].clone());
678 }
679 self.prefixes = new_prefixes;
680 self.suffixes = new_suffixes;
681 assert_eq!(self.prefixes.len(), self.suffixes.len());
682 let n = self.prefixes.len();
683 let h = Matrix::<F>::new_with((n, n), |i, j| {
684 self.automaton.behavior(
685 self.prefixes[i]
686 .iter()
687 .cloned()
688 .chain(self.suffixes[j].iter().cloned()),
689 )
690 });
691 self.inv_h = h.inverse().expect("Hankel matrix must be invertible");
692 self.wfa = WeightedFiniteAutomaton::<F> {
693 initial_weights: Matrix::new_with((1, n), |_, j| {
694 if self.prefixes[j].is_empty() {
695 F::one()
696 } else {
697 F::zero()
698 }
699 }),
700 transitions: (0..self.automaton.sigma())
701 .map(|x| {
702 &Matrix::new_with((n, n), |i, j| {
703 self.automaton.behavior(
704 self.prefixes[i]
705 .iter()
706 .cloned()
707 .chain([x])
708 .chain(self.suffixes[j].iter().cloned()),
709 )
710 }) * &self.inv_h
711 })
712 .collect(),
713 final_weights: Matrix::new_with((n, 1), |i, _| {
714 self.automaton.behavior(self.prefixes[i].iter().cloned())
715 }),
716 };
717 }
718}
719
720#[cfg(test)]
721mod tests {
722 use super::*;
723 use crate::{
724 algebra::AddMulOperation,
725 num::{One as _, Zero as _, mint_basic::MInt998244353},
726 };
727 use std::collections::{HashSet, VecDeque};
728
729 #[test]
730 fn test_dense_sampling() {
731 for base in 1usize..=10 {
732 let mut expected = vec![];
733 for len in 0..=3 {
734 for n in 0..base.pow(len) {
735 let mut n = n;
736 let mut current = vec![];
737 for _ in 0..len {
738 current.push(n % base);
739 n /= base;
740 }
741 current.reverse();
742 expected.push(current);
743 }
744 }
745
746 for (expected, result) in expected.into_iter().zip(dense_sampling(base, 3)) {
747 assert_eq!(expected, result);
748 }
749 }
750 }
751
752 #[test]
753 fn test_lstar() {
754 {
755 let automaton = BlackBoxAutomatonImpl::new(2, |input| input.len() % 6 == 0);
756 let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 6));
757 for sample in dense_sampling(automaton.sigma(), 12) {
758 let expected = automaton.behavior(sample.iter().cloned());
759 let result = dfa.behavior(sample.iter().cloned());
760 assert_eq!(expected, result);
761 }
762 }
763 {
764 let automaton =
765 BlackBoxAutomatonImpl::new(3, |input| input.iter().sum::<usize>() % 4 == 0);
766 let dfa = DfaLearning::new(&automaton).train(dense_sampling(3, 4));
767 for sample in dense_sampling(automaton.sigma(), 8) {
768 let expected = automaton.behavior(sample.iter().cloned());
769 let result = dfa.behavior(sample.iter().cloned());
770 assert_eq!(expected, result);
771 }
772 }
773 for i in 0usize..16 {
774 let a = i >> 3 & 1;
775 let b = i >> 2 & 1;
776 let c = i >> 1 & 1;
777 let d = i & 1;
778 let naive = |t: &[usize]| {
779 let mut set = HashSet::new();
780 let mut deq = VecDeque::new();
781 deq.push_back(t.to_vec());
782 set.insert(t.to_vec());
783 while let Some(t) = deq.pop_front() {
784 for i in 0..t.len().saturating_sub(1) {
785 let x = match (t[i], t[i + 1]) {
786 (0, 0) => a,
787 (0, 1) => b,
788 (1, 0) => c,
789 (1, 1) => d,
790 _ => unreachable!(),
791 };
792 let mut t = t.to_vec();
793 t.remove(i);
794 t[i] = x;
795 if set.insert(t.to_vec()) {
796 deq.push_back(t);
797 }
798 }
799 }
800 set.contains(&vec![1])
801 };
802 let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
803 let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 4));
804 for sample in dense_sampling(automaton.sigma(), 8) {
805 let expected = automaton.behavior(sample.iter().cloned());
806 let result = dfa.behavior(sample.iter().cloned());
807 assert_eq!(expected, result);
808 }
809 }
810 }
811
812 #[test]
813 fn test_wfa_learning() {
814 {
815 let automaton = BlackBoxAutomatonImpl::new(2, |input| {
816 MInt998244353::from(input.iter().sum::<usize>())
817 });
818 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
819 wl.train(dense_sampling(2, 3));
820 let wfa = wl.wfa();
821 for sample in dense_sampling(automaton.sigma(), 12) {
822 let expected = automaton.behavior(sample.iter().cloned());
823 let result = wfa.behavior(sample.iter().cloned());
824 assert_eq!(expected, result);
825 }
826 }
827 {
828 let automaton = BlackBoxAutomatonImpl::new(3, |input| {
829 let mut s = MInt998244353::zero();
830 let mut c = MInt998244353::one();
831 for &x in &input {
832 s += MInt998244353::from(x) * c;
833 c = -c;
834 }
835 s
836 });
837 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
838 wl.train(dense_sampling(3, 4));
839 let wfa = wl.wfa();
840 for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
841 automaton.sigma(),
842 6..=12,
843 0.1,
844 )) {
845 let expected = automaton.behavior(sample.iter().cloned());
846 let result = wfa.behavior(sample.iter().cloned());
847 assert_eq!(expected, result);
848 }
849 }
850 {
851 let automaton = BlackBoxAutomatonImpl::new(2, |input| {
853 let mut n = 1; for x in input {
855 n = n * 2 + x;
856 }
857 let mut s = MInt998244353::zero();
858 for u in 0..=n {
859 for v in 0..=n {
860 let mut ok = false;
861 for a in 0..=n {
862 let b = u ^ a;
863 ok |= a + b == v;
864 }
865 s += MInt998244353::new(ok as _);
866 }
867 }
868 s
869 });
870 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
871 wl.train(dense_sampling(2, 4));
872 let wfa = wl.wfa();
873 for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
874 automaton.sigma(),
875 6..=12,
876 0.1,
877 )) {
878 let expected = automaton.behavior(sample.iter().cloned());
879 let result = wfa.behavior(sample.iter().cloned());
880 assert_eq!(expected, result);
881 }
882 }
883 for i in 0usize..16 {
884 let a = i >> 3 & 1;
885 let b = i >> 2 & 1;
886 let c = i >> 1 & 1;
887 let d = i & 1;
888 let naive = |t: &[usize]| {
889 let mut set = HashSet::new();
890 let mut deq = VecDeque::new();
891 deq.push_back(t.to_vec());
892 set.insert(t.to_vec());
893 while let Some(t) = deq.pop_front() {
894 for i in 0..t.len().saturating_sub(1) {
895 let x = match (t[i], t[i + 1]) {
896 (0, 0) => a,
897 (0, 1) => b,
898 (1, 0) => c,
899 (1, 1) => d,
900 _ => unreachable!(),
901 };
902 let mut t = t.to_vec();
903 t.remove(i);
904 t[i] = x;
905 if set.insert(t.to_vec()) {
906 deq.push_back(t);
907 }
908 }
909 }
910 set.contains(&vec![1])
911 };
912 let naive = |t: &[usize]| {
913 let mut s = MInt998244353::zero();
914 for l in 0..t.len() {
915 for r in l + 1..=t.len() {
916 if naive(&t[l..r]) {
917 s += MInt998244353::one();
918 }
919 }
920 }
921 s
922 };
923 let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
924 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
925 wl.train(dense_sampling(2, 6));
926 let wfa = wl.wfa();
927 for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
928 automaton.sigma(),
929 9..=12,
930 0.1,
931 )) {
932 let expected = automaton.behavior(sample.iter().cloned());
933 let result = wfa.behavior(sample.iter().cloned());
934 assert_eq!(expected, result);
935 }
936 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
937 wl.batch_train(dense_sampling(2, 3));
938 let wfa = wl.wfa();
939 for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
940 automaton.sigma(),
941 9..=12,
942 0.1,
943 )) {
944 let expected = automaton.behavior(sample.iter().cloned());
945 let result = wfa.behavior(sample.iter().cloned());
946 assert_eq!(expected, result);
947 }
948 }
949 }
950}