Skip to main content

polars_core/chunked_array/
random.rs

1use num_traits::{Float, NumCast};
2use polars_error::to_compute_err;
3use rand::distr::Bernoulli;
4use rand::prelude::*;
5use rand::seq::index::IndexVec;
6use rand_distr::{Normal, StandardNormal, StandardUniform, Uniform};
7
8use crate::prelude::DataType::Float64;
9use crate::prelude::*;
10use crate::random::get_global_random_u64;
11use crate::utils::NoNull;
12
13fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14    if len == 0 {
15        return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16    }
17    let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18    let dist = Uniform::new(0, len as IdxSize).unwrap();
19    (0..n as IdxSize)
20        .map(move |_| dist.sample(&mut rng))
21        .collect_trusted::<NoNull<IdxCa>>()
22        .into_inner()
23}
24
25fn create_rand_index_no_replacement(
26    n: usize,
27    len: usize,
28    seed: Option<u64>,
29    shuffle: bool,
30) -> IdxCa {
31    let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32    let mut buf: Vec<IdxSize>;
33    if n == len {
34        buf = (0..len as IdxSize).collect();
35        if shuffle {
36            buf.shuffle(&mut rng)
37        }
38    } else {
39        // TODO: avoid extra potential copy by vendoring rand::seq::index::sample,
40        // or genericize take over slices over any unsigned type. The optimizer
41        // should get rid of the extra copy already if IdxSize matches the IndexVec
42        // size returned.
43        buf = match rand::seq::index::sample(&mut rng, len, n) {
44            IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45            #[cfg(target_pointer_width = "64")]
46            IndexVec::U64(v) => v.into_iter().map(|x| x as IdxSize).collect(),
47        };
48        if !shuffle {
49            buf.sort_unstable();
50        }
51    }
52    IdxCa::new_vec(PlSmallStr::EMPTY, buf)
53}
54
55impl<T> ChunkedArray<T>
56where
57    T: PolarsNumericType,
58    StandardUniform: Distribution<T::Native>,
59{
60    pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
61        let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
62        (0..size)
63            .map(|_| {
64                if rng.random::<f32>() < null_density {
65                    None
66                } else {
67                    Some(rng.random())
68                }
69            })
70            .collect()
71    }
72}
73
74fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
75    polars_ensure!(
76        with_replacement || n <= len,
77        ShapeMismatch:
78        "cannot take a larger sample than the total population when `with_replacement=false`"
79    );
80    Ok(())
81}
82
83impl Series {
84    pub fn sample_n(
85        &self,
86        n: usize,
87        with_replacement: bool,
88        shuffle: bool,
89        seed: Option<u64>,
90    ) -> PolarsResult<Self> {
91        ensure_shape(n, self.len(), with_replacement)?;
92        if n == 0 {
93            return Ok(self.clear());
94        }
95        let len = self.len();
96
97        match with_replacement {
98            true => {
99                let idx = create_rand_index_with_replacement(n, len, seed);
100                debug_assert_eq!(len, self.len());
101                // SAFETY: we know that we never go out of bounds.
102                unsafe { Ok(self.take_unchecked(&idx)) }
103            },
104            false => {
105                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
106                debug_assert_eq!(len, self.len());
107                // SAFETY: we know that we never go out of bounds.
108                unsafe { Ok(self.take_unchecked(&idx)) }
109            },
110        }
111    }
112
113    /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
114    pub fn sample_frac(
115        &self,
116        frac: f64,
117        with_replacement: bool,
118        shuffle: bool,
119        seed: Option<u64>,
120    ) -> PolarsResult<Self> {
121        let n = (self.len() as f64 * frac) as usize;
122        self.sample_n(n, with_replacement, shuffle, seed)
123    }
124
125    pub fn shuffle(&self, seed: Option<u64>) -> Self {
126        let len = self.len();
127        let n = len;
128        let idx = create_rand_index_no_replacement(n, len, seed, true);
129        debug_assert_eq!(len, self.len());
130        // SAFETY: we know that we never go out of bounds.
131        unsafe { self.take_unchecked(&idx) }
132    }
133}
134
135impl<T> ChunkedArray<T>
136where
137    T: PolarsDataType,
138    ChunkedArray<T>: ChunkTake<IdxCa>,
139{
140    /// Sample n datapoints from this [`ChunkedArray`].
141    pub fn sample_n(
142        &self,
143        n: usize,
144        with_replacement: bool,
145        shuffle: bool,
146        seed: Option<u64>,
147    ) -> PolarsResult<Self> {
148        ensure_shape(n, self.len(), with_replacement)?;
149        let len = self.len();
150
151        match with_replacement {
152            true => {
153                let idx = create_rand_index_with_replacement(n, len, seed);
154                debug_assert_eq!(len, self.len());
155                // SAFETY: we know that we never go out of bounds.
156                unsafe { Ok(self.take_unchecked(&idx)) }
157            },
158            false => {
159                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
160                debug_assert_eq!(len, self.len());
161                // SAFETY: we know that we never go out of bounds.
162                unsafe { Ok(self.take_unchecked(&idx)) }
163            },
164        }
165    }
166
167    /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
168    pub fn sample_frac(
169        &self,
170        frac: f64,
171        with_replacement: bool,
172        shuffle: bool,
173        seed: Option<u64>,
174    ) -> PolarsResult<Self> {
175        let n = (self.len() as f64 * frac) as usize;
176        self.sample_n(n, with_replacement, shuffle, seed)
177    }
178}
179
180impl DataFrame {
181    /// Sample n datapoints from this [`DataFrame`].
182    pub fn sample_n(
183        &self,
184        n: &Series,
185        with_replacement: bool,
186        shuffle: bool,
187        seed: Option<u64>,
188    ) -> PolarsResult<Self> {
189        polars_ensure!(
190        n.len() == 1,
191        ComputeError: "Sample size must be a single value."
192        );
193
194        let n = n.strict_cast(&IDX_DTYPE)?;
195        let n = n.idx()?;
196
197        match n.get(0) {
198            Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
199            None => Ok(self.clear()),
200        }
201    }
202
203    pub fn sample_n_literal(
204        &self,
205        n: usize,
206        with_replacement: bool,
207        shuffle: bool,
208        seed: Option<u64>,
209    ) -> PolarsResult<Self> {
210        ensure_shape(n, self.height(), with_replacement)?;
211        // All columns should used the same indices. So we first create the indices.
212        let idx = match with_replacement {
213            true => create_rand_index_with_replacement(n, self.height(), seed),
214            false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
215        };
216        // SAFETY: the indices are within bounds.
217        Ok(unsafe { self.take_unchecked(&idx) })
218    }
219
220    /// Sample a fraction between 0.0-1.0 of this [`DataFrame`].
221    pub fn sample_frac(
222        &self,
223        frac: &Series,
224        with_replacement: bool,
225        shuffle: bool,
226        seed: Option<u64>,
227    ) -> PolarsResult<Self> {
228        polars_ensure!(
229        frac.len() == 1,
230        ComputeError: "Sample fraction must be a single value."
231        );
232
233        let frac = frac.cast(&Float64)?;
234        let frac = frac.f64()?;
235
236        match frac.get(0) {
237            Some(frac) => {
238                let n = (self.height() as f64 * frac) as usize;
239                self.sample_n_literal(n, with_replacement, shuffle, seed)
240            },
241            None => Ok(self.clear()),
242        }
243    }
244}
245
246impl<T> ChunkedArray<T>
247where
248    T: PolarsNumericType,
249    T::Native: Float,
250{
251    /// Create [`ChunkedArray`] with samples from a Normal distribution.
252    pub fn rand_normal(
253        name: PlSmallStr,
254        length: usize,
255        mean: f64,
256        std_dev: f64,
257    ) -> PolarsResult<Self> {
258        let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
259        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
260        let mut rng = rand::rng();
261        for _ in 0..length {
262            let smpl = normal.sample(&mut rng);
263            let smpl = NumCast::from(smpl).unwrap();
264            builder.append_value(smpl)
265        }
266        Ok(builder.finish())
267    }
268
269    /// Create [`ChunkedArray`] with samples from a Standard Normal distribution.
270    pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
271        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
272        let mut rng = rand::rng();
273        for _ in 0..length {
274            let smpl: f64 = rng.sample(StandardNormal);
275            let smpl = NumCast::from(smpl).unwrap();
276            builder.append_value(smpl)
277        }
278        builder.finish()
279    }
280
281    /// Create [`ChunkedArray`] with samples from a Uniform distribution.
282    pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
283        let uniform = Uniform::new(low, high).unwrap();
284        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
285        let mut rng = rand::rng();
286        for _ in 0..length {
287            let smpl = uniform.sample(&mut rng);
288            let smpl = NumCast::from(smpl).unwrap();
289            builder.append_value(smpl)
290        }
291        builder.finish()
292    }
293}
294
295impl BooleanChunked {
296    /// Create [`ChunkedArray`] with samples from a Bernoulli distribution.
297    pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
298        let dist = Bernoulli::new(p).map_err(to_compute_err)?;
299        let mut rng = rand::rng();
300        let mut builder = BooleanChunkedBuilder::new(name, length);
301        for _ in 0..length {
302            let smpl = dist.sample(&mut rng);
303            builder.append_value(smpl)
304        }
305        Ok(builder.finish())
306    }
307}
308
309#[cfg(test)]
310mod test {
311    use super::*;
312
313    #[test]
314    fn test_sample() {
315        let df = df![
316            "foo" => &[1, 2, 3, 4, 5]
317        ]
318        .unwrap();
319
320        // Default samples are random and don't require seeds.
321        assert!(
322            df.sample_n(
323                &Series::new(PlSmallStr::from_static("s"), &[3]),
324                false,
325                false,
326                None
327            )
328            .is_ok()
329        );
330        assert!(
331            df.sample_frac(
332                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
333                false,
334                false,
335                None
336            )
337            .is_ok()
338        );
339        // With seeding.
340        assert!(
341            df.sample_n(
342                &Series::new(PlSmallStr::from_static("s"), &[3]),
343                false,
344                false,
345                Some(0)
346            )
347            .is_ok()
348        );
349        assert!(
350            df.sample_frac(
351                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
352                false,
353                false,
354                Some(0)
355            )
356            .is_ok()
357        );
358        // Without replacement can not sample more than 100%.
359        assert!(
360            df.sample_frac(
361                &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
362                false,
363                false,
364                Some(0)
365            )
366            .is_err()
367        );
368        assert!(
369            df.sample_n(
370                &Series::new(PlSmallStr::from_static("s"), &[3]),
371                true,
372                false,
373                Some(0)
374            )
375            .is_ok()
376        );
377        assert!(
378            df.sample_frac(
379                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
380                true,
381                false,
382                Some(0)
383            )
384            .is_ok()
385        );
386        // With replacement can sample more than 100%.
387        assert!(
388            df.sample_frac(
389                &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
390                true,
391                false,
392                Some(0)
393            )
394            .is_ok()
395        );
396    }
397}