Skip to main content

polars_ops/chunked_array/binary/
slice.rs

1use std::cmp::Ordering;
2
3use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
4use polars_core::prelude::{BinaryChunked, ChunkFullNull, DataType, Int64Chunked, UInt64Chunked};
5use polars_error::{PolarsResult, polars_ensure};
6
7fn head_binary(opt_bytes: Option<&[u8]>, opt_n: Option<i64>) -> Option<&[u8]> {
8    if let (Some(bytes), Some(n)) = (opt_bytes, opt_n) {
9        let end_idx = head_binary_values(bytes, n);
10        Some(&bytes[..end_idx])
11    } else {
12        None
13    }
14}
15
16fn head_binary_values(bytes: &[u8], n: i64) -> usize {
17    match n.cmp(&0) {
18        Ordering::Equal => 0,
19        Ordering::Greater => {
20            // Take first n bytes
21            std::cmp::min(n as usize, bytes.len())
22        },
23        Ordering::Less => {
24            // End n bytes from the end
25            bytes.len().saturating_sub((-n) as usize)
26        },
27    }
28}
29
30fn tail_binary(opt_bytes: Option<&[u8]>, opt_n: Option<i64>) -> Option<&[u8]> {
31    if let (Some(bytes), Some(n)) = (opt_bytes, opt_n) {
32        let start_idx = tail_binary_values(bytes, n);
33        Some(&bytes[start_idx..])
34    } else {
35        None
36    }
37}
38
39fn tail_binary_values(bytes: &[u8], n: i64) -> usize {
40    let max_len = bytes.len();
41
42    match n.cmp(&0) {
43        Ordering::Equal => max_len,
44        Ordering::Greater => {
45            // Start from nth byte from the end
46            max_len.saturating_sub(n as usize)
47        },
48        Ordering::Less => {
49            // Start after the nth byte
50            std::cmp::min((-n) as usize, max_len)
51        },
52    }
53}
54
55fn slice_ternary_offsets(
56    opt_bytes: Option<&[u8]>,
57    opt_offset: Option<i64>,
58    opt_length: Option<u64>,
59) -> Option<(usize, usize)> {
60    let bytes = opt_bytes?;
61    let offset = opt_offset?;
62    Some(slice_ternary_offsets_value(
63        bytes,
64        offset,
65        opt_length.unwrap_or(u64::MAX),
66    ))
67}
68
69pub fn slice_ternary_offsets_value(bytes: &[u8], offset: i64, length: u64) -> (usize, usize) {
70    // Fast-path: always empty slice
71    if length == 0 || offset >= bytes.len() as i64 {
72        return (0, 0);
73    }
74
75    let start_byte_offset = if offset >= 0 {
76        std::cmp::min(offset as usize, bytes.len())
77    } else {
78        // If `offset` is negative, it counts from the end
79        let abs_offset = (-offset) as usize;
80        if abs_offset > bytes.len() {
81            // Offset is before the start - handle length reduction
82            let length_reduction = abs_offset - bytes.len();
83            let adjusted_length = (length as usize).saturating_sub(length_reduction);
84            return (0, std::cmp::min(adjusted_length, bytes.len()));
85        }
86        bytes.len() - abs_offset
87    };
88
89    let remaining = bytes.len() - start_byte_offset;
90    let end_byte_offset = start_byte_offset + std::cmp::min(length as usize, remaining);
91
92    (start_byte_offset, end_byte_offset)
93}
94
95fn slice_ternary(
96    opt_bytes: Option<&[u8]>,
97    opt_offset: Option<i64>,
98    opt_length: Option<u64>,
99) -> Option<&[u8]> {
100    let (start, end) = slice_ternary_offsets(opt_bytes, opt_offset, opt_length)?;
101    opt_bytes.map(|bytes| &bytes[start..end])
102}
103
104pub(super) fn slice(
105    ca: &BinaryChunked,
106    offset: &Int64Chunked,
107    length: &UInt64Chunked,
108) -> BinaryChunked {
109    match (ca.len(), offset.len(), length.len()) {
110        (1, 1, _) => {
111            let bytes = ca.get(0);
112            let offset = offset.get(0);
113            unary_elementwise(length, |length| slice_ternary(bytes, offset, length))
114                .with_name(ca.name().clone())
115        },
116        (_, 1, 1) => {
117            let offset = offset.get(0);
118            let length = length.get(0).unwrap_or(u64::MAX);
119
120            let Some(offset) = offset else {
121                return BinaryChunked::full_null(ca.name().clone(), ca.len());
122            };
123
124            ca.apply_nonnull_values_generic(DataType::Binary, |val| {
125                let (start, end) = slice_ternary_offsets_value(val, offset, length);
126                &val[start..end]
127            })
128        },
129        (1, _, 1) => {
130            let bytes = ca.get(0);
131            let length = length.get(0);
132            unary_elementwise(offset, |offset| slice_ternary(bytes, offset, length))
133                .with_name(ca.name().clone())
134        },
135        (1, len_b, len_c) if len_b == len_c => {
136            let bytes = ca.get(0);
137            binary_elementwise(offset, length, |offset, length| {
138                slice_ternary(bytes, offset, length)
139            })
140        },
141        (len_a, 1, len_c) if len_a == len_c => {
142            fn infer<F: for<'a> FnMut(Option<&'a [u8]>, Option<u64>) -> Option<&'a [u8]>>(
143                f: F,
144            ) -> F {
145                f
146            }
147            let offset = offset.get(0);
148            binary_elementwise(
149                ca,
150                length,
151                infer(|bytes, length| slice_ternary(bytes, offset, length)),
152            )
153        },
154        (len_a, len_b, 1) if len_a == len_b => {
155            fn infer<F: for<'a> FnMut(Option<&'a [u8]>, Option<i64>) -> Option<&'a [u8]>>(
156                f: F,
157            ) -> F {
158                f
159            }
160            let length = length.get(0);
161            binary_elementwise(
162                ca,
163                offset,
164                infer(|bytes, offset| slice_ternary(bytes, offset, length)),
165            )
166        },
167        _ => ternary_elementwise(ca, offset, length, slice_ternary),
168    }
169}
170
171pub(super) fn head(ca: &BinaryChunked, n: &Int64Chunked) -> PolarsResult<BinaryChunked> {
172    match (ca.len(), n.len()) {
173        (len, 1) => {
174            let n = n.get(0);
175            let Some(n) = n else {
176                return Ok(BinaryChunked::full_null(ca.name().clone(), len));
177            };
178
179            Ok(ca.apply_nonnull_values_generic(DataType::Binary, |val| {
180                let end = head_binary_values(val, n);
181                &val[..end]
182            }))
183        },
184        (1, _) => {
185            let bytes = ca.get(0);
186            Ok(unary_elementwise(n, |n| head_binary(bytes, n)).with_name(ca.name().clone()))
187        },
188        (a, b) => {
189            polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'bin.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);
190            Ok(binary_elementwise(ca, n, head_binary))
191        },
192    }
193}
194
195pub(super) fn tail(ca: &BinaryChunked, n: &Int64Chunked) -> PolarsResult<BinaryChunked> {
196    Ok(match (ca.len(), n.len()) {
197        (len, 1) => {
198            let n = n.get(0);
199            let Some(n) = n else {
200                return Ok(BinaryChunked::full_null(ca.name().clone(), len));
201            };
202
203            ca.apply_nonnull_values_generic(DataType::Binary, |val| {
204                let start = tail_binary_values(val, n);
205                &val[start..]
206            })
207        },
208        (1, _) => {
209            let bytes = ca.get(0);
210            unary_elementwise(n, |n| tail_binary(bytes, n)).with_name(ca.name().clone())
211        },
212        (a, b) => {
213            polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'bin.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b);
214            binary_elementwise(ca, n, tail_binary)
215        },
216    })
217}