polars_core/chunked_array/
random.rs

1use num_traits::{Float, NumCast};
2use polars_error::to_compute_err;
3use rand::distributions::Bernoulli;
4use rand::prelude::*;
5use rand::seq::index::IndexVec;
6use rand_distr::{Normal, Standard, StandardNormal, Uniform};
7
8use crate::prelude::DataType::Float64;
9use crate::prelude::*;
10use crate::random::get_global_random_u64;
11use crate::utils::NoNull;
12
13fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14    if len == 0 {
15        return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16    }
17    let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18    let dist = Uniform::new(0, len as IdxSize);
19    (0..n as IdxSize)
20        .map(move |_| dist.sample(&mut rng))
21        .collect_trusted::<NoNull<IdxCa>>()
22        .into_inner()
23}
24
25fn create_rand_index_no_replacement(
26    n: usize,
27    len: usize,
28    seed: Option<u64>,
29    shuffle: bool,
30) -> IdxCa {
31    let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32    let mut buf: Vec<IdxSize>;
33    if n == len {
34        buf = (0..len as IdxSize).collect();
35        if shuffle {
36            buf.shuffle(&mut rng)
37        }
38    } else {
39        // TODO: avoid extra potential copy by vendoring rand::seq::index::sample,
40        // or genericize take over slices over any unsigned type. The optimizer
41        // should get rid of the extra copy already if IdxSize matches the IndexVec
42        // size returned.
43        buf = match rand::seq::index::sample(&mut rng, len, n) {
44            IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45            IndexVec::USize(v) => v.into_iter().map(|x| x as IdxSize).collect(),
46        };
47    }
48    IdxCa::new_vec(PlSmallStr::EMPTY, buf)
49}
50
51impl<T> ChunkedArray<T>
52where
53    T: PolarsNumericType,
54    Standard: Distribution<T::Native>,
55{
56    pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
57        let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
58        (0..size)
59            .map(|_| {
60                if rng.r#gen::<f32>() < null_density {
61                    None
62                } else {
63                    Some(rng.r#gen())
64                }
65            })
66            .collect()
67    }
68}
69
70fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
71    polars_ensure!(
72        with_replacement || n <= len,
73        ShapeMismatch:
74        "cannot take a larger sample than the total population when `with_replacement=false`"
75    );
76    Ok(())
77}
78
79impl Series {
80    pub fn sample_n(
81        &self,
82        n: usize,
83        with_replacement: bool,
84        shuffle: bool,
85        seed: Option<u64>,
86    ) -> PolarsResult<Self> {
87        ensure_shape(n, self.len(), with_replacement)?;
88        if n == 0 {
89            return Ok(self.clear());
90        }
91        let len = self.len();
92
93        match with_replacement {
94            true => {
95                let idx = create_rand_index_with_replacement(n, len, seed);
96                debug_assert_eq!(len, self.len());
97                // SAFETY: we know that we never go out of bounds.
98                unsafe { Ok(self.take_unchecked(&idx)) }
99            },
100            false => {
101                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
102                debug_assert_eq!(len, self.len());
103                // SAFETY: we know that we never go out of bounds.
104                unsafe { Ok(self.take_unchecked(&idx)) }
105            },
106        }
107    }
108
109    /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
110    pub fn sample_frac(
111        &self,
112        frac: f64,
113        with_replacement: bool,
114        shuffle: bool,
115        seed: Option<u64>,
116    ) -> PolarsResult<Self> {
117        let n = (self.len() as f64 * frac) as usize;
118        self.sample_n(n, with_replacement, shuffle, seed)
119    }
120
121    pub fn shuffle(&self, seed: Option<u64>) -> Self {
122        let len = self.len();
123        let n = len;
124        let idx = create_rand_index_no_replacement(n, len, seed, true);
125        debug_assert_eq!(len, self.len());
126        // SAFETY: we know that we never go out of bounds.
127        unsafe { self.take_unchecked(&idx) }
128    }
129}
130
131impl<T> ChunkedArray<T>
132where
133    T: PolarsDataType,
134    ChunkedArray<T>: ChunkTake<IdxCa>,
135{
136    /// Sample n datapoints from this [`ChunkedArray`].
137    pub fn sample_n(
138        &self,
139        n: usize,
140        with_replacement: bool,
141        shuffle: bool,
142        seed: Option<u64>,
143    ) -> PolarsResult<Self> {
144        ensure_shape(n, self.len(), with_replacement)?;
145        let len = self.len();
146
147        match with_replacement {
148            true => {
149                let idx = create_rand_index_with_replacement(n, len, seed);
150                debug_assert_eq!(len, self.len());
151                // SAFETY: we know that we never go out of bounds.
152                unsafe { Ok(self.take_unchecked(&idx)) }
153            },
154            false => {
155                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
156                debug_assert_eq!(len, self.len());
157                // SAFETY: we know that we never go out of bounds.
158                unsafe { Ok(self.take_unchecked(&idx)) }
159            },
160        }
161    }
162
163    /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
164    pub fn sample_frac(
165        &self,
166        frac: f64,
167        with_replacement: bool,
168        shuffle: bool,
169        seed: Option<u64>,
170    ) -> PolarsResult<Self> {
171        let n = (self.len() as f64 * frac) as usize;
172        self.sample_n(n, with_replacement, shuffle, seed)
173    }
174}
175
176impl DataFrame {
177    /// Sample n datapoints from this [`DataFrame`].
178    pub fn sample_n(
179        &self,
180        n: &Series,
181        with_replacement: bool,
182        shuffle: bool,
183        seed: Option<u64>,
184    ) -> PolarsResult<Self> {
185        polars_ensure!(
186        n.len() == 1,
187        ComputeError: "Sample size must be a single value."
188        );
189
190        let n = n.cast(&IDX_DTYPE)?;
191        let n = n.idx()?;
192
193        match n.get(0) {
194            Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
195            None => Ok(self.clear()),
196        }
197    }
198
199    pub fn sample_n_literal(
200        &self,
201        n: usize,
202        with_replacement: bool,
203        shuffle: bool,
204        seed: Option<u64>,
205    ) -> PolarsResult<Self> {
206        ensure_shape(n, self.height(), with_replacement)?;
207        // All columns should used the same indices. So we first create the indices.
208        let idx = match with_replacement {
209            true => create_rand_index_with_replacement(n, self.height(), seed),
210            false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
211        };
212        // SAFETY: the indices are within bounds.
213        Ok(unsafe { self.take_unchecked(&idx) })
214    }
215
216    /// Sample a fraction between 0.0-1.0 of this [`DataFrame`].
217    pub fn sample_frac(
218        &self,
219        frac: &Series,
220        with_replacement: bool,
221        shuffle: bool,
222        seed: Option<u64>,
223    ) -> PolarsResult<Self> {
224        polars_ensure!(
225        frac.len() == 1,
226        ComputeError: "Sample fraction must be a single value."
227        );
228
229        let frac = frac.cast(&Float64)?;
230        let frac = frac.f64()?;
231
232        match frac.get(0) {
233            Some(frac) => {
234                let n = (self.height() as f64 * frac) as usize;
235                self.sample_n_literal(n, with_replacement, shuffle, seed)
236            },
237            None => Ok(self.clear()),
238        }
239    }
240}
241
242impl<T> ChunkedArray<T>
243where
244    T: PolarsNumericType,
245    T::Native: Float,
246{
247    /// Create [`ChunkedArray`] with samples from a Normal distribution.
248    pub fn rand_normal(
249        name: PlSmallStr,
250        length: usize,
251        mean: f64,
252        std_dev: f64,
253    ) -> PolarsResult<Self> {
254        let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
255        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
256        let mut rng = rand::thread_rng();
257        for _ in 0..length {
258            let smpl = normal.sample(&mut rng);
259            let smpl = NumCast::from(smpl).unwrap();
260            builder.append_value(smpl)
261        }
262        Ok(builder.finish())
263    }
264
265    /// Create [`ChunkedArray`] with samples from a Standard Normal distribution.
266    pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
267        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
268        let mut rng = rand::thread_rng();
269        for _ in 0..length {
270            let smpl: f64 = rng.sample(StandardNormal);
271            let smpl = NumCast::from(smpl).unwrap();
272            builder.append_value(smpl)
273        }
274        builder.finish()
275    }
276
277    /// Create [`ChunkedArray`] with samples from a Uniform distribution.
278    pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
279        let uniform = Uniform::new(low, high);
280        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
281        let mut rng = rand::thread_rng();
282        for _ in 0..length {
283            let smpl = uniform.sample(&mut rng);
284            let smpl = NumCast::from(smpl).unwrap();
285            builder.append_value(smpl)
286        }
287        builder.finish()
288    }
289}
290
291impl BooleanChunked {
292    /// Create [`ChunkedArray`] with samples from a Bernoulli distribution.
293    pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
294        let dist = Bernoulli::new(p).map_err(to_compute_err)?;
295        let mut rng = rand::thread_rng();
296        let mut builder = BooleanChunkedBuilder::new(name, length);
297        for _ in 0..length {
298            let smpl = dist.sample(&mut rng);
299            builder.append_value(smpl)
300        }
301        Ok(builder.finish())
302    }
303}
304
305#[cfg(test)]
306mod test {
307    use super::*;
308
309    #[test]
310    fn test_sample() {
311        let df = df![
312            "foo" => &[1, 2, 3, 4, 5]
313        ]
314        .unwrap();
315
316        // Default samples are random and don't require seeds.
317        assert!(
318            df.sample_n(
319                &Series::new(PlSmallStr::from_static("s"), &[3]),
320                false,
321                false,
322                None
323            )
324            .is_ok()
325        );
326        assert!(
327            df.sample_frac(
328                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
329                false,
330                false,
331                None
332            )
333            .is_ok()
334        );
335        // With seeding.
336        assert!(
337            df.sample_n(
338                &Series::new(PlSmallStr::from_static("s"), &[3]),
339                false,
340                false,
341                Some(0)
342            )
343            .is_ok()
344        );
345        assert!(
346            df.sample_frac(
347                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
348                false,
349                false,
350                Some(0)
351            )
352            .is_ok()
353        );
354        // Without replacement can not sample more than 100%.
355        assert!(
356            df.sample_frac(
357                &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
358                false,
359                false,
360                Some(0)
361            )
362            .is_err()
363        );
364        assert!(
365            df.sample_n(
366                &Series::new(PlSmallStr::from_static("s"), &[3]),
367                true,
368                false,
369                Some(0)
370            )
371            .is_ok()
372        );
373        assert!(
374            df.sample_frac(
375                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
376                true,
377                false,
378                Some(0)
379            )
380            .is_ok()
381        );
382        // With replacement can sample more than 100%.
383        assert!(
384            df.sample_frac(
385                &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
386                true,
387                false,
388                Some(0)
389            )
390            .is_ok()
391        );
392    }
393}