polars_core/chunked_array/
random.rs

1use num_traits::{Float, NumCast};
2use polars_error::to_compute_err;
3use rand::distr::Bernoulli;
4use rand::prelude::*;
5use rand::seq::index::IndexVec;
6use rand_distr::{Normal, StandardNormal, StandardUniform, Uniform};
7
8use crate::prelude::DataType::Float64;
9use crate::prelude::*;
10use crate::random::get_global_random_u64;
11use crate::utils::NoNull;
12
13fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14    if len == 0 {
15        return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16    }
17    let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18    let dist = Uniform::new(0, len as IdxSize).unwrap();
19    (0..n as IdxSize)
20        .map(move |_| dist.sample(&mut rng))
21        .collect_trusted::<NoNull<IdxCa>>()
22        .into_inner()
23}
24
25fn create_rand_index_no_replacement(
26    n: usize,
27    len: usize,
28    seed: Option<u64>,
29    shuffle: bool,
30) -> IdxCa {
31    let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32    let mut buf: Vec<IdxSize>;
33    if n == len {
34        buf = (0..len as IdxSize).collect();
35        if shuffle {
36            buf.shuffle(&mut rng)
37        }
38    } else {
39        // TODO: avoid extra potential copy by vendoring rand::seq::index::sample,
40        // or genericize take over slices over any unsigned type. The optimizer
41        // should get rid of the extra copy already if IdxSize matches the IndexVec
42        // size returned.
43        buf = match rand::seq::index::sample(&mut rng, len, n) {
44            IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45            #[cfg(target_pointer_width = "64")]
46            IndexVec::U64(v) => v.into_iter().map(|x| x as IdxSize).collect(),
47        };
48    }
49    IdxCa::new_vec(PlSmallStr::EMPTY, buf)
50}
51
52impl<T> ChunkedArray<T>
53where
54    T: PolarsNumericType,
55    StandardUniform: Distribution<T::Native>,
56{
57    pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
58        let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
59        (0..size)
60            .map(|_| {
61                if rng.random::<f32>() < null_density {
62                    None
63                } else {
64                    Some(rng.random())
65                }
66            })
67            .collect()
68    }
69}
70
71fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
72    polars_ensure!(
73        with_replacement || n <= len,
74        ShapeMismatch:
75        "cannot take a larger sample than the total population when `with_replacement=false`"
76    );
77    Ok(())
78}
79
80impl Series {
81    pub fn sample_n(
82        &self,
83        n: usize,
84        with_replacement: bool,
85        shuffle: bool,
86        seed: Option<u64>,
87    ) -> PolarsResult<Self> {
88        ensure_shape(n, self.len(), with_replacement)?;
89        if n == 0 {
90            return Ok(self.clear());
91        }
92        let len = self.len();
93
94        match with_replacement {
95            true => {
96                let idx = create_rand_index_with_replacement(n, len, seed);
97                debug_assert_eq!(len, self.len());
98                // SAFETY: we know that we never go out of bounds.
99                unsafe { Ok(self.take_unchecked(&idx)) }
100            },
101            false => {
102                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
103                debug_assert_eq!(len, self.len());
104                // SAFETY: we know that we never go out of bounds.
105                unsafe { Ok(self.take_unchecked(&idx)) }
106            },
107        }
108    }
109
110    /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
111    pub fn sample_frac(
112        &self,
113        frac: f64,
114        with_replacement: bool,
115        shuffle: bool,
116        seed: Option<u64>,
117    ) -> PolarsResult<Self> {
118        let n = (self.len() as f64 * frac) as usize;
119        self.sample_n(n, with_replacement, shuffle, seed)
120    }
121
122    pub fn shuffle(&self, seed: Option<u64>) -> Self {
123        let len = self.len();
124        let n = len;
125        let idx = create_rand_index_no_replacement(n, len, seed, true);
126        debug_assert_eq!(len, self.len());
127        // SAFETY: we know that we never go out of bounds.
128        unsafe { self.take_unchecked(&idx) }
129    }
130}
131
132impl<T> ChunkedArray<T>
133where
134    T: PolarsDataType,
135    ChunkedArray<T>: ChunkTake<IdxCa>,
136{
137    /// Sample n datapoints from this [`ChunkedArray`].
138    pub fn sample_n(
139        &self,
140        n: usize,
141        with_replacement: bool,
142        shuffle: bool,
143        seed: Option<u64>,
144    ) -> PolarsResult<Self> {
145        ensure_shape(n, self.len(), with_replacement)?;
146        let len = self.len();
147
148        match with_replacement {
149            true => {
150                let idx = create_rand_index_with_replacement(n, len, seed);
151                debug_assert_eq!(len, self.len());
152                // SAFETY: we know that we never go out of bounds.
153                unsafe { Ok(self.take_unchecked(&idx)) }
154            },
155            false => {
156                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
157                debug_assert_eq!(len, self.len());
158                // SAFETY: we know that we never go out of bounds.
159                unsafe { Ok(self.take_unchecked(&idx)) }
160            },
161        }
162    }
163
164    /// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
165    pub fn sample_frac(
166        &self,
167        frac: f64,
168        with_replacement: bool,
169        shuffle: bool,
170        seed: Option<u64>,
171    ) -> PolarsResult<Self> {
172        let n = (self.len() as f64 * frac) as usize;
173        self.sample_n(n, with_replacement, shuffle, seed)
174    }
175}
176
177impl DataFrame {
178    /// Sample n datapoints from this [`DataFrame`].
179    pub fn sample_n(
180        &self,
181        n: &Series,
182        with_replacement: bool,
183        shuffle: bool,
184        seed: Option<u64>,
185    ) -> PolarsResult<Self> {
186        polars_ensure!(
187        n.len() == 1,
188        ComputeError: "Sample size must be a single value."
189        );
190
191        let n = n.cast(&IDX_DTYPE)?;
192        let n = n.idx()?;
193
194        match n.get(0) {
195            Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
196            None => Ok(self.clear()),
197        }
198    }
199
200    pub fn sample_n_literal(
201        &self,
202        n: usize,
203        with_replacement: bool,
204        shuffle: bool,
205        seed: Option<u64>,
206    ) -> PolarsResult<Self> {
207        ensure_shape(n, self.height(), with_replacement)?;
208        // All columns should used the same indices. So we first create the indices.
209        let idx = match with_replacement {
210            true => create_rand_index_with_replacement(n, self.height(), seed),
211            false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
212        };
213        // SAFETY: the indices are within bounds.
214        Ok(unsafe { self.take_unchecked(&idx) })
215    }
216
217    /// Sample a fraction between 0.0-1.0 of this [`DataFrame`].
218    pub fn sample_frac(
219        &self,
220        frac: &Series,
221        with_replacement: bool,
222        shuffle: bool,
223        seed: Option<u64>,
224    ) -> PolarsResult<Self> {
225        polars_ensure!(
226        frac.len() == 1,
227        ComputeError: "Sample fraction must be a single value."
228        );
229
230        let frac = frac.cast(&Float64)?;
231        let frac = frac.f64()?;
232
233        match frac.get(0) {
234            Some(frac) => {
235                let n = (self.height() as f64 * frac) as usize;
236                self.sample_n_literal(n, with_replacement, shuffle, seed)
237            },
238            None => Ok(self.clear()),
239        }
240    }
241}
242
243impl<T> ChunkedArray<T>
244where
245    T: PolarsNumericType,
246    T::Native: Float,
247{
248    /// Create [`ChunkedArray`] with samples from a Normal distribution.
249    pub fn rand_normal(
250        name: PlSmallStr,
251        length: usize,
252        mean: f64,
253        std_dev: f64,
254    ) -> PolarsResult<Self> {
255        let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
256        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
257        let mut rng = rand::rng();
258        for _ in 0..length {
259            let smpl = normal.sample(&mut rng);
260            let smpl = NumCast::from(smpl).unwrap();
261            builder.append_value(smpl)
262        }
263        Ok(builder.finish())
264    }
265
266    /// Create [`ChunkedArray`] with samples from a Standard Normal distribution.
267    pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
268        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
269        let mut rng = rand::rng();
270        for _ in 0..length {
271            let smpl: f64 = rng.sample(StandardNormal);
272            let smpl = NumCast::from(smpl).unwrap();
273            builder.append_value(smpl)
274        }
275        builder.finish()
276    }
277
278    /// Create [`ChunkedArray`] with samples from a Uniform distribution.
279    pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
280        let uniform = Uniform::new(low, high).unwrap();
281        let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
282        let mut rng = rand::rng();
283        for _ in 0..length {
284            let smpl = uniform.sample(&mut rng);
285            let smpl = NumCast::from(smpl).unwrap();
286            builder.append_value(smpl)
287        }
288        builder.finish()
289    }
290}
291
292impl BooleanChunked {
293    /// Create [`ChunkedArray`] with samples from a Bernoulli distribution.
294    pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
295        let dist = Bernoulli::new(p).map_err(to_compute_err)?;
296        let mut rng = rand::rng();
297        let mut builder = BooleanChunkedBuilder::new(name, length);
298        for _ in 0..length {
299            let smpl = dist.sample(&mut rng);
300            builder.append_value(smpl)
301        }
302        Ok(builder.finish())
303    }
304}
305
306#[cfg(test)]
307mod test {
308    use super::*;
309
310    #[test]
311    fn test_sample() {
312        let df = df![
313            "foo" => &[1, 2, 3, 4, 5]
314        ]
315        .unwrap();
316
317        // Default samples are random and don't require seeds.
318        assert!(
319            df.sample_n(
320                &Series::new(PlSmallStr::from_static("s"), &[3]),
321                false,
322                false,
323                None
324            )
325            .is_ok()
326        );
327        assert!(
328            df.sample_frac(
329                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
330                false,
331                false,
332                None
333            )
334            .is_ok()
335        );
336        // With seeding.
337        assert!(
338            df.sample_n(
339                &Series::new(PlSmallStr::from_static("s"), &[3]),
340                false,
341                false,
342                Some(0)
343            )
344            .is_ok()
345        );
346        assert!(
347            df.sample_frac(
348                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
349                false,
350                false,
351                Some(0)
352            )
353            .is_ok()
354        );
355        // Without replacement can not sample more than 100%.
356        assert!(
357            df.sample_frac(
358                &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
359                false,
360                false,
361                Some(0)
362            )
363            .is_err()
364        );
365        assert!(
366            df.sample_n(
367                &Series::new(PlSmallStr::from_static("s"), &[3]),
368                true,
369                false,
370                Some(0)
371            )
372            .is_ok()
373        );
374        assert!(
375            df.sample_frac(
376                &Series::new(PlSmallStr::from_static("frac"), &[0.4]),
377                true,
378                false,
379                Some(0)
380            )
381            .is_ok()
382        );
383        // With replacement can sample more than 100%.
384        assert!(
385            df.sample_frac(
386                &Series::new(PlSmallStr::from_static("frac"), &[2.0]),
387                true,
388                false,
389                Some(0)
390            )
391            .is_ok()
392        );
393    }
394}