competitive/algorithm/
esper.rs

1use super::{Matrix, One, Zero};
2use std::{
3    collections::HashMap,
4    fmt::Debug,
5    hash::Hash,
6    marker::PhantomData,
7    ops::{Add, Div, Mul, Sub},
8};
9
10type Marker<T> = PhantomData<fn() -> T>;
11
12#[derive(Debug, Clone)]
13pub struct EsperEstimator<T, Input, Class, FC, FF>
14where
15    Class: Eq + Hash,
16    FC: Fn(&Input) -> Class,
17    FF: Fn(&Input) -> Vec<T>,
18{
19    class: FC,
20    feature: FF,
21    data: HashMap<Class, (Vec<Vec<T>>, Vec<T>)>,
22    _marker: Marker<(T, Input, Class)>,
23}
24
25#[derive(Debug, Clone)]
26pub struct EsperSolver<T, Input, Class, FC, FF>
27where
28    Class: Eq + Hash,
29    FC: Fn(&Input) -> Class,
30    FF: Fn(&Input) -> Vec<T>,
31{
32    class: FC,
33    feature: FF,
34    data: HashMap<Class, Option<Vec<T>>>,
35    _marker: Marker<(T, Input, Class)>,
36}
37
38impl<T, Input, Class, FC, FF> EsperEstimator<T, Input, Class, FC, FF>
39where
40    Class: Eq + Hash,
41    FC: Fn(&Input) -> Class,
42    FF: Fn(&Input) -> Vec<T>,
43{
44    pub fn new(class: FC, feature: FF) -> Self {
45        Self {
46            class,
47            feature,
48            data: Default::default(),
49            _marker: PhantomData,
50        }
51    }
52
53    pub fn push(&mut self, input: Input, output: T) {
54        let class = (self.class)(&input);
55        let feature = (self.feature)(&input);
56        let entry = self.data.entry(class).or_default();
57        entry.0.push(feature);
58        entry.1.push(output);
59    }
60}
61
62impl<T, Input, Class, FC, FF> EsperEstimator<T, Input, Class, FC, FF>
63where
64    Class: Eq + Hash,
65    T: Copy + PartialEq + Zero + One + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
66    FC: Fn(&Input) -> Class,
67    FF: Fn(&Input) -> Vec<T>,
68{
69    pub fn solve(self) -> EsperSolver<T, Input, Class, FC, FF> {
70        let data: HashMap<_, _> = self
71            .data
72            .into_iter()
73            .map(|(key, (a, b))| {
74                (
75                    key,
76                    Matrix::from_vec(a).solve_system_of_linear_equations(&b),
77                )
78            })
79            .collect();
80        EsperSolver {
81            class: self.class,
82            feature: self.feature,
83            data,
84            _marker: PhantomData,
85        }
86    }
87
88    pub fn solve_checked(self) -> EsperSolver<T, Input, Class, FC, FF>
89    where
90        Class: Debug,
91        T: Debug,
92    {
93        let data: HashMap<_, _> = self
94            .data
95            .into_iter()
96            .map(|(key, (a, b))| {
97                let mat = Matrix::from_vec(a);
98                let coeff = mat.solve_system_of_linear_equations(&b);
99                if coeff.is_none() {
100                    eprintln!(
101                        "failed to solve linear equations: key={:?} A={:?} b={:?}",
102                        key, &mat.data, &b
103                    );
104                }
105                (key, coeff)
106            })
107            .collect();
108        EsperSolver {
109            class: self.class,
110            feature: self.feature,
111            data,
112            _marker: PhantomData,
113        }
114    }
115}
116
117impl<T, Input, Class, FC, FF> EsperSolver<T, Input, Class, FC, FF>
118where
119    Class: Eq + Hash,
120    T: Copy + Zero + Add<Output = T> + Mul<Output = T>,
121    FC: Fn(&Input) -> Class,
122    FF: Fn(&Input) -> Vec<T>,
123{
124    pub fn solve(&self, input: Input) -> T {
125        let coeff = self
126            .data
127            .get(&(self.class)(&input))
128            .expect("unrecognized class")
129            .as_ref()
130            .expect("failed to solve");
131        let feature = (self.feature)(&input);
132        feature
133            .into_iter()
134            .zip(coeff)
135            .map(|(x, &y)| x * y)
136            .fold(T::zero(), |x, y| x + y)
137    }
138}