1use num_traits::{Float, NumCast};
2use polars_error::to_compute_err;
3use rand::distr::Bernoulli;
4use rand::prelude::*;
5use rand::seq::index::IndexVec;
6use rand_distr::{Normal, StandardNormal, StandardUniform, Uniform};
7
8use crate::prelude::DataType::Float64;
9use crate::prelude::*;
10use crate::random::get_global_random_u64;
11use crate::utils::NoNull;
12
13fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14 if len == 0 {
15 return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16 }
17 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18 let dist = Uniform::new(0, len as IdxSize).unwrap();
19 (0..n as IdxSize)
20 .map(move |_| dist.sample(&mut rng))
21 .collect_trusted::<NoNull<IdxCa>>()
22 .into_inner()
23}
24
25fn create_rand_index_no_replacement(
26 n: usize,
27 len: usize,
28 seed: Option<u64>,
29 shuffle: bool,
30) -> IdxCa {
31 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32 let mut buf: Vec<IdxSize>;
33 if n == len {
34 buf = (0..len as IdxSize).collect();
35 if shuffle {
36 buf.shuffle(&mut rng)
37 }
38 } else {
39 buf = match rand::seq::index::sample(&mut rng, len, n) {
44 IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45 #[cfg(target_pointer_width = "64")]
46 IndexVec::U64(v) => v.into_iter().map(|x| x as IdxSize).collect(),
47 };
48 }
49 IdxCa::new_vec(PlSmallStr::EMPTY, buf)
50}
51
52impl<T> ChunkedArray<T>
53where
54 T: PolarsNumericType,
55 StandardUniform: Distribution<T::Native>,
56{
57 pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
58 let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
59 (0..size)
60 .map(|_| {
61 if rng.random::<f32>() < null_density {
62 None
63 } else {
64 Some(rng.random())
65 }
66 })
67 .collect()
68 }
69}
70
71fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
72 polars_ensure!(
73 with_replacement || n <= len,
74 ShapeMismatch:
75 "cannot take a larger sample than the total population when `with_replacement=false`"
76 );
77 Ok(())
78}
79
80impl Series {
81 pub fn sample_n(
82 &self,
83 n: usize,
84 with_replacement: bool,
85 shuffle: bool,
86 seed: Option<u64>,
87 ) -> PolarsResult<Self> {
88 ensure_shape(n, self.len(), with_replacement)?;
89 if n == 0 {
90 return Ok(self.clear());
91 }
92 let len = self.len();
93
94 match with_replacement {
95 true => {
96 let idx = create_rand_index_with_replacement(n, len, seed);
97 debug_assert_eq!(len, self.len());
98 unsafe { Ok(self.take_unchecked(&idx)) }
100 },
101 false => {
102 let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
103 debug_assert_eq!(len, self.len());
104 unsafe { Ok(self.take_unchecked(&idx)) }
106 },
107 }
108 }
109
110 pub fn sample_frac(
112 &self,
113 frac: f64,
114 with_replacement: bool,
115 shuffle: bool,
116 seed: Option<u64>,
117 ) -> PolarsResult<Self> {
118 let n = (self.len() as f64 * frac) as usize;
119 self.sample_n(n, with_replacement, shuffle, seed)
120 }
121
122 pub fn shuffle(&self, seed: Option<u64>) -> Self {
123 let len = self.len();
124 let n = len;
125 let idx = create_rand_index_no_replacement(n, len, seed, true);
126 debug_assert_eq!(len, self.len());
127 unsafe { self.take_unchecked(&idx) }
129 }
130}
131
132impl<T> ChunkedArray<T>
133where
134 T: PolarsDataType,
135 ChunkedArray<T>: ChunkTake<IdxCa>,
136{
137 pub fn sample_n(
139 &self,
140 n: usize,
141 with_replacement: bool,
142 shuffle: bool,
143 seed: Option<u64>,
144 ) -> PolarsResult<Self> {
145 ensure_shape(n, self.len(), with_replacement)?;
146 let len = self.len();
147
148 match with_replacement {
149 true => {
150 let idx = create_rand_index_with_replacement(n, len, seed);
151 debug_assert_eq!(len, self.len());
152 unsafe { Ok(self.take_unchecked(&idx)) }
154 },
155 false => {
156 let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
157 debug_assert_eq!(len, self.len());
158 unsafe { Ok(self.take_unchecked(&idx)) }
160 },
161 }
162 }
163
164 pub fn sample_frac(
166 &self,
167 frac: f64,
168 with_replacement: bool,
169 shuffle: bool,
170 seed: Option<u64>,
171 ) -> PolarsResult<Self> {
172 let n = (self.len() as f64 * frac) as usize;
173 self.sample_n(n, with_replacement, shuffle, seed)
174 }
175}
176
177impl DataFrame {
178 pub fn sample_n(
180 &self,
181 n: &Series,
182 with_replacement: bool,
183 shuffle: bool,
184 seed: Option<u64>,
185 ) -> PolarsResult<Self> {
186 polars_ensure!(
187 n.len() == 1,
188 ComputeError: "Sample size must be a single value."
189 );
190
191 let n = n.cast(&IDX_DTYPE)?;
192 let n = n.idx()?;
193
194 match n.get(0) {
195 Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
196 None => Ok(self.clear()),
197 }
198 }
199
200 pub fn sample_n_literal(
201 &self,
202 n: usize,
203 with_replacement: bool,
204 shuffle: bool,
205 seed: Option<u64>,
206 ) -> PolarsResult<Self> {
207 ensure_shape(n, self.height(), with_replacement)?;
208 let idx = match with_replacement {
210 true => create_rand_index_with_replacement(n, self.height(), seed),
211 false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
212 };
213 Ok(unsafe { self.take_unchecked(&idx) })
215 }
216
217 pub fn sample_frac(
219 &self,
220 frac: &Series,
221 with_replacement: bool,
222 shuffle: bool,
223 seed: Option<u64>,
224 ) -> PolarsResult<Self> {
225 polars_ensure!(
226 frac.len() == 1,
227 ComputeError: "Sample fraction must be a single value."
228 );
229
230 let frac = frac.cast(&Float64)?;
231 let frac = frac.f64()?;
232
233 match frac.get(0) {
234 Some(frac) => {
235 let n = (self.height() as f64 * frac) as usize;
236 self.sample_n_literal(n, with_replacement, shuffle, seed)
237 },
238 None => Ok(self.clear()),
239 }
240 }
241}
242
243impl<T> ChunkedArray<T>
244where
245 T: PolarsNumericType,
246 T::Native: Float,
247{
248 pub fn rand_normal(
250 name: PlSmallStr,
251 length: usize,
252 mean: f64,
253 std_dev: f64,
254 ) -> PolarsResult<Self> {
255 let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
256 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
257 let mut rng = rand::rng();
258 for _ in 0..length {
259 let smpl = normal.sample(&mut rng);
260 let smpl = NumCast::from(smpl).unwrap();
261 builder.append_value(smpl)
262 }
263 Ok(builder.finish())
264 }
265
266 pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
268 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
269 let mut rng = rand::rng();
270 for _ in 0..length {
271 let smpl: f64 = rng.sample(StandardNormal);
272 let smpl = NumCast::from(smpl).unwrap();
273 builder.append_value(smpl)
274 }
275 builder.finish()
276 }
277
278 pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
280 let uniform = Uniform::new(low, high).unwrap();
281 let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
282 let mut rng = rand::rng();
283 for _ in 0..length {
284 let smpl = uniform.sample(&mut rng);
285 let smpl = NumCast::from(smpl).unwrap();
286 builder.append_value(smpl)
287 }
288 builder.finish()
289 }
290}
291
292impl BooleanChunked {
293 pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
295 let dist = Bernoulli::new(p).map_err(to_compute_err)?;
296 let mut rng = rand::rng();
297 let mut builder = BooleanChunkedBuilder::new(name, length);
298 for _ in 0..length {
299 let smpl = dist.sample(&mut rng);
300 builder.append_value(smpl)
301 }
302 Ok(builder.finish())
303 }
304}
305
306#[cfg(test)]
307mod test {
308 use super::*;
309
310 #[test]
311 fn test_sample() {
312 let df = df![
313 "foo" => &[1, 2, 3, 4, 5]
314 ]
315 .unwrap();
316
317 assert!(
319 df.sample_n(
320 &Series::new(PlSmallStr::from_static("s"), &[3]),
321 false,
322 false,
323 None
324 )
325 .is_ok()
326 );
327 assert!(
328 df.sample_frac(
329 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
330 false,
331 false,
332 None
333 )
334 .is_ok()
335 );
336 assert!(
338 df.sample_n(
339 &Series::new(PlSmallStr::from_static("s"), &[3]),
340 false,
341 false,
342 Some(0)
343 )
344 .is_ok()
345 );
346 assert!(
347 df.sample_frac(
348 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
349 false,
350 false,
351 Some(0)
352 )
353 .is_ok()
354 );
355 assert!(
357 df.sample_frac(
358 &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
359 false,
360 false,
361 Some(0)
362 )
363 .is_err()
364 );
365 assert!(
366 df.sample_n(
367 &Series::new(PlSmallStr::from_static("s"), &[3]),
368 true,
369 false,
370 Some(0)
371 )
372 .is_ok()
373 );
374 assert!(
375 df.sample_frac(
376 &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
377 true,
378 false,
379 Some(0)
380 )
381 .is_ok()
382 );
383 assert!(
385 df.sample_frac(
386 &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
387 true,
388 false,
389 Some(0)
390 )
391 .is_ok()
392 );
393 }
394}