










use num_traits::{Float, FloatConst};
use crate::{Cauchy, Distribution, Standard};
use rand::Rng;
use core::fmt;






/// # Example








#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Poisson<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
    lambda: F,
    
    exp_lambda: F,
    log_lambda: F,
    sqrt_2lambda: F,
    magic_val: F,
}


#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
    
    ShapeTooSmall,
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(match self {
            Error::ShapeTooSmall => "lambda is not positive in Poisson distribution",
        })
    }
}

#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}

impl<F> Poisson<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
    
    
    pub fn new(lambda: F) -> Result<Poisson<F>, Error> {
        if !(lambda > F::zero()) {
            return Err(Error::ShapeTooSmall);
        }
        let log_lambda = lambda.ln();
        Ok(Poisson {
            lambda,
            exp_lambda: (-lambda).exp(),
            log_lambda,
            sqrt_2lambda: (F::from(2.0).unwrap() * lambda).sqrt(),
            magic_val: lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda),
        })
    }
}

impl<F> Distribution<F> for Poisson<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
    #[inline]
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
        

        
        if self.lambda < F::from(12.0).unwrap() {
            let mut result = F::zero();
            let mut p = F::one();
            while p > self.exp_lambda {
                p = p*rng.gen::<F>();
                result = result + F::one();
            }
            result - F::one()
        }
        
        else {
            
            
            let cauchy = Cauchy::new(F::zero(), F::one()).unwrap();
            let mut result;

            loop {
                let mut comp_dev;

                loop {
                    
                    comp_dev = rng.sample(cauchy);
                    
                    result = self.sqrt_2lambda * comp_dev + self.lambda;
                    
                    if result >= F::zero() {
                        break;
                    }
                }
                
                
                result = result.floor();

                
                
                
                
                let check = F::from(0.9).unwrap()
                    * (F::one() + comp_dev * comp_dev)
                    * (result * self.log_lambda
                        - crate::utils::log_gamma(F::one() + result)
                        - self.magic_val)
                        .exp();

                
                if rng.gen::<F>() <= check {
                    break;
                }
            }
            result
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F)
        where Standard: Distribution<F>
    {
        let poisson = Poisson::new(lambda).unwrap();
        let mut rng = crate::test::rng(123);
        let mut sum = F::zero();
        for _ in 0..1000 {
            sum = sum + poisson.sample(&mut rng);
        }
        let avg = sum / F::from(1000.0).unwrap();
        assert!((avg - lambda).abs() < tol);
    }

    #[test]
    fn test_poisson_avg() {
        test_poisson_avg_gen::<f64>(10.0, 0.5);
        test_poisson_avg_gen::<f64>(15.0, 0.5);
        test_poisson_avg_gen::<f32>(10.0, 0.5);
        test_poisson_avg_gen::<f32>(15.0, 0.5);
    }

    #[test]
    #[should_panic]
    fn test_poisson_invalid_lambda_zero() {
        Poisson::new(0.0).unwrap();
    }

    #[test]
    #[should_panic]
    fn test_poisson_invalid_lambda_neg() {
        Poisson::new(-10.0).unwrap();
    }
}
