1use super::{Field, Invertible, Matrix};
2use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
3
4type Marker<T> = PhantomData<fn() -> T>;
5
6#[derive(Debug, Clone)]
7struct SystemOfLinearEquation<T> {
8 a: Vec<Vec<T>>,
9 b: Vec<T>,
10}
11
12impl<T> Default for SystemOfLinearEquation<T> {
13 fn default() -> Self {
14 Self {
15 a: Default::default(),
16 b: Default::default(),
17 }
18 }
19}
20
21#[derive(Debug, Clone)]
22pub struct EsperEstimator<R, Input, Class, FC, FF>
23where
24 R: Field<Additive: Invertible, Multiplicative: Invertible>,
25 Class: Eq + Hash,
26 FC: Fn(&Input) -> Class,
27 FF: Fn(&Input) -> Vec<R::T>,
28{
29 class: FC,
30 feature: FF,
31 data: HashMap<Class, SystemOfLinearEquation<R::T>>,
32 _marker: Marker<(R::T, Input, Class)>,
33}
34
35#[derive(Debug, Clone)]
36pub struct EsperSolver<R, Input, Class, FC, FF>
37where
38 R: Field<Additive: Invertible, Multiplicative: Invertible>,
39 Class: Eq + Hash,
40 FC: Fn(&Input) -> Class,
41 FF: Fn(&Input) -> Vec<R::T>,
42{
43 class: FC,
44 feature: FF,
45 data: HashMap<Class, Option<Vec<R::T>>>,
46 _marker: Marker<(R::T, Input, Class)>,
47}
48
49impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
50where
51 R: Field<Additive: Invertible, Multiplicative: Invertible>,
52 Class: Eq + Hash,
53 FC: Fn(&Input) -> Class,
54 FF: Fn(&Input) -> Vec<R::T>,
55{
56 pub fn new(class: FC, feature: FF) -> Self {
57 Self {
58 class,
59 feature,
60 data: Default::default(),
61 _marker: PhantomData,
62 }
63 }
64
65 pub fn push(&mut self, input: Input, output: R::T) {
66 let class = (self.class)(&input);
67 let feature = (self.feature)(&input);
68 let entry = self.data.entry(class).or_default();
69 entry.a.push(feature);
70 entry.b.push(output);
71 }
72}
73
74impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
75where
76 R: Field<T: PartialEq, Additive: Invertible, Multiplicative: Invertible>,
77 Class: Eq + Hash,
78 FC: Fn(&Input) -> Class,
79 FF: Fn(&Input) -> Vec<R::T>,
80{
81 pub fn solve(self) -> EsperSolver<R, Input, Class, FC, FF> {
82 let data: HashMap<_, _> = self
83 .data
84 .into_iter()
85 .map(|(key, SystemOfLinearEquation { a, b })| {
86 (
87 key,
88 Matrix::<R>::from_vec(a)
89 .solve_system_of_linear_equations(&b)
90 .map(|sol| sol.particular),
91 )
92 })
93 .collect();
94 EsperSolver {
95 class: self.class,
96 feature: self.feature,
97 data,
98 _marker: PhantomData,
99 }
100 }
101
102 pub fn solve_checked(self) -> EsperSolver<R, Input, Class, FC, FF>
103 where
104 Class: Debug,
105 R: Field<T: Debug, Additive: Invertible, Multiplicative: Invertible>,
106 {
107 let data: HashMap<_, _> = self
108 .data
109 .into_iter()
110 .map(|(key, SystemOfLinearEquation { a, b })| {
111 let mat = Matrix::<R>::from_vec(a);
112 let coeff = mat
113 .solve_system_of_linear_equations(&b)
114 .map(|sol| sol.particular);
115 if coeff.is_none() {
116 eprintln!(
117 "failed to solve linear equations: key={:?} A={:?} b={:?}",
118 key, &mat.data, &b
119 );
120 }
121 (key, coeff)
122 })
123 .collect();
124 EsperSolver {
125 class: self.class,
126 feature: self.feature,
127 data,
128 _marker: PhantomData,
129 }
130 }
131}
132
133impl<R, Input, Class, FC, FF> EsperSolver<R, Input, Class, FC, FF>
134where
135 R: Field<Additive: Invertible, Multiplicative: Invertible>,
136 Class: Eq + Hash,
137 FC: Fn(&Input) -> Class,
138 FF: Fn(&Input) -> Vec<R::T>,
139{
140 pub fn solve(&self, input: Input) -> R::T {
141 let coeff = self
142 .data
143 .get(&(self.class)(&input))
144 .expect("unrecognized class")
145 .as_ref()
146 .expect("failed to solve");
147 let feature = (self.feature)(&input);
148 feature
149 .into_iter()
150 .zip(coeff)
151 .map(|(x, y)| R::mul(&x, y))
152 .fold(R::zero(), |x, y| R::add(&x, &y))
153 }
154}