1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
//! Gaussian distribution module. //! //! Contains extension methods for the Normal struct //! found in the rand crate. This is provided through //! traits added within the containing stats module. use stats::dist::Distribution; use rand::Rng; use rand::distributions::{Sample, IndependentSample}; use rand::distributions::normal::StandardNormal; use super::consts as stat_consts; use std::f64::consts as float_consts; /// A Gaussian random variable. /// /// This struct stores both the variance and the standard deviation. /// This is to minimize the computation required for computing /// the distribution functions and sampling. /// /// It is most efficient to construct the struct using the `from_std_dev` constructor. #[derive(Debug, Clone, Copy)] pub struct Gaussian { mean: f64, variance: f64, _std_dev: f64, } /// The default Gaussian random variable. /// This is the Standard Normal random variable. /// /// The defaults are: /// /// - mean = 0 /// - variance = 1 impl Default for Gaussian { fn default() -> Gaussian { Gaussian { mean: 0f64, variance: 1f64, _std_dev: 1f64, } } } impl Gaussian { /// Creates a new Gaussian random variable from /// a given mean and variance. pub fn new(mean: f64, variance: f64) -> Gaussian { Gaussian { mean: mean, variance: variance, _std_dev: variance.sqrt(), } } /// Creates a new Gaussian random variable from /// a given mean and standard deviation. pub fn from_std_dev(mean: f64, std_dev: f64) -> Gaussian { Gaussian { mean: mean, variance: std_dev * std_dev, _std_dev: std_dev, } } } /// The distribution of the gaussian random variable. /// /// Accurately computes the PDF and log PDF. /// Estimates the CDF accurate only to 0.003. impl Distribution<f64> for Gaussian { /// The pdf of the normal distribution /// /// # Examples /// /// ``` /// use rusty_machine::stats::dist::Gaussian; /// use rusty_machine::stats::dist::Distribution; /// use rusty_machine::stats::dist::consts; /// /// let gauss = Gaussian::default(); /// /// let lpdf_zero = gauss.pdf(0f64); /// /// // The value should be very close to 1/sqrt(2 * pi) /// assert!((lpdf_zero - (1f64/consts::SQRT_2_PI).abs()) < 1e-20); /// ``` fn pdf(&self, x: f64) -> f64 { (-(x - self.mean) * (x - self.mean) / (2.0 * self.variance)).exp() / (stat_consts::SQRT_2_PI * self._std_dev) } /// The log pdf of the normal distribution. /// /// # Examples /// /// ``` /// use rusty_machine::stats::dist::Gaussian; /// use rusty_machine::stats::dist::Distribution; /// use rusty_machine::stats::dist::consts; /// /// let gauss = Gaussian::default(); /// /// let lpdf_zero = gauss.logpdf(0f64); /// /// // The value should be very close to -0.5*Ln(2 * pi) /// assert!((lpdf_zero + 0.5*consts::LN_2_PI).abs() < 1e-20); /// ``` fn logpdf(&self, x: f64) -> f64 { -self._std_dev.ln() - (stat_consts::LN_2_PI / 2.0) - ((x - self.mean) * (x - self.mean) / (2.0 * self.variance)) } /// Rough estimate for the cdf of the gaussian distribution. /// Accurate to 0.003. /// /// # Examples /// /// ``` /// use rusty_machine::stats::dist::Gaussian; /// use rusty_machine::stats::dist::Distribution; /// /// let gauss = Gaussian::new(10f64, 5f64); /// let cdf_mid = gauss.cdf(10f64); /// /// assert!((0.5 - cdf_mid).abs() < 0.004); /// ``` /// /// A slightly more involved test: /// /// ``` /// use rusty_machine::stats::dist::Gaussian; /// use rusty_machine::stats::dist::Distribution; /// /// let gauss = Gaussian::new(10f64, 4f64); /// let cdf = gauss.cdf(9f64); /// /// assert!((0.5*(1f64 - 0.382924922548) - cdf).abs() < 0.004); /// ``` fn cdf(&self, x: f64) -> f64 { 0.5 * (1f64 + (x - self.mean).signum() * (1f64 - (-float_consts::FRAC_2_PI * (x - self.mean) * (x - self.mean) / self.variance).exp()) .sqrt()) } } impl Sample<f64> for Gaussian { fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 { self.ind_sample(rng) } } impl IndependentSample<f64> for Gaussian { fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 { let StandardNormal(n) = rng.gen::<StandardNormal>(); self.mean + self._std_dev * n } }