1#![allow(clippy::type_complexity)]
2
3use super::{Field, Invertible, Matrix};
4use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
5
6type Marker<T> = PhantomData<fn() -> T>;
7
8#[derive(Debug, Clone)]
9pub struct EsperEstimator<R, Input, Class, FC, FF>
10where
11 R: Field,
12 R::Additive: Invertible,
13 R::Multiplicative: Invertible,
14 Class: Eq + Hash,
15 FC: Fn(&Input) -> Class,
16 FF: Fn(&Input) -> Vec<R::T>,
17{
18 class: FC,
19 feature: FF,
20 data: HashMap<Class, (Vec<Vec<R::T>>, Vec<R::T>)>,
21 _marker: Marker<(R::T, Input, Class)>,
22}
23
24#[derive(Debug, Clone)]
25pub struct EsperSolver<R, Input, Class, FC, FF>
26where
27 R: Field,
28 R::Additive: Invertible,
29 R::Multiplicative: Invertible,
30 Class: Eq + Hash,
31 FC: Fn(&Input) -> Class,
32 FF: Fn(&Input) -> Vec<R::T>,
33{
34 class: FC,
35 feature: FF,
36 data: HashMap<Class, Option<Vec<R::T>>>,
37 _marker: Marker<(R::T, Input, Class)>,
38}
39
40impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
41where
42 R: Field,
43 R::Additive: Invertible,
44 R::Multiplicative: Invertible,
45 Class: Eq + Hash,
46 FC: Fn(&Input) -> Class,
47 FF: Fn(&Input) -> Vec<R::T>,
48{
49 pub fn new(class: FC, feature: FF) -> Self {
50 Self {
51 class,
52 feature,
53 data: Default::default(),
54 _marker: PhantomData,
55 }
56 }
57
58 pub fn push(&mut self, input: Input, output: R::T) {
59 let class = (self.class)(&input);
60 let feature = (self.feature)(&input);
61 let entry = self.data.entry(class).or_default();
62 entry.0.push(feature);
63 entry.1.push(output);
64 }
65}
66
67impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
68where
69 R: Field,
70 R::Additive: Invertible,
71 R::Multiplicative: Invertible,
72 R::T: PartialEq,
73 Class: Eq + Hash,
74 FC: Fn(&Input) -> Class,
75 FF: Fn(&Input) -> Vec<R::T>,
76{
77 pub fn solve(self) -> EsperSolver<R, Input, Class, FC, FF> {
78 let data: HashMap<_, _> = self
79 .data
80 .into_iter()
81 .map(|(key, (a, b))| {
82 (
83 key,
84 Matrix::<R>::from_vec(a)
85 .solve_system_of_linear_equations(&b)
86 .map(|sol| sol.particular),
87 )
88 })
89 .collect();
90 EsperSolver {
91 class: self.class,
92 feature: self.feature,
93 data,
94 _marker: PhantomData,
95 }
96 }
97
98 pub fn solve_checked(self) -> EsperSolver<R, Input, Class, FC, FF>
99 where
100 Class: Debug,
101 R::T: Debug,
102 {
103 let data: HashMap<_, _> = self
104 .data
105 .into_iter()
106 .map(|(key, (a, b))| {
107 let mat = Matrix::<R>::from_vec(a);
108 let coeff = mat
109 .solve_system_of_linear_equations(&b)
110 .map(|sol| sol.particular);
111 if coeff.is_none() {
112 eprintln!(
113 "failed to solve linear equations: key={:?} A={:?} b={:?}",
114 key, &mat.data, &b
115 );
116 }
117 (key, coeff)
118 })
119 .collect();
120 EsperSolver {
121 class: self.class,
122 feature: self.feature,
123 data,
124 _marker: PhantomData,
125 }
126 }
127}
128
129impl<R, Input, Class, FC, FF> EsperSolver<R, Input, Class, FC, FF>
130where
131 R: Field,
132 R::Additive: Invertible,
133 R::Multiplicative: Invertible,
134 Class: Eq + Hash,
135 FC: Fn(&Input) -> Class,
136 FF: Fn(&Input) -> Vec<R::T>,
137{
138 pub fn solve(&self, input: Input) -> R::T {
139 let coeff = self
140 .data
141 .get(&(self.class)(&input))
142 .expect("unrecognized class")
143 .as_ref()
144 .expect("failed to solve");
145 let feature = (self.feature)(&input);
146 feature
147 .into_iter()
148 .zip(coeff)
149 .map(|(x, y)| R::mul(&x, y))
150 .fold(R::zero(), |x, y| R::add(&x, &y))
151 }
152}