competitive/algorithm/
esper.rs

1#![allow(clippy::type_complexity)]
2
3use super::{Field, Invertible, Matrix};
4use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
5
6type Marker<T> = PhantomData<fn() -> T>;
7
8#[derive(Debug, Clone)]
9pub struct EsperEstimator<R, Input, Class, FC, FF>
10where
11    R: Field,
12    R::Additive: Invertible,
13    R::Multiplicative: Invertible,
14    Class: Eq + Hash,
15    FC: Fn(&Input) -> Class,
16    FF: Fn(&Input) -> Vec<R::T>,
17{
18    class: FC,
19    feature: FF,
20    data: HashMap<Class, (Vec<Vec<R::T>>, Vec<R::T>)>,
21    _marker: Marker<(R::T, Input, Class)>,
22}
23
24#[derive(Debug, Clone)]
25pub struct EsperSolver<R, Input, Class, FC, FF>
26where
27    R: Field,
28    R::Additive: Invertible,
29    R::Multiplicative: Invertible,
30    Class: Eq + Hash,
31    FC: Fn(&Input) -> Class,
32    FF: Fn(&Input) -> Vec<R::T>,
33{
34    class: FC,
35    feature: FF,
36    data: HashMap<Class, Option<Vec<R::T>>>,
37    _marker: Marker<(R::T, Input, Class)>,
38}
39
40impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
41where
42    R: Field,
43    R::Additive: Invertible,
44    R::Multiplicative: Invertible,
45    Class: Eq + Hash,
46    FC: Fn(&Input) -> Class,
47    FF: Fn(&Input) -> Vec<R::T>,
48{
49    pub fn new(class: FC, feature: FF) -> Self {
50        Self {
51            class,
52            feature,
53            data: Default::default(),
54            _marker: PhantomData,
55        }
56    }
57
58    pub fn push(&mut self, input: Input, output: R::T) {
59        let class = (self.class)(&input);
60        let feature = (self.feature)(&input);
61        let entry = self.data.entry(class).or_default();
62        entry.0.push(feature);
63        entry.1.push(output);
64    }
65}
66
67impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
68where
69    R: Field,
70    R::Additive: Invertible,
71    R::Multiplicative: Invertible,
72    R::T: PartialEq,
73    Class: Eq + Hash,
74    FC: Fn(&Input) -> Class,
75    FF: Fn(&Input) -> Vec<R::T>,
76{
77    pub fn solve(self) -> EsperSolver<R, Input, Class, FC, FF> {
78        let data: HashMap<_, _> = self
79            .data
80            .into_iter()
81            .map(|(key, (a, b))| {
82                (
83                    key,
84                    Matrix::<R>::from_vec(a)
85                        .solve_system_of_linear_equations(&b)
86                        .map(|sol| sol.particular),
87                )
88            })
89            .collect();
90        EsperSolver {
91            class: self.class,
92            feature: self.feature,
93            data,
94            _marker: PhantomData,
95        }
96    }
97
98    pub fn solve_checked(self) -> EsperSolver<R, Input, Class, FC, FF>
99    where
100        Class: Debug,
101        R::T: Debug,
102    {
103        let data: HashMap<_, _> = self
104            .data
105            .into_iter()
106            .map(|(key, (a, b))| {
107                let mat = Matrix::<R>::from_vec(a);
108                let coeff = mat
109                    .solve_system_of_linear_equations(&b)
110                    .map(|sol| sol.particular);
111                if coeff.is_none() {
112                    eprintln!(
113                        "failed to solve linear equations: key={:?} A={:?} b={:?}",
114                        key, &mat.data, &b
115                    );
116                }
117                (key, coeff)
118            })
119            .collect();
120        EsperSolver {
121            class: self.class,
122            feature: self.feature,
123            data,
124            _marker: PhantomData,
125        }
126    }
127}
128
129impl<R, Input, Class, FC, FF> EsperSolver<R, Input, Class, FC, FF>
130where
131    R: Field,
132    R::Additive: Invertible,
133    R::Multiplicative: Invertible,
134    Class: Eq + Hash,
135    FC: Fn(&Input) -> Class,
136    FF: Fn(&Input) -> Vec<R::T>,
137{
138    pub fn solve(&self, input: Input) -> R::T {
139        let coeff = self
140            .data
141            .get(&(self.class)(&input))
142            .expect("unrecognized class")
143            .as_ref()
144            .expect("failed to solve");
145        let feature = (self.feature)(&input);
146        feature
147            .into_iter()
148            .zip(coeff)
149            .map(|(x, y)| R::mul(&x, y))
150            .fold(R::zero(), |x, y| R::add(&x, &y))
151    }
152}