1use super::{Matrix, One, Zero};
2use std::{
3 collections::HashMap,
4 fmt::Debug,
5 hash::Hash,
6 marker::PhantomData,
7 ops::{Add, Div, Mul, Sub},
8};
9
10type Marker<T> = PhantomData<fn() -> T>;
11
12#[derive(Debug, Clone)]
13pub struct EsperEstimator<T, Input, Class, FC, FF>
14where
15 Class: Eq + Hash,
16 FC: Fn(&Input) -> Class,
17 FF: Fn(&Input) -> Vec<T>,
18{
19 class: FC,
20 feature: FF,
21 data: HashMap<Class, (Vec<Vec<T>>, Vec<T>)>,
22 _marker: Marker<(T, Input, Class)>,
23}
24
25#[derive(Debug, Clone)]
26pub struct EsperSolver<T, Input, Class, FC, FF>
27where
28 Class: Eq + Hash,
29 FC: Fn(&Input) -> Class,
30 FF: Fn(&Input) -> Vec<T>,
31{
32 class: FC,
33 feature: FF,
34 data: HashMap<Class, Option<Vec<T>>>,
35 _marker: Marker<(T, Input, Class)>,
36}
37
38impl<T, Input, Class, FC, FF> EsperEstimator<T, Input, Class, FC, FF>
39where
40 Class: Eq + Hash,
41 FC: Fn(&Input) -> Class,
42 FF: Fn(&Input) -> Vec<T>,
43{
44 pub fn new(class: FC, feature: FF) -> Self {
45 Self {
46 class,
47 feature,
48 data: Default::default(),
49 _marker: PhantomData,
50 }
51 }
52
53 pub fn push(&mut self, input: Input, output: T) {
54 let class = (self.class)(&input);
55 let feature = (self.feature)(&input);
56 let entry = self.data.entry(class).or_default();
57 entry.0.push(feature);
58 entry.1.push(output);
59 }
60}
61
62impl<T, Input, Class, FC, FF> EsperEstimator<T, Input, Class, FC, FF>
63where
64 Class: Eq + Hash,
65 T: Copy + PartialEq + Zero + One + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
66 FC: Fn(&Input) -> Class,
67 FF: Fn(&Input) -> Vec<T>,
68{
69 pub fn solve(self) -> EsperSolver<T, Input, Class, FC, FF> {
70 let data: HashMap<_, _> = self
71 .data
72 .into_iter()
73 .map(|(key, (a, b))| {
74 (
75 key,
76 Matrix::from_vec(a).solve_system_of_linear_equations(&b),
77 )
78 })
79 .collect();
80 EsperSolver {
81 class: self.class,
82 feature: self.feature,
83 data,
84 _marker: PhantomData,
85 }
86 }
87
88 pub fn solve_checked(self) -> EsperSolver<T, Input, Class, FC, FF>
89 where
90 Class: Debug,
91 T: Debug,
92 {
93 let data: HashMap<_, _> = self
94 .data
95 .into_iter()
96 .map(|(key, (a, b))| {
97 let mat = Matrix::from_vec(a);
98 let coeff = mat.solve_system_of_linear_equations(&b);
99 if coeff.is_none() {
100 eprintln!(
101 "failed to solve linear equations: key={:?} A={:?} b={:?}",
102 key, &mat.data, &b
103 );
104 }
105 (key, coeff)
106 })
107 .collect();
108 EsperSolver {
109 class: self.class,
110 feature: self.feature,
111 data,
112 _marker: PhantomData,
113 }
114 }
115}
116
117impl<T, Input, Class, FC, FF> EsperSolver<T, Input, Class, FC, FF>
118where
119 Class: Eq + Hash,
120 T: Copy + Zero + Add<Output = T> + Mul<Output = T>,
121 FC: Fn(&Input) -> Class,
122 FF: Fn(&Input) -> Vec<T>,
123{
124 pub fn solve(&self, input: Input) -> T {
125 let coeff = self
126 .data
127 .get(&(self.class)(&input))
128 .expect("unrecognized class")
129 .as_ref()
130 .expect("failed to solve");
131 let feature = (self.feature)(&input);
132 feature
133 .into_iter()
134 .zip(coeff)
135 .map(|(x, &y)| x * y)
136 .fold(T::zero(), |x, y| x + y)
137 }
138}