use num_traits::{Float, NumCast};
use polars_error::to_compute_err;
use rand::distributions::Bernoulli;
use rand::prelude::*;
use rand::seq::index::IndexVec;
use rand_distr::{Normal, Standard, StandardNormal, Uniform};
use crate::prelude::DataType::Float64;
use crate::prelude::*;
use crate::random::get_global_random_u64;
use crate::utils::NoNull;
fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
if len == 0 {
return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
}
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
let dist = Uniform::new(0, len as IdxSize);
(0..n as IdxSize)
.map(move |_| dist.sample(&mut rng))
.collect_trusted::<NoNull<IdxCa>>()
.into_inner()
}
fn create_rand_index_no_replacement(
n: usize,
len: usize,
seed: Option<u64>,
shuffle: bool,
) -> IdxCa {
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
let mut buf: Vec<IdxSize>;
if n == len {
buf = (0..len as IdxSize).collect();
if shuffle {
buf.shuffle(&mut rng)
}
} else {
buf = match rand::seq::index::sample(&mut rng, len, n) {
IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
IndexVec::USize(v) => v.into_iter().map(|x| x as IdxSize).collect(),
};
}
IdxCa::new_vec(PlSmallStr::EMPTY, buf)
}
impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
Standard: Distribution<T::Native>,
{
pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
(0..size)
.map(|_| {
if rng.gen::<f32>() < null_density {
None
} else {
Some(rng.gen())
}
})
.collect()
}
}
fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
polars_ensure!(
with_replacement || n <= len,
ShapeMismatch:
"cannot take a larger sample than the total population when `with_replacement=false`"
);
Ok(())
}
impl Series {
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
ensure_shape(n, self.len(), with_replacement)?;
if n == 0 {
return Ok(self.clear());
}
let len = self.len();
match with_replacement {
true => {
let idx = create_rand_index_with_replacement(n, len, seed);
debug_assert_eq!(len, self.len());
unsafe { Ok(self.take_unchecked(&idx)) }
},
false => {
let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
debug_assert_eq!(len, self.len());
unsafe { Ok(self.take_unchecked(&idx)) }
},
}
}
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
let n = (self.len() as f64 * frac) as usize;
self.sample_n(n, with_replacement, shuffle, seed)
}
pub fn shuffle(&self, seed: Option<u64>) -> Self {
let len = self.len();
let n = len;
let idx = create_rand_index_no_replacement(n, len, seed, true);
debug_assert_eq!(len, self.len());
unsafe { self.take_unchecked(&idx) }
}
}
impl<T> ChunkedArray<T>
where
T: PolarsDataType,
ChunkedArray<T>: ChunkTake<IdxCa>,
{
pub fn sample_n(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
ensure_shape(n, self.len(), with_replacement)?;
let len = self.len();
match with_replacement {
true => {
let idx = create_rand_index_with_replacement(n, len, seed);
debug_assert_eq!(len, self.len());
unsafe { Ok(self.take_unchecked(&idx)) }
},
false => {
let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
debug_assert_eq!(len, self.len());
unsafe { Ok(self.take_unchecked(&idx)) }
},
}
}
pub fn sample_frac(
&self,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
let n = (self.len() as f64 * frac) as usize;
self.sample_n(n, with_replacement, shuffle, seed)
}
}
impl DataFrame {
pub fn sample_n(
&self,
n: &Series,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
polars_ensure!(
n.len() == 1,
ComputeError: "Sample size must be a single value."
);
let n = n.cast(&IDX_DTYPE)?;
let n = n.idx()?;
match n.get(0) {
Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
None => Ok(self.clear()),
}
}
pub fn sample_n_literal(
&self,
n: usize,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
ensure_shape(n, self.height(), with_replacement)?;
let idx = match with_replacement {
true => create_rand_index_with_replacement(n, self.height(), seed),
false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
};
Ok(unsafe { self.take_unchecked(&idx) })
}
pub fn sample_frac(
&self,
frac: &Series,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
polars_ensure!(
frac.len() == 1,
ComputeError: "Sample fraction must be a single value."
);
let frac = frac.cast(&Float64)?;
let frac = frac.f64()?;
match frac.get(0) {
Some(frac) => {
let n = (self.height() as f64 * frac) as usize;
self.sample_n_literal(n, with_replacement, shuffle, seed)
},
None => Ok(self.clear()),
}
}
}
impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: Float,
{
pub fn rand_normal(
name: PlSmallStr,
length: usize,
mean: f64,
std_dev: f64,
) -> PolarsResult<Self> {
let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
let mut rng = rand::thread_rng();
for _ in 0..length {
let smpl = normal.sample(&mut rng);
let smpl = NumCast::from(smpl).unwrap();
builder.append_value(smpl)
}
Ok(builder.finish())
}
pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
let mut rng = rand::thread_rng();
for _ in 0..length {
let smpl: f64 = rng.sample(StandardNormal);
let smpl = NumCast::from(smpl).unwrap();
builder.append_value(smpl)
}
builder.finish()
}
pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
let uniform = Uniform::new(low, high);
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
let mut rng = rand::thread_rng();
for _ in 0..length {
let smpl = uniform.sample(&mut rng);
let smpl = NumCast::from(smpl).unwrap();
builder.append_value(smpl)
}
builder.finish()
}
}
impl BooleanChunked {
pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
let dist = Bernoulli::new(p).map_err(to_compute_err)?;
let mut rng = rand::thread_rng();
let mut builder = BooleanChunkedBuilder::new(name, length);
for _ in 0..length {
let smpl = dist.sample(&mut rng);
builder.append_value(smpl)
}
Ok(builder.finish())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_sample() {
let df = df![
"foo" => &[1, 2, 3, 4, 5]
]
.unwrap();
assert!(df
.sample_n(
&Series::new(PlSmallStr::from_static("s"), &[3]),
false,
false,
None
)
.is_ok());
assert!(df
.sample_frac(
&Series::new(PlSmallStr::from_static("frac"), &[0.4]),
false,
false,
None
)
.is_ok());
assert!(df
.sample_n(
&Series::new(PlSmallStr::from_static("s"), &[3]),
false,
false,
Some(0)
)
.is_ok());
assert!(df
.sample_frac(
&Series::new(PlSmallStr::from_static("frac"), &[0.4]),
false,
false,
Some(0)
)
.is_ok());
assert!(df
.sample_frac(
&Series::new(PlSmallStr::from_static("frac"), &[2.0]),
false,
false,
Some(0)
)
.is_err());
assert!(df
.sample_n(
&Series::new(PlSmallStr::from_static("s"), &[3]),
true,
false,
Some(0)
)
.is_ok());
assert!(df
.sample_frac(
&Series::new(PlSmallStr::from_static("frac"), &[0.4]),
true,
false,
Some(0)
)
.is_ok());
assert!(df
.sample_frac(
&Series::new(PlSmallStr::from_static("frac"), &[2.0]),
true,
false,
Some(0)
)
.is_ok());
}
}