1use num_traits::{Float, NumCast};
2use polars_error::to_compute_err;
3use rand::distributions::Bernoulli;
4use rand::prelude::*;
5use rand::seq::index::IndexVec;
6use rand_distr::{Normal, Standard, StandardNormal, Uniform};
7
8use crate::prelude::DataType::Float64;
9use crate::prelude::*;
10use crate::random::get_global_random_u64;
11use crate::utils::NoNull;
12
13fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14 if len == 0 {
15 return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16 }
17 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18 let dist = Uniform::new(0, len as IdxSize);
19 (0..n as IdxSize)
20 .map(move |_| dist.sample(&mut rng))
21 .collect_trusted::<NoNull<IdxCa>>()
22 .into_inner()
23}
24
25fn create_rand_index_no_replacement(
26 n: usize,
27 len: usize,
28 seed: Option<u64>,
29 shuffle: bool,
30) -> IdxCa {
31 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32 let mut buf: Vec<IdxSize>;
33 if n == len {
34 buf = (0..len as IdxSize).collect();
35 if shuffle {
36 buf.shuffle(&mut rng)
37 }
38 } else {
39 buf = match rand::seq::index::sample(&mut rng, len, n) {
44 IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45 IndexVec::USize(v) => v.into_iter().map(|x| x as IdxSize).collect(),
46 };
47 }
48 IdxCa::new_vec(PlSmallStr::EMPTY, buf)
49}
50
51impl<T> ChunkedArray<T>
52where
53 T: PolarsNumericType,
54 Standard: Distribution<T::Native>,
55{
56 pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
57 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
58 (0..size)
59 .map(|_| {
60 if rng.r#gen::<f32>() < null_density {
61 None
62 } else {
63 Some(rng.r#gen())
64 }
65 })
66 .collect()
67 }
68}
69
70fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
71 polars_ensure!(
72 with_replacement || n <= len,
73 ShapeMismatch:
74 "cannot take a larger sample than the total population when `with_replacement=false`"
75 );
76 Ok(())
77}
78
79impl Series {
80 pub fn sample_n(
81 &self,
82 n: usize,
83 with_replacement: bool,
84 shuffle: bool,
85 seed: Option<u64>,
86 ) -> PolarsResult<Self> {
87 ensure_shape(n, self.len(), with_replacement)?;
88 if n == 0 {
89 return Ok(self.clear());
90 }
91 let len = self.len();
92
93 match with_replacement {
94 true => {
95 let idx = create_rand_index_with_replacement(n, len, seed);
96 debug_assert_eq!(len, self.len());
97 unsafe { Ok(self.take_unchecked(&idx)) }
99 },
100 false => {
101 let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
102 debug_assert_eq!(len, self.len());
103 unsafe { Ok(self.take_unchecked(&idx)) }
105 },
106 }
107 }
108
109 pub fn sample_frac(
111 &self,
112 frac: f64,
113 with_replacement: bool,
114 shuffle: bool,
115 seed: Option<u64>,
116 ) -> PolarsResult<Self> {
117 let n = (self.len() as f64 * frac) as usize;
118 self.sample_n(n, with_replacement, shuffle, seed)
119 }
120
121 pub fn shuffle(&self, seed: Option<u64>) -> Self {
122 let len = self.len();
123 let n = len;
124 let idx = create_rand_index_no_replacement(n, len, seed, true);
125 debug_assert_eq!(len, self.len());
126 unsafe { self.take_unchecked(&idx) }
128 }
129}
130
131impl<T> ChunkedArray<T>
132where
133 T: PolarsDataType,
134 ChunkedArray<T>: ChunkTake<IdxCa>,
135{
136 pub fn sample_n(
138 &self,
139 n: usize,
140 with_replacement: bool,
141 shuffle: bool,
142 seed: Option<u64>,
143 ) -> PolarsResult<Self> {
144 ensure_shape(n, self.len(), with_replacement)?;
145 let len = self.len();
146
147 match with_replacement {
148 true => {
149 let idx = create_rand_index_with_replacement(n, len, seed);
150 debug_assert_eq!(len, self.len());
151 unsafe { Ok(self.take_unchecked(&idx)) }
153 },
154 false => {
155 let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
156 debug_assert_eq!(len, self.len());
157 unsafe { Ok(self.take_unchecked(&idx)) }
159 },
160 }
161 }
162
163 pub fn sample_frac(
165 &self,
166 frac: f64,
167 with_replacement: bool,
168 shuffle: bool,
169 seed: Option<u64>,
170 ) -> PolarsResult<Self> {
171 let n = (self.len() as f64 * frac) as usize;
172 self.sample_n(n, with_replacement, shuffle, seed)
173 }
174}
175
176impl DataFrame {
177 pub fn sample_n(
179 &self,
180 n: &Series,
181 with_replacement: bool,
182 shuffle: bool,
183 seed: Option<u64>,
184 ) -> PolarsResult<Self> {
185 polars_ensure!(
186 n.len() == 1,
187 ComputeError: "Sample size must be a single value."
188 );
189
190 let n = n.cast(&IDX_DTYPE)?;
191 let n = n.idx()?;
192
193 match n.get(0) {
194 Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
195 None => Ok(self.clear()),
196 }
197 }
198
199 pub fn sample_n_literal(
200 &self,
201 n: usize,
202 with_replacement: bool,
203 shuffle: bool,
204 seed: Option<u64>,
205 ) -> PolarsResult<Self> {
206 ensure_shape(n, self.height(), with_replacement)?;
207 let idx = match with_replacement {
209 true => create_rand_index_with_replacement(n, self.height(), seed),
210 false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
211 };
212 Ok(unsafe { self.take_unchecked(&idx) })
214 }
215
216 pub fn sample_frac(
218 &self,
219 frac: &Series,
220 with_replacement: bool,
221 shuffle: bool,
222 seed: Option<u64>,
223 ) -> PolarsResult<Self> {
224 polars_ensure!(
225 frac.len() == 1,
226 ComputeError: "Sample fraction must be a single value."
227 );
228
229 let frac = frac.cast(&Float64)?;
230 let frac = frac.f64()?;
231
232 match frac.get(0) {
233 Some(frac) => {
234 let n = (self.height() as f64 * frac) as usize;
235 self.sample_n_literal(n, with_replacement, shuffle, seed)
236 },
237 None => Ok(self.clear()),
238 }
239 }
240}
241
242impl<T> ChunkedArray<T>
243where
244 T: PolarsNumericType,
245 T::Native: Float,
246{
247 pub fn rand_normal(
249 name: PlSmallStr,
250 length: usize,
251 mean: f64,
252 std_dev: f64,
253 ) -> PolarsResult<Self> {
254 let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
255 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
256 let mut rng = rand::thread_rng();
257 for _ in 0..length {
258 let smpl = normal.sample(&mut rng);
259 let smpl = NumCast::from(smpl).unwrap();
260 builder.append_value(smpl)
261 }
262 Ok(builder.finish())
263 }
264
265 pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
267 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
268 let mut rng = rand::thread_rng();
269 for _ in 0..length {
270 let smpl: f64 = rng.sample(StandardNormal);
271 let smpl = NumCast::from(smpl).unwrap();
272 builder.append_value(smpl)
273 }
274 builder.finish()
275 }
276
277 pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
279 let uniform = Uniform::new(low, high);
280 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
281 let mut rng = rand::thread_rng();
282 for _ in 0..length {
283 let smpl = uniform.sample(&mut rng);
284 let smpl = NumCast::from(smpl).unwrap();
285 builder.append_value(smpl)
286 }
287 builder.finish()
288 }
289}
290
291impl BooleanChunked {
292 pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
294 let dist = Bernoulli::new(p).map_err(to_compute_err)?;
295 let mut rng = rand::thread_rng();
296 let mut builder = BooleanChunkedBuilder::new(name, length);
297 for _ in 0..length {
298 let smpl = dist.sample(&mut rng);
299 builder.append_value(smpl)
300 }
301 Ok(builder.finish())
302 }
303}
304
305#[cfg(test)]
306mod test {
307 use super::*;
308
309 #[test]
310 fn test_sample() {
311 let df = df![
312 "foo" => &[1, 2, 3, 4, 5]
313 ]
314 .unwrap();
315
316 assert!(
318 df.sample_n(
319 &Series::new(PlSmallStr::from_static("s"), &[3]),
320 false,
321 false,
322 None
323 )
324 .is_ok()
325 );
326 assert!(
327 df.sample_frac(
328 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
329 false,
330 false,
331 None
332 )
333 .is_ok()
334 );
335 assert!(
337 df.sample_n(
338 &Series::new(PlSmallStr::from_static("s"), &[3]),
339 false,
340 false,
341 Some(0)
342 )
343 .is_ok()
344 );
345 assert!(
346 df.sample_frac(
347 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
348 false,
349 false,
350 Some(0)
351 )
352 .is_ok()
353 );
354 assert!(
356 df.sample_frac(
357 &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
358 false,
359 false,
360 Some(0)
361 )
362 .is_err()
363 );
364 assert!(
365 df.sample_n(
366 &Series::new(PlSmallStr::from_static("s"), &[3]),
367 true,
368 false,
369 Some(0)
370 )
371 .is_ok()
372 );
373 assert!(
374 df.sample_frac(
375 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
376 true,
377 false,
378 Some(0)
379 )
380 .is_ok()
381 );
382 assert!(
384 df.sample_frac(
385 &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
386 true,
387 false,
388 Some(0)
389 )
390 .is_ok()
391 );
392 }
393}