competitive/algorithm/
esper.rs

1use super::{Field, Invertible, Matrix};
2use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
3
4type Marker<T> = PhantomData<fn() -> T>;
5
6#[derive(Debug, Clone)]
7struct SystemOfLinearEquation<T> {
8    a: Vec<Vec<T>>,
9    b: Vec<T>,
10}
11
12impl<T> Default for SystemOfLinearEquation<T> {
13    fn default() -> Self {
14        Self {
15            a: Default::default(),
16            b: Default::default(),
17        }
18    }
19}
20
21#[derive(Debug, Clone)]
22pub struct EsperEstimator<R, Input, Class, FC, FF>
23where
24    R: Field<Additive: Invertible, Multiplicative: Invertible>,
25    Class: Eq + Hash,
26    FC: Fn(&Input) -> Class,
27    FF: Fn(&Input) -> Vec<R::T>,
28{
29    class: FC,
30    feature: FF,
31    data: HashMap<Class, SystemOfLinearEquation<R::T>>,
32    _marker: Marker<(R::T, Input, Class)>,
33}
34
35#[derive(Debug, Clone)]
36pub struct EsperSolver<R, Input, Class, FC, FF>
37where
38    R: Field<Additive: Invertible, Multiplicative: Invertible>,
39    Class: Eq + Hash,
40    FC: Fn(&Input) -> Class,
41    FF: Fn(&Input) -> Vec<R::T>,
42{
43    class: FC,
44    feature: FF,
45    data: HashMap<Class, Option<Vec<R::T>>>,
46    _marker: Marker<(R::T, Input, Class)>,
47}
48
49impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
50where
51    R: Field<Additive: Invertible, Multiplicative: Invertible>,
52    Class: Eq + Hash,
53    FC: Fn(&Input) -> Class,
54    FF: Fn(&Input) -> Vec<R::T>,
55{
56    pub fn new(class: FC, feature: FF) -> Self {
57        Self {
58            class,
59            feature,
60            data: Default::default(),
61            _marker: PhantomData,
62        }
63    }
64
65    pub fn push(&mut self, input: Input, output: R::T) {
66        let class = (self.class)(&input);
67        let feature = (self.feature)(&input);
68        let entry = self.data.entry(class).or_default();
69        entry.a.push(feature);
70        entry.b.push(output);
71    }
72}
73
74impl<R, Input, Class, FC, FF> EsperEstimator<R, Input, Class, FC, FF>
75where
76    R: Field<T: PartialEq, Additive: Invertible, Multiplicative: Invertible>,
77    Class: Eq + Hash,
78    FC: Fn(&Input) -> Class,
79    FF: Fn(&Input) -> Vec<R::T>,
80{
81    pub fn solve(self) -> EsperSolver<R, Input, Class, FC, FF> {
82        let data: HashMap<_, _> = self
83            .data
84            .into_iter()
85            .map(|(key, SystemOfLinearEquation { a, b })| {
86                (
87                    key,
88                    Matrix::<R>::from_vec(a)
89                        .solve_system_of_linear_equations(&b)
90                        .map(|sol| sol.particular),
91                )
92            })
93            .collect();
94        EsperSolver {
95            class: self.class,
96            feature: self.feature,
97            data,
98            _marker: PhantomData,
99        }
100    }
101
102    pub fn solve_checked(self) -> EsperSolver<R, Input, Class, FC, FF>
103    where
104        Class: Debug,
105        R: Field<T: Debug, Additive: Invertible, Multiplicative: Invertible>,
106    {
107        let data: HashMap<_, _> = self
108            .data
109            .into_iter()
110            .map(|(key, SystemOfLinearEquation { a, b })| {
111                let mat = Matrix::<R>::from_vec(a);
112                let coeff = mat
113                    .solve_system_of_linear_equations(&b)
114                    .map(|sol| sol.particular);
115                if coeff.is_none() {
116                    eprintln!(
117                        "failed to solve linear equations: key={:?} A={:?} b={:?}",
118                        key, &mat.data, &b
119                    );
120                }
121                (key, coeff)
122            })
123            .collect();
124        EsperSolver {
125            class: self.class,
126            feature: self.feature,
127            data,
128            _marker: PhantomData,
129        }
130    }
131}
132
133impl<R, Input, Class, FC, FF> EsperSolver<R, Input, Class, FC, FF>
134where
135    R: Field<Additive: Invertible, Multiplicative: Invertible>,
136    Class: Eq + Hash,
137    FC: Fn(&Input) -> Class,
138    FF: Fn(&Input) -> Vec<R::T>,
139{
140    pub fn solve(&self, input: Input) -> R::T {
141        let coeff = self
142            .data
143            .get(&(self.class)(&input))
144            .expect("unrecognized class")
145            .as_ref()
146            .expect("failed to solve");
147        let feature = (self.feature)(&input);
148        feature
149            .into_iter()
150            .zip(coeff)
151            .map(|(x, y)| R::mul(&x, y))
152            .fold(R::zero(), |x, y| R::add(&x, &y))
153    }
154}