1use num_traits::{Float, NumCast};
2use polars_error::to_compute_err;
3use rand::distr::Bernoulli;
4use rand::prelude::*;
5use rand::seq::index::IndexVec;
6use rand_distr::{Normal, StandardNormal, StandardUniform, Uniform};
7
8use crate::prelude::DataType::Float64;
9use crate::prelude::*;
10use crate::random::get_global_random_u64;
11use crate::utils::NoNull;
12
13fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14 if len == 0 {
15 return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16 }
17 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18 let dist = Uniform::new(0, len as IdxSize).unwrap();
19 (0..n as IdxSize)
20 .map(move |_| dist.sample(&mut rng))
21 .collect_trusted::<NoNull<IdxCa>>()
22 .into_inner()
23}
24
25fn create_rand_index_no_replacement(
26 n: usize,
27 len: usize,
28 seed: Option<u64>,
29 shuffle: bool,
30) -> IdxCa {
31 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32 let mut buf: Vec<IdxSize>;
33 if n == len {
34 buf = (0..len as IdxSize).collect();
35 if shuffle {
36 buf.shuffle(&mut rng)
37 }
38 } else {
39 buf = match rand::seq::index::sample(&mut rng, len, n) {
44 IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45 #[cfg(target_pointer_width = "64")]
46 IndexVec::U64(v) => v.into_iter().map(|x| x as IdxSize).collect(),
47 };
48 if !shuffle {
49 buf.sort_unstable();
50 }
51 }
52 IdxCa::new_vec(PlSmallStr::EMPTY, buf)
53}
54
55impl<T> ChunkedArray<T>
56where
57 T: PolarsNumericType,
58 StandardUniform: Distribution<T::Native>,
59{
60 pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
61 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
62 (0..size)
63 .map(|_| {
64 if rng.random::<f32>() < null_density {
65 None
66 } else {
67 Some(rng.random())
68 }
69 })
70 .collect()
71 }
72}
73
74fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
75 polars_ensure!(
76 with_replacement || n <= len,
77 ShapeMismatch:
78 "cannot take a larger sample than the total population when `with_replacement=false`"
79 );
80 Ok(())
81}
82
83impl Series {
84 pub fn sample_n(
85 &self,
86 n: usize,
87 with_replacement: bool,
88 shuffle: bool,
89 seed: Option<u64>,
90 ) -> PolarsResult<Self> {
91 ensure_shape(n, self.len(), with_replacement)?;
92 if n == 0 {
93 return Ok(self.clear());
94 }
95 let len = self.len();
96
97 match with_replacement {
98 true => {
99 let idx = create_rand_index_with_replacement(n, len, seed);
100 debug_assert_eq!(len, self.len());
101 unsafe { Ok(self.take_unchecked(&idx)) }
103 },
104 false => {
105 let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
106 debug_assert_eq!(len, self.len());
107 unsafe { Ok(self.take_unchecked(&idx)) }
109 },
110 }
111 }
112
113 pub fn sample_frac(
115 &self,
116 frac: f64,
117 with_replacement: bool,
118 shuffle: bool,
119 seed: Option<u64>,
120 ) -> PolarsResult<Self> {
121 let n = (self.len() as f64 * frac) as usize;
122 self.sample_n(n, with_replacement, shuffle, seed)
123 }
124
125 pub fn shuffle(&self, seed: Option<u64>) -> Self {
126 let len = self.len();
127 let n = len;
128 let idx = create_rand_index_no_replacement(n, len, seed, true);
129 debug_assert_eq!(len, self.len());
130 unsafe { self.take_unchecked(&idx) }
132 }
133}
134
135impl<T> ChunkedArray<T>
136where
137 T: PolarsDataType,
138 ChunkedArray<T>: ChunkTake<IdxCa>,
139{
140 pub fn sample_n(
142 &self,
143 n: usize,
144 with_replacement: bool,
145 shuffle: bool,
146 seed: Option<u64>,
147 ) -> PolarsResult<Self> {
148 ensure_shape(n, self.len(), with_replacement)?;
149 let len = self.len();
150
151 match with_replacement {
152 true => {
153 let idx = create_rand_index_with_replacement(n, len, seed);
154 debug_assert_eq!(len, self.len());
155 unsafe { Ok(self.take_unchecked(&idx)) }
157 },
158 false => {
159 let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
160 debug_assert_eq!(len, self.len());
161 unsafe { Ok(self.take_unchecked(&idx)) }
163 },
164 }
165 }
166
167 pub fn sample_frac(
169 &self,
170 frac: f64,
171 with_replacement: bool,
172 shuffle: bool,
173 seed: Option<u64>,
174 ) -> PolarsResult<Self> {
175 let n = (self.len() as f64 * frac) as usize;
176 self.sample_n(n, with_replacement, shuffle, seed)
177 }
178}
179
180impl DataFrame {
181 pub fn sample_n(
183 &self,
184 n: &Series,
185 with_replacement: bool,
186 shuffle: bool,
187 seed: Option<u64>,
188 ) -> PolarsResult<Self> {
189 polars_ensure!(
190 n.len() == 1,
191 ComputeError: "Sample size must be a single value."
192 );
193
194 let n = n.strict_cast(&IDX_DTYPE)?;
195 let n = n.idx()?;
196
197 match n.get(0) {
198 Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
199 None => Ok(self.clear()),
200 }
201 }
202
203 pub fn sample_n_literal(
204 &self,
205 n: usize,
206 with_replacement: bool,
207 shuffle: bool,
208 seed: Option<u64>,
209 ) -> PolarsResult<Self> {
210 ensure_shape(n, self.height(), with_replacement)?;
211 let idx = match with_replacement {
213 true => create_rand_index_with_replacement(n, self.height(), seed),
214 false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
215 };
216 Ok(unsafe { self.take_unchecked(&idx) })
218 }
219
220 pub fn sample_frac(
222 &self,
223 frac: &Series,
224 with_replacement: bool,
225 shuffle: bool,
226 seed: Option<u64>,
227 ) -> PolarsResult<Self> {
228 polars_ensure!(
229 frac.len() == 1,
230 ComputeError: "Sample fraction must be a single value."
231 );
232
233 let frac = frac.cast(&Float64)?;
234 let frac = frac.f64()?;
235
236 match frac.get(0) {
237 Some(frac) => {
238 let n = (self.height() as f64 * frac) as usize;
239 self.sample_n_literal(n, with_replacement, shuffle, seed)
240 },
241 None => Ok(self.clear()),
242 }
243 }
244}
245
246impl<T> ChunkedArray<T>
247where
248 T: PolarsNumericType,
249 T::Native: Float,
250{
251 pub fn rand_normal(
253 name: PlSmallStr,
254 length: usize,
255 mean: f64,
256 std_dev: f64,
257 ) -> PolarsResult<Self> {
258 let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
259 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
260 let mut rng = rand::rng();
261 for _ in 0..length {
262 let smpl = normal.sample(&mut rng);
263 let smpl = NumCast::from(smpl).unwrap();
264 builder.append_value(smpl)
265 }
266 Ok(builder.finish())
267 }
268
269 pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
271 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
272 let mut rng = rand::rng();
273 for _ in 0..length {
274 let smpl: f64 = rng.sample(StandardNormal);
275 let smpl = NumCast::from(smpl).unwrap();
276 builder.append_value(smpl)
277 }
278 builder.finish()
279 }
280
281 pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
283 let uniform = Uniform::new(low, high).unwrap();
284 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
285 let mut rng = rand::rng();
286 for _ in 0..length {
287 let smpl = uniform.sample(&mut rng);
288 let smpl = NumCast::from(smpl).unwrap();
289 builder.append_value(smpl)
290 }
291 builder.finish()
292 }
293}
294
295impl BooleanChunked {
296 pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
298 let dist = Bernoulli::new(p).map_err(to_compute_err)?;
299 let mut rng = rand::rng();
300 let mut builder = BooleanChunkedBuilder::new(name, length);
301 for _ in 0..length {
302 let smpl = dist.sample(&mut rng);
303 builder.append_value(smpl)
304 }
305 Ok(builder.finish())
306 }
307}
308
309#[cfg(test)]
310mod test {
311 use super::*;
312
313 #[test]
314 fn test_sample() {
315 let df = df![
316 "foo" => &[1, 2, 3, 4, 5]
317 ]
318 .unwrap();
319
320 assert!(
322 df.sample_n(
323 &Series::new(PlSmallStr::from_static("s"), &[3]),
324 false,
325 false,
326 None
327 )
328 .is_ok()
329 );
330 assert!(
331 df.sample_frac(
332 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
333 false,
334 false,
335 None
336 )
337 .is_ok()
338 );
339 assert!(
341 df.sample_n(
342 &Series::new(PlSmallStr::from_static("s"), &[3]),
343 false,
344 false,
345 Some(0)
346 )
347 .is_ok()
348 );
349 assert!(
350 df.sample_frac(
351 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
352 false,
353 false,
354 Some(0)
355 )
356 .is_ok()
357 );
358 assert!(
360 df.sample_frac(
361 &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
362 false,
363 false,
364 Some(0)
365 )
366 .is_err()
367 );
368 assert!(
369 df.sample_n(
370 &Series::new(PlSmallStr::from_static("s"), &[3]),
371 true,
372 false,
373 Some(0)
374 )
375 .is_ok()
376 );
377 assert!(
378 df.sample_frac(
379 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
380 true,
381 false,
382 Some(0)
383 )
384 .is_ok()
385 );
386 assert!(
388 df.sample_frac(
389 &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
390 true,
391 false,
392 Some(0)
393 )
394 .is_ok()
395 );
396 }
397}