1use super::{BitSet, Field, Invertible, Matrix, RandomSpec, SerdeByteStr, Xorshift};
2use std::{
3 cell::RefCell,
4 collections::{HashMap, HashSet},
5 fmt::{self, Debug},
6 iter::{from_fn, once_with},
7 marker::PhantomData,
8 time::Instant,
9};
10
11pub trait BlackBoxAutomaton {
12 type Output;
13 fn sigma(&self) -> usize; fn behavior<I>(&self, input: I) -> Self::Output
15 where
16 I: IntoIterator<Item = usize>;
17}
18
19#[derive(Debug, Clone)]
20pub struct BlackBoxAutomatonImpl<T, F>
21where
22 F: Fn(Vec<usize>) -> T,
23{
24 sigma: usize,
25 behavior_fn: F,
26 memo: RefCell<HashMap<Vec<usize>, T>>,
27}
28
29impl<T, F> BlackBoxAutomatonImpl<T, F>
30where
31 F: Fn(Vec<usize>) -> T,
32{
33 pub fn new(sigma: usize, behavior_fn: F) -> Self {
34 Self {
35 sigma,
36 behavior_fn,
37 memo: RefCell::new(HashMap::new()),
38 }
39 }
40}
41
42impl<T, F> BlackBoxAutomaton for BlackBoxAutomatonImpl<T, F>
43where
44 F: Fn(Vec<usize>) -> T,
45 T: Clone,
46{
47 type Output = T;
48
49 fn sigma(&self) -> usize {
50 self.sigma
51 }
52
53 fn behavior<I>(&self, input: I) -> Self::Output
54 where
55 I: IntoIterator<Item = usize>,
56 {
57 let input: Vec<usize> = input.into_iter().collect();
58 self.memo
59 .borrow_mut()
60 .entry(input.clone())
61 .or_insert_with(|| (self.behavior_fn)(input))
62 .clone()
63 }
64}
65
66impl<A> BlackBoxAutomaton for &A
67where
68 A: BlackBoxAutomaton,
69{
70 type Output = A::Output;
71
72 fn sigma(&self) -> usize {
73 (*self).sigma()
74 }
75
76 fn behavior<I>(&self, input: I) -> Self::Output
77 where
78 I: IntoIterator<Item = usize>,
79 {
80 (*self).behavior(input)
81 }
82}
83
84#[derive(Debug, Clone)]
85struct DfaState {
86 delta: Vec<usize>,
87 accept: bool,
88}
89
90#[derive(Debug, Clone)]
91pub struct DeterministicFiniteAutomaton {
92 states: Vec<DfaState>,
93 initial_state: usize,
94}
95
96impl DeterministicFiniteAutomaton {
97 pub fn size(&self) -> usize {
98 self.states.len()
99 }
100 pub fn delta(&self, state: usize, input: usize) -> usize {
101 assert!(state < self.states.len());
102 assert!(input < self.states[0].delta.len());
103 self.states[state].delta[input]
104 }
105 pub fn accept(&self, state: usize) -> bool {
106 assert!(state < self.states.len());
107 self.states[state].accept
108 }
109}
110
111impl BlackBoxAutomaton for DeterministicFiniteAutomaton {
112 type Output = bool;
113
114 fn sigma(&self) -> usize {
115 self.states[0].delta.len()
116 }
117
118 fn behavior<I>(&self, input: I) -> Self::Output
119 where
120 I: IntoIterator<Item = usize>,
121 {
122 let mut state = self.initial_state;
123 for x in input {
124 state = self.states[state].delta[x];
125 }
126 self.states[state].accept
127 }
128}
129
130impl SerdeByteStr for DfaState {
131 fn serialize(&self, buf: &mut Vec<u8>) {
132 self.delta.serialize(buf);
133 self.accept.serialize(buf);
134 }
135
136 fn deserialize<I>(iter: &mut I) -> Self
137 where
138 I: Iterator<Item = u8>,
139 {
140 let delta = Vec::deserialize(iter);
141 let accept = bool::deserialize(iter);
142 Self { delta, accept }
143 }
144}
145
146impl SerdeByteStr for DeterministicFiniteAutomaton {
147 fn serialize(&self, buf: &mut Vec<u8>) {
148 self.states.serialize(buf);
149 self.initial_state.serialize(buf);
150 }
151
152 fn deserialize<I>(iter: &mut I) -> Self
153 where
154 I: Iterator<Item = u8>,
155 {
156 let states = Vec::deserialize(iter);
157 let initial_state = usize::deserialize(iter);
158 Self {
159 states,
160 initial_state,
161 }
162 }
163}
164
165pub struct WeightedFiniteAutomaton<F>
166where
167 F: Field,
168 F::Additive: Invertible,
169 F::Multiplicative: Invertible,
170{
171 pub initial_weights: Matrix<F>,
172 pub transitions: Vec<Matrix<F>>,
173 pub final_weights: Matrix<F>,
174}
175
176impl<F> Debug for WeightedFiniteAutomaton<F>
177where
178 F: Field,
179 F::Additive: Invertible,
180 F::Multiplicative: Invertible,
181 F::T: Debug,
182{
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 f.debug_struct("WeightedFiniteAutomaton")
185 .field("initial_weights", &self.initial_weights)
186 .field("transitions", &self.transitions)
187 .field("final_weights", &self.final_weights)
188 .finish()
189 }
190}
191
192impl<F> Clone for WeightedFiniteAutomaton<F>
193where
194 F: Field,
195 F::Additive: Invertible,
196 F::Multiplicative: Invertible,
197{
198 fn clone(&self) -> Self {
199 Self {
200 initial_weights: self.initial_weights.clone(),
201 transitions: self.transitions.clone(),
202 final_weights: self.final_weights.clone(),
203 }
204 }
205}
206
207impl<F> BlackBoxAutomaton for WeightedFiniteAutomaton<F>
208where
209 F: Field,
210 F::Additive: Invertible,
211 F::Multiplicative: Invertible,
212{
213 type Output = F::T;
214
215 fn sigma(&self) -> usize {
216 self.transitions.len()
217 }
218
219 fn behavior<I>(&self, input: I) -> Self::Output
220 where
221 I: IntoIterator<Item = usize>,
222 {
223 let mut weights = self.initial_weights.clone();
224 for x in input {
225 weights = &weights * &self.transitions[x];
226 }
227 let result = &weights * &self.final_weights;
228 if result.shape != (0, 0) {
229 result[0][0].clone()
230 } else {
231 F::zero()
232 }
233 }
234}
235
236impl<F> SerdeByteStr for WeightedFiniteAutomaton<F>
237where
238 F: Field,
239 F::Additive: Invertible,
240 F::Multiplicative: Invertible,
241 F::T: SerdeByteStr,
242{
243 fn serialize(&self, buf: &mut Vec<u8>) {
244 self.initial_weights.serialize(buf);
245 self.transitions.serialize(buf);
246 self.final_weights.serialize(buf);
247 }
248
249 fn deserialize<I>(iter: &mut I) -> Self
250 where
251 I: Iterator<Item = u8>,
252 {
253 let initial_weights = Matrix::deserialize(iter);
254 let transitions = Vec::deserialize(iter);
255 let final_weights = Matrix::deserialize(iter);
256 Self {
257 initial_weights,
258 transitions,
259 final_weights,
260 }
261 }
262}
263
264pub fn dense_sampling(sigma: usize, max_len: usize) -> impl Iterator<Item = Vec<usize>> {
265 assert_ne!(sigma, 0, "Sigma must be greater than 0");
266 let mut current = vec![];
267 once_with(Vec::new).chain(from_fn(move || {
268 let mut carry = true;
269 for i in (0..current.len()).rev() {
270 current[i] += 1;
271 if current[i] == sigma {
272 current[i] = 0;
273 } else {
274 carry = false;
275 break;
276 }
277 }
278 if carry {
279 current.push(0);
280 }
281 if current.len() <= max_len {
282 Some(current.to_vec())
283 } else {
284 None
285 }
286 }))
287}
288
289pub fn random_sampling(
290 sigma: usize,
291 len_spec: impl RandomSpec<usize>,
292 seconds: f64,
293) -> impl Iterator<Item = Vec<usize>> {
294 assert_ne!(sigma, 0, "Sigma must be greater than 0");
295 let now = Instant::now();
296 let mut rng = Xorshift::new();
297 from_fn(move || {
298 if now.elapsed().as_secs_f64() > seconds {
299 None
300 } else {
301 let n = rng.random(&len_spec);
302 Some(rng.random_iter(0..sigma).take(n).collect())
303 }
304 })
305}
306
307#[derive(Debug, Clone)]
308pub struct DfaLearning<A>
309where
310 A: BlackBoxAutomaton<Output = bool>,
311{
312 automaton: A,
313 prefixes: Vec<Vec<usize>>,
314 suffixes: Vec<Vec<usize>>,
315 table: Vec<BitSet>,
316 row_map: HashMap<BitSet, usize>,
317}
318
319impl<A> DfaLearning<A>
320where
321 A: BlackBoxAutomaton<Output = bool>,
322{
323 pub fn new(automaton: A) -> Self {
324 let mut this = Self {
325 automaton,
326 prefixes: vec![],
327 suffixes: vec![],
328 table: vec![],
329 row_map: HashMap::new(),
330 };
331 this.add_suffix(vec![]);
332 this.add_prefix(vec![]);
333 this
334 }
335 fn add_prefix(&mut self, prefix: Vec<usize>) -> usize {
336 let row: BitSet = self
337 .suffixes
338 .iter()
339 .map(|s| {
340 self.automaton
341 .behavior(prefix.iter().cloned().chain(s.iter().cloned()))
342 })
343 .collect();
344 *self.row_map.entry(row.clone()).or_insert_with(|| {
345 let idx = self.table.len();
346 self.table.push(row);
347 self.prefixes.push(prefix);
348 idx
349 })
350 }
351 fn add_suffix(&mut self, suffix: Vec<usize>) {
352 if self.suffixes.contains(&suffix) {
353 return;
354 }
355 for (prefix, table) in self.prefixes.iter_mut().zip(&mut self.table) {
356 table.push(
357 self.automaton
358 .behavior(prefix.iter().cloned().chain(suffix.iter().cloned())),
359 );
360 }
361 self.suffixes.push(suffix);
362 self.row_map.clear();
363 for (i_prefix, row) in self.table.iter().enumerate() {
364 self.row_map.insert(row.clone(), i_prefix);
365 }
366 }
367 pub fn construct_dfa(&mut self) -> DeterministicFiniteAutomaton {
368 let sigma = self.automaton.sigma();
369 let mut dfa = DeterministicFiniteAutomaton {
370 states: vec![],
371 initial_state: 0,
372 };
373 let mut i_prefix = 0;
374 while i_prefix < self.prefixes.len() {
375 let mut delta = vec![];
376 for x in 0..sigma {
377 let prefix: Vec<usize> =
378 self.prefixes[i_prefix].iter().cloned().chain([x]).collect();
379 let index = self.add_prefix(prefix);
380 delta.push(index);
381 }
382 dfa.states.push(DfaState {
383 delta,
384 accept: self.table[i_prefix].get(0),
385 });
386 i_prefix += 1;
387 }
388 dfa
389 }
390 pub fn train_sample(&mut self, dfa: &DeterministicFiniteAutomaton, sample: &[usize]) -> bool {
391 let expected = self.automaton.behavior(sample.iter().cloned());
392 if expected == dfa.behavior(sample.iter().cloned()) {
393 return false;
394 }
395 let n = sample.len();
396 let mut states: Vec<(usize, usize)> = Vec::with_capacity(n + 1);
397 let mut s = 0usize;
398 states.push((s, 0));
399 for (k, &x) in sample.iter().enumerate() {
400 s = dfa.states[s].delta[x];
401 states.push((s, k + 1));
402 }
403 let split = states.partition_point(|&(state, k)| {
404 self.automaton.behavior(
405 self.prefixes[state]
406 .iter()
407 .cloned()
408 .chain(sample[k..].iter().cloned()),
409 ) == expected
410 });
411 let new_prefix = sample[..split].to_vec();
412 let new_suffix = sample[split..].to_vec();
413 self.add_suffix(new_suffix);
414 self.add_prefix(new_prefix);
415 true
416 }
417 pub fn train(
418 &mut self,
419 samples: impl IntoIterator<Item = Vec<usize>>,
420 ) -> DeterministicFiniteAutomaton {
421 let mut dfa = self.construct_dfa();
422 for sample in samples {
423 if self.train_sample(&dfa, &sample) {
424 dfa = self.construct_dfa();
425 }
426 }
427 dfa
428 }
429}
430
431pub struct WfaLearning<F, A>
432where
433 F: Field,
434 F::Additive: Invertible,
435 F::Multiplicative: Invertible,
436 A: BlackBoxAutomaton<Output = F::T>,
437{
438 automaton: A,
439 prefixes: Vec<Vec<usize>>,
440 suffixes: Vec<Vec<usize>>,
441 inv_h: Matrix<F>,
442 nh: Vec<Matrix<F>>,
443 wfa: WeightedFiniteAutomaton<F>,
444 _marker: PhantomData<fn() -> F>,
445}
446
447impl<F, A> Debug for WfaLearning<F, A>
448where
449 F: Field,
450 F::Additive: Invertible,
451 F::Multiplicative: Invertible,
452 F::T: Debug,
453 A: BlackBoxAutomaton<Output = F::T> + Debug,
454{
455 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456 f.debug_struct("WfaLearning")
457 .field("automaton", &self.automaton)
458 .field("prefixes", &self.prefixes)
459 .field("suffixes", &self.suffixes)
460 .field("inv_h", &self.inv_h)
461 .field("nh", &self.nh)
462 .field("wfa", &self.wfa)
463 .finish()
464 }
465}
466
467impl<F, A> Clone for WfaLearning<F, A>
468where
469 F: Field,
470 F::Additive: Invertible,
471 F::Multiplicative: Invertible,
472 A: BlackBoxAutomaton<Output = F::T> + Clone,
473{
474 fn clone(&self) -> Self {
475 Self {
476 automaton: self.automaton.clone(),
477 prefixes: self.prefixes.clone(),
478 suffixes: self.suffixes.clone(),
479 inv_h: self.inv_h.clone(),
480 nh: self.nh.clone(),
481 wfa: self.wfa.clone(),
482 _marker: self._marker,
483 }
484 }
485}
486
487impl<F, A> WfaLearning<F, A>
488where
489 F: Field,
490 F::Additive: Invertible,
491 F::Multiplicative: Invertible,
492 F::T: PartialEq,
493 A: BlackBoxAutomaton<Output = F::T>,
494{
495 pub fn new(automaton: A) -> Self {
496 let sigma = automaton.sigma();
497 Self {
498 automaton,
499 prefixes: vec![],
500 suffixes: vec![],
501 inv_h: Matrix::zeros((0, 0)),
502 nh: vec![Matrix::zeros((0, 0)); sigma],
503 wfa: WeightedFiniteAutomaton {
504 initial_weights: Matrix::zeros((1, 0)),
505 transitions: vec![Matrix::zeros((0, 0)); sigma],
506 final_weights: Matrix::zeros((0, 1)),
507 },
508 _marker: PhantomData,
509 }
510 }
511 pub fn wfa(&self) -> &WeightedFiniteAutomaton<F> {
512 &self.wfa
513 }
514 fn split_sample(&mut self, sample: &[usize]) -> Option<(Vec<usize>, Vec<usize>)> {
515 if self.prefixes.is_empty() && !F::is_zero(&self.automaton.behavior(sample.iter().cloned()))
516 {
517 return Some((vec![], sample.to_vec()));
518 }
519 let expected = self.automaton.behavior(sample.iter().cloned());
520 if expected == self.wfa.behavior(sample.iter().cloned()) {
521 return None;
522 }
523 let n = sample.len();
524 let dim = self.wfa.final_weights.shape.0;
525 let mut states: Vec<(Matrix<F>, usize)> = Vec::with_capacity(n + 1);
526 let mut v = self.wfa.final_weights.clone();
527 states.push((v.clone(), n));
528 for k in (0..n).rev() {
529 v = &self.wfa.transitions[sample[k]] * &v;
530 states.push((v.clone(), k));
531 }
532 states.reverse();
533 let split = states.partition_point(|(state, k)| {
534 (0..dim).any(|j| {
535 self.automaton.behavior(
536 self.prefixes[j]
537 .iter()
538 .cloned()
539 .chain(sample[*k..].iter().cloned()),
540 ) != state[j][0]
541 })
542 });
543 Some((sample[..split].to_vec(), sample[split..].to_vec()))
544 }
545 pub fn train_sample(&mut self, sample: &[usize]) -> bool {
546 let Some((prefix, suffix)) = self.split_sample(sample) else {
547 return false;
548 };
549 self.prefixes.push(prefix);
550 self.suffixes.push(suffix);
551 let n = self.inv_h.shape.0;
552 let prefix = &self.prefixes[n];
553 let suffix = &self.suffixes[n];
554 let u = Matrix::<F>::new_with((n, 1), |i, _| {
555 self.automaton.behavior(
556 self.prefixes[i]
557 .iter()
558 .cloned()
559 .chain(suffix.iter().cloned()),
560 )
561 });
562 let v = Matrix::<F>::new_with((1, n), |_, j| {
563 self.automaton.behavior(
564 prefix
565 .iter()
566 .cloned()
567 .chain(self.suffixes[j].iter().cloned()),
568 )
569 });
570 let w = Matrix::<F>::new_with((1, 1), |_, _| {
571 self.automaton
572 .behavior(prefix.iter().cloned().chain(suffix.iter().cloned()))
573 });
574 let t = &self.inv_h * &u;
575 let s = &v * &self.inv_h;
576 let d = F::inv(&(&w - &(&v * &t))[0][0]);
577 let dh = &t * &s;
578 for i in 0..n {
579 for j in 0..n {
580 F::add_assign(&mut self.inv_h[i][j], &F::mul(&dh[i][j], &d));
581 }
582 }
583 self.inv_h
584 .add_col_with(|i, _| F::neg(&F::mul(&t[i][0], &d)));
585 self.inv_h.add_row_with(|_, j| {
586 if j != n {
587 F::neg(&F::mul(&s[0][j], &d))
588 } else {
589 d.clone()
590 }
591 });
592
593 for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
594 let b = &(&self.nh[x] * &t) * &s;
595 for i in 0..n {
596 for j in 0..n {
597 F::add_assign(&mut transition[i][j], &F::mul(&b[i][j], &d));
598 }
599 }
600 }
601 for (x, nh) in self.nh.iter_mut().enumerate() {
602 nh.add_col_with(|i, j| {
603 self.automaton.behavior(
604 self.prefixes[i]
605 .iter()
606 .cloned()
607 .chain([x])
608 .chain(self.suffixes[j].iter().cloned()),
609 )
610 });
611 nh.add_row_with(|i, j| {
612 self.automaton.behavior(
613 self.prefixes[i]
614 .iter()
615 .cloned()
616 .chain([x])
617 .chain(self.suffixes[j].iter().cloned()),
618 )
619 });
620 }
621 self.wfa
622 .initial_weights
623 .add_col_with(|_, _| if n == 0 { F::one() } else { F::zero() });
624 self.wfa
625 .final_weights
626 .add_row_with(|_, _| self.automaton.behavior(prefix.iter().cloned()));
627 for (x, transition) in self.wfa.transitions.iter_mut().enumerate() {
628 transition.add_col_with(|_, _| F::zero());
629 transition.add_row_with(|_, _| F::zero());
630 for i in 0..=n {
631 for j in 0..=n {
632 if i == n || j == n {
633 for k in 0..=n {
634 if i != n && j != n && k != n {
635 continue;
636 }
637 F::add_assign(
638 &mut transition[i][k],
639 &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
640 );
641 }
642 } else {
643 let k = n;
644 F::add_assign(
645 &mut transition[i][k],
646 &F::mul(&self.nh[x][i][j], &self.inv_h[j][k]),
647 );
648 }
649 }
650 }
651 }
652 true
653 }
654 pub fn train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
655 for sample in samples {
656 self.train_sample(&sample);
657 }
658 }
659 pub fn batch_train(&mut self, samples: impl IntoIterator<Item = Vec<usize>>) {
660 let mut prefix_set: HashSet<_> = self.prefixes.iter().cloned().collect();
661 let mut suffix_set: HashSet<_> = self.suffixes.iter().cloned().collect();
662 for sample in samples {
663 if prefix_set.insert(sample.to_vec()) {
664 self.prefixes.push(sample.to_vec());
665 }
666 if suffix_set.insert(sample.to_vec()) {
667 self.suffixes.push(sample);
668 }
669 }
670 let mut h = Matrix::<F>::new_with((self.prefixes.len(), self.suffixes.len()), |i, j| {
671 self.automaton.behavior(
672 self.prefixes[i]
673 .iter()
674 .cloned()
675 .chain(self.suffixes[j].iter().cloned()),
676 )
677 });
678 if !self.prefixes.is_empty() && !self.suffixes.is_empty() && F::is_zero(&h[0][0]) {
679 for j in 1..self.suffixes.len() {
680 if !F::is_zero(&h[0][j]) {
681 self.suffixes.swap(0, j);
682 for i in 0..self.prefixes.len() {
683 h.data[i].swap(0, j);
684 }
685 break;
686 }
687 }
688 }
689 let mut row_id: Vec<usize> = (0..h.shape.0).collect();
690 let mut pivots = vec![];
691 h.row_reduction_with(false, |r, p, c| {
692 row_id.swap(r, p);
693 pivots.push((row_id[r], c));
694 });
695 let mut new_prefixes = vec![];
696 let mut new_suffixes = vec![];
697 for (i, j) in pivots {
698 new_prefixes.push(self.prefixes[i].clone());
699 new_suffixes.push(self.suffixes[j].clone());
700 }
701 self.prefixes = new_prefixes;
702 self.suffixes = new_suffixes;
703 assert_eq!(self.prefixes.len(), self.suffixes.len());
704 let n = self.prefixes.len();
705 let h = Matrix::<F>::new_with((n, n), |i, j| {
706 self.automaton.behavior(
707 self.prefixes[i]
708 .iter()
709 .cloned()
710 .chain(self.suffixes[j].iter().cloned()),
711 )
712 });
713 self.inv_h = h.inverse().expect("Hankel matrix must be invertible");
714 self.wfa = WeightedFiniteAutomaton::<F> {
715 initial_weights: Matrix::new_with((1, n), |_, j| {
716 if self.prefixes[j].is_empty() {
717 F::one()
718 } else {
719 F::zero()
720 }
721 }),
722 transitions: (0..self.automaton.sigma())
723 .map(|x| {
724 &Matrix::new_with((n, n), |i, j| {
725 self.automaton.behavior(
726 self.prefixes[i]
727 .iter()
728 .cloned()
729 .chain([x])
730 .chain(self.suffixes[j].iter().cloned()),
731 )
732 }) * &self.inv_h
733 })
734 .collect(),
735 final_weights: Matrix::new_with((n, 1), |i, _| {
736 self.automaton.behavior(self.prefixes[i].iter().cloned())
737 }),
738 };
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745 use crate::{
746 algebra::AddMulOperation,
747 num::{One as _, Zero as _, mint_basic::MInt998244353},
748 };
749 use std::collections::{HashSet, VecDeque};
750
751 #[test]
752 fn test_dense_sampling() {
753 for base in 1usize..=10 {
754 let mut expected = vec![];
755 for len in 0..=3 {
756 for n in 0..base.pow(len) {
757 let mut n = n;
758 let mut current = vec![];
759 for _ in 0..len {
760 current.push(n % base);
761 n /= base;
762 }
763 current.reverse();
764 expected.push(current);
765 }
766 }
767
768 for (expected, result) in expected.into_iter().zip(dense_sampling(base, 3)) {
769 assert_eq!(expected, result);
770 }
771 }
772 }
773
774 #[test]
775 fn test_lstar() {
776 {
777 let automaton = BlackBoxAutomatonImpl::new(2, |input| input.len() % 6 == 0);
778 let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 6));
779 for sample in dense_sampling(automaton.sigma(), 12) {
780 let expected = automaton.behavior(sample.iter().cloned());
781 let result = dfa.behavior(sample.iter().cloned());
782 assert_eq!(expected, result);
783 }
784 }
785 {
786 let automaton =
787 BlackBoxAutomatonImpl::new(3, |input| input.iter().sum::<usize>() % 4 == 0);
788 let dfa = DfaLearning::new(&automaton).train(dense_sampling(3, 4));
789 for sample in dense_sampling(automaton.sigma(), 8) {
790 let expected = automaton.behavior(sample.iter().cloned());
791 let result = dfa.behavior(sample.iter().cloned());
792 assert_eq!(expected, result);
793 }
794 }
795 for i in 0usize..16 {
796 let a = i >> 3 & 1;
797 let b = i >> 2 & 1;
798 let c = i >> 1 & 1;
799 let d = i & 1;
800 let naive = |t: &[usize]| {
801 let mut set = HashSet::new();
802 let mut deq = VecDeque::new();
803 deq.push_back(t.to_vec());
804 set.insert(t.to_vec());
805 while let Some(t) = deq.pop_front() {
806 for i in 0..t.len().saturating_sub(1) {
807 let x = match (t[i], t[i + 1]) {
808 (0, 0) => a,
809 (0, 1) => b,
810 (1, 0) => c,
811 (1, 1) => d,
812 _ => unreachable!(),
813 };
814 let mut t = t.to_vec();
815 t.remove(i);
816 t[i] = x;
817 if set.insert(t.to_vec()) {
818 deq.push_back(t);
819 }
820 }
821 }
822 set.contains(&vec![1])
823 };
824 let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
825 let dfa = DfaLearning::new(&automaton).train(dense_sampling(2, 4));
826 for sample in dense_sampling(automaton.sigma(), 8) {
827 let expected = automaton.behavior(sample.iter().cloned());
828 let result = dfa.behavior(sample.iter().cloned());
829 assert_eq!(expected, result);
830 }
831 }
832 }
833
834 #[test]
835 fn test_wfa_learning() {
836 {
837 let automaton = BlackBoxAutomatonImpl::new(2, |input| {
838 MInt998244353::from(input.iter().sum::<usize>())
839 });
840 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
841 wl.train(dense_sampling(2, 3));
842 let wfa = wl.wfa();
843 for sample in dense_sampling(automaton.sigma(), 12) {
844 let expected = automaton.behavior(sample.iter().cloned());
845 let result = wfa.behavior(sample.iter().cloned());
846 assert_eq!(expected, result);
847 }
848 }
849 {
850 let automaton = BlackBoxAutomatonImpl::new(3, |input| {
851 let mut s = MInt998244353::zero();
852 let mut c = MInt998244353::one();
853 for &x in &input {
854 s += MInt998244353::from(x) * c;
855 c = -c;
856 }
857 s
858 });
859 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
860 wl.train(dense_sampling(3, 4));
861 let wfa = wl.wfa();
862 for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
863 automaton.sigma(),
864 6..=12,
865 0.1,
866 )) {
867 let expected = automaton.behavior(sample.iter().cloned());
868 let result = wfa.behavior(sample.iter().cloned());
869 assert_eq!(expected, result);
870 }
871 }
872 {
873 let automaton = BlackBoxAutomatonImpl::new(2, |input| {
875 let mut n = 1; for x in input {
877 n = n * 2 + x;
878 }
879 let mut s = MInt998244353::zero();
880 for u in 0..=n {
881 for v in 0..=n {
882 let mut ok = false;
883 for a in 0..=n {
884 let b = u ^ a;
885 ok |= a + b == v;
886 }
887 s += MInt998244353::new(ok as _);
888 }
889 }
890 s
891 });
892 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
893 wl.train(dense_sampling(2, 4));
894 let wfa = wl.wfa();
895 for sample in dense_sampling(automaton.sigma(), 6).chain(random_sampling(
896 automaton.sigma(),
897 6..=12,
898 0.1,
899 )) {
900 let expected = automaton.behavior(sample.iter().cloned());
901 let result = wfa.behavior(sample.iter().cloned());
902 assert_eq!(expected, result);
903 }
904 }
905 for i in 0usize..16 {
906 let a = i >> 3 & 1;
907 let b = i >> 2 & 1;
908 let c = i >> 1 & 1;
909 let d = i & 1;
910 let naive = |t: &[usize]| {
911 let mut set = HashSet::new();
912 let mut deq = VecDeque::new();
913 deq.push_back(t.to_vec());
914 set.insert(t.to_vec());
915 while let Some(t) = deq.pop_front() {
916 for i in 0..t.len().saturating_sub(1) {
917 let x = match (t[i], t[i + 1]) {
918 (0, 0) => a,
919 (0, 1) => b,
920 (1, 0) => c,
921 (1, 1) => d,
922 _ => unreachable!(),
923 };
924 let mut t = t.to_vec();
925 t.remove(i);
926 t[i] = x;
927 if set.insert(t.to_vec()) {
928 deq.push_back(t);
929 }
930 }
931 }
932 set.contains(&vec![1])
933 };
934 let naive = |t: &[usize]| {
935 let mut s = MInt998244353::zero();
936 for l in 0..t.len() {
937 for r in l + 1..=t.len() {
938 if naive(&t[l..r]) {
939 s += MInt998244353::one();
940 }
941 }
942 }
943 s
944 };
945 let automaton = BlackBoxAutomatonImpl::new(2, |t| naive(&t));
946 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
947 wl.train(dense_sampling(2, 6));
948 let wfa = wl.wfa();
949 for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
950 automaton.sigma(),
951 9..=12,
952 0.1,
953 )) {
954 let expected = automaton.behavior(sample.iter().cloned());
955 let result = wfa.behavior(sample.iter().cloned());
956 assert_eq!(expected, result);
957 }
958 let mut wl = WfaLearning::<AddMulOperation<_>, _>::new(&automaton);
959 wl.batch_train(dense_sampling(2, 3));
960 let wfa = wl.wfa();
961 for sample in dense_sampling(automaton.sigma(), 8).chain(random_sampling(
962 automaton.sigma(),
963 9..=12,
964 0.1,
965 )) {
966 let expected = automaton.behavior(sample.iter().cloned());
967 let result = wfa.behavior(sample.iter().cloned());
968 assert_eq!(expected, result);
969 }
970 }
971 }
972}