1#![allow(unsafe_op_in_unsafe_fn)]
2use std::error::Error;
3
4use arrow::array::{Array, MutablePlString, StaticArray};
5use arrow::compute::utils::combine_validities_and;
6use polars_error::PolarsResult;
7use polars_utils::pl_str::PlSmallStr;
8
9use crate::chunked_array::flags::StatisticsFlags;
10use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter};
11use crate::prelude::{ChunkedArray, CompatLevel, PolarsDataType, Series, StringChunked};
12use crate::utils::{align_chunks_binary, align_chunks_binary_owned, align_chunks_ternary};
13
14pub trait UnaryFnMut<A1>: FnMut(A1) -> Self::Ret {
17 type Ret;
18}
19
20impl<A1, R, T: FnMut(A1) -> R> UnaryFnMut<A1> for T {
21 type Ret = R;
22}
23
24pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
27 type Ret;
28}
29
30impl<A1, A2, A3, R, T: FnMut(A1, A2, A3) -> R> TernaryFnMut<A1, A2, A3> for T {
31 type Ret = R;
32}
33
34pub trait BinaryFnMut<A1, A2>: FnMut(A1, A2) -> Self::Ret {
37 type Ret;
38}
39
40impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
41 type Ret = R;
42}
43
44#[inline]
46pub fn unary_kernel<T, V, F, Arr>(ca: &ChunkedArray<T>, op: F) -> ChunkedArray<V>
47where
48 T: PolarsDataType,
49 V: PolarsDataType<Array = Arr>,
50 Arr: Array,
51 F: FnMut(&T::Array) -> Arr,
52{
53 let iter = ca.downcast_iter().map(op);
54 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
55}
56
57#[inline]
59pub fn unary_kernel_owned<T, V, F, Arr>(ca: ChunkedArray<T>, op: F) -> ChunkedArray<V>
60where
61 T: PolarsDataType,
62 V: PolarsDataType<Array = Arr>,
63 Arr: Array,
64 F: FnMut(T::Array) -> Arr,
65{
66 let name = ca.name().clone();
67 let iter = ca.downcast_into_iter().map(op);
68 ChunkedArray::from_chunk_iter(name, iter)
69}
70
71#[inline]
72pub fn unary_elementwise<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
73where
74 T: PolarsDataType,
75 V: PolarsDataType,
76 F: UnaryFnMut<Option<T::Physical<'a>>>,
77 V::Array: ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
78{
79 if ca.has_nulls() {
80 let iter = ca
81 .downcast_iter()
82 .map(|arr| arr.iter().map(&mut op).collect_arr());
83 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
84 } else {
85 let iter = ca
86 .downcast_iter()
87 .map(|arr| arr.values_iter().map(|x| op(Some(x))).collect_arr());
88 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
89 }
90}
91
92#[inline]
93pub fn try_unary_elementwise<'a, T, V, F, K, E>(
94 ca: &'a ChunkedArray<T>,
95 mut op: F,
96) -> Result<ChunkedArray<V>, E>
97where
98 T: PolarsDataType,
99 V: PolarsDataType,
100 F: FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E>,
101 V::Array: ArrayFromIter<Option<K>>,
102{
103 let iter = ca
104 .downcast_iter()
105 .map(|arr| arr.iter().map(&mut op).try_collect_arr());
106 ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter)
107}
108
109#[inline]
110pub fn unary_elementwise_values<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
111where
112 T: PolarsDataType,
113 V: PolarsDataType,
114 F: UnaryFnMut<T::Physical<'a>>,
115 V::Array: ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
116{
117 if ca.null_count() == ca.len() {
118 let arr = V::Array::full_null(
119 ca.len(),
120 V::get_static_dtype().to_arrow(CompatLevel::newest()),
121 );
122 return ChunkedArray::with_chunk(ca.name().clone(), arr);
123 }
124
125 let iter = ca.downcast_iter().map(|arr| {
126 let validity = arr.validity().cloned();
127 let arr: V::Array = arr.values_iter().map(&mut op).collect_arr();
128 arr.with_validity_typed(validity)
129 });
130 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
131}
132
133#[inline]
134pub fn try_unary_elementwise_values<'a, T, V, F, K, E>(
135 ca: &'a ChunkedArray<T>,
136 mut op: F,
137) -> Result<ChunkedArray<V>, E>
138where
139 T: PolarsDataType,
140 V: PolarsDataType,
141 F: FnMut(T::Physical<'a>) -> Result<K, E>,
142 V::Array: ArrayFromIter<K>,
143{
144 if ca.null_count() == ca.len() {
145 let arr = V::Array::full_null(
146 ca.len(),
147 V::get_static_dtype().to_arrow(CompatLevel::newest()),
148 );
149 return Ok(ChunkedArray::with_chunk(ca.name().clone(), arr));
150 }
151
152 let iter = ca.downcast_iter().map(|arr| {
153 let validity = arr.validity().cloned();
154 let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?;
155 Ok(arr.with_validity_typed(validity))
156 });
157 ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter)
158}
159
160#[inline]
165pub fn unary_mut_values<T, V, F, Arr>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
166where
167 T: PolarsDataType,
168 V: PolarsDataType<Array = Arr>,
169 Arr: Array + StaticArray,
170 F: FnMut(&T::Array) -> Arr,
171{
172 let iter = ca
173 .downcast_iter()
174 .map(|arr| op(arr).with_validity_typed(arr.validity().cloned()));
175 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
176}
177
178#[inline]
180pub fn unary_mut_with_options<T, V, F, Arr>(ca: &ChunkedArray<T>, op: F) -> ChunkedArray<V>
181where
182 T: PolarsDataType,
183 V: PolarsDataType<Array = Arr>,
184 Arr: Array + StaticArray,
185 F: FnMut(&T::Array) -> Arr,
186{
187 ChunkedArray::from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op))
188}
189
190#[inline]
191pub fn try_unary_mut_with_options<T, V, F, Arr, E>(
192 ca: &ChunkedArray<T>,
193 op: F,
194) -> Result<ChunkedArray<V>, E>
195where
196 T: PolarsDataType,
197 V: PolarsDataType<Array = Arr>,
198 Arr: Array + StaticArray,
199 F: FnMut(&T::Array) -> Result<Arr, E>,
200 E: Error,
201{
202 ChunkedArray::try_from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op))
203}
204
205#[inline]
206pub fn binary_elementwise<T, U, V, F>(
207 lhs: &ChunkedArray<T>,
208 rhs: &ChunkedArray<U>,
209 mut op: F,
210) -> ChunkedArray<V>
211where
212 T: PolarsDataType,
213 U: PolarsDataType,
214 V: PolarsDataType,
215 F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
216 V::Array: for<'a> ArrayFromIter<
217 <F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
218 >,
219{
220 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
221 let iter = lhs
222 .downcast_iter()
223 .zip(rhs.downcast_iter())
224 .map(|(lhs_arr, rhs_arr)| {
225 let element_iter = lhs_arr
226 .iter()
227 .zip(rhs_arr.iter())
228 .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val));
229 element_iter.collect_arr()
230 });
231 ChunkedArray::from_chunk_iter(lhs.name().clone(), iter)
232}
233
234#[inline]
235pub fn binary_elementwise_for_each<'a, 'b, T, U, F>(
236 lhs: &'a ChunkedArray<T>,
237 rhs: &'b ChunkedArray<U>,
238 mut op: F,
239) where
240 T: PolarsDataType,
241 U: PolarsDataType,
242 F: FnMut(Option<T::Physical<'a>>, Option<U::Physical<'b>>),
243{
244 let mut lhs_arr_iter = lhs.downcast_iter();
245 let mut rhs_arr_iter = rhs.downcast_iter();
246
247 let lhs_arr = lhs_arr_iter.next().unwrap();
248 let rhs_arr = rhs_arr_iter.next().unwrap();
249
250 let mut lhs_remaining = lhs_arr.len();
251 let mut rhs_remaining = rhs_arr.len();
252 let mut lhs_iter = lhs_arr.iter();
253 let mut rhs_iter = rhs_arr.iter();
254
255 loop {
256 let range = std::cmp::min(lhs_remaining, rhs_remaining);
257
258 for _ in 0..range {
259 let lhs_opt_val = unsafe { lhs_iter.next().unwrap_unchecked() };
261 let rhs_opt_val = unsafe { rhs_iter.next().unwrap_unchecked() };
262 op(lhs_opt_val, rhs_opt_val)
263 }
264 lhs_remaining -= range;
265 rhs_remaining -= range;
266
267 if lhs_remaining == 0 {
268 let Some(new_arr) = lhs_arr_iter.next() else {
269 return;
270 };
271 lhs_remaining = new_arr.len();
272 lhs_iter = new_arr.iter();
273 }
274 if rhs_remaining == 0 {
275 let Some(new_arr) = rhs_arr_iter.next() else {
276 return;
277 };
278 rhs_remaining = new_arr.len();
279 rhs_iter = new_arr.iter();
280 }
281 }
282}
283
284#[inline]
285pub fn try_binary_elementwise<T, U, V, F, K, E>(
286 lhs: &ChunkedArray<T>,
287 rhs: &ChunkedArray<U>,
288 mut op: F,
289) -> Result<ChunkedArray<V>, E>
290where
291 T: PolarsDataType,
292 U: PolarsDataType,
293 V: PolarsDataType,
294 F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
295 V::Array: ArrayFromIter<Option<K>>,
296{
297 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
298 let iter = lhs
299 .downcast_iter()
300 .zip(rhs.downcast_iter())
301 .map(|(lhs_arr, rhs_arr)| {
302 let element_iter = lhs_arr
303 .iter()
304 .zip(rhs_arr.iter())
305 .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val));
306 element_iter.try_collect_arr()
307 });
308 ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter)
309}
310
311#[inline]
312pub fn binary_elementwise_values<T, U, V, F, K>(
313 lhs: &ChunkedArray<T>,
314 rhs: &ChunkedArray<U>,
315 mut op: F,
316) -> ChunkedArray<V>
317where
318 T: PolarsDataType,
319 U: PolarsDataType,
320 V: PolarsDataType,
321 F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
322 V::Array: ArrayFromIter<K>,
323{
324 if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
325 let len = lhs.len().min(rhs.len());
326 let arr = V::Array::full_null(len, V::get_static_dtype().to_arrow(CompatLevel::newest()));
327
328 return ChunkedArray::with_chunk(lhs.name().clone(), arr);
329 }
330
331 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
332
333 let iter = lhs
334 .downcast_iter()
335 .zip(rhs.downcast_iter())
336 .map(|(lhs_arr, rhs_arr)| {
337 let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());
338
339 let element_iter = lhs_arr
340 .values_iter()
341 .zip(rhs_arr.values_iter())
342 .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val));
343
344 let array: V::Array = element_iter.collect_arr();
345 array.with_validity_typed(validity)
346 });
347 ChunkedArray::from_chunk_iter(lhs.name().clone(), iter)
348}
349
350#[inline]
354pub fn binary_elementwise_into_string_amortized<T, U, F>(
355 lhs: &ChunkedArray<T>,
356 rhs: &ChunkedArray<U>,
357 mut op: F,
358) -> StringChunked
359where
360 T: PolarsDataType,
361 U: PolarsDataType,
362 F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>, &mut String),
363{
364 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
365 let mut buf = String::new();
366 let iter = lhs
367 .downcast_iter()
368 .zip(rhs.downcast_iter())
369 .map(|(lhs_arr, rhs_arr)| {
370 let mut mutarr = MutablePlString::with_capacity(lhs_arr.len());
371 lhs_arr
372 .iter()
373 .zip(rhs_arr.iter())
374 .for_each(|(lhs_opt, rhs_opt)| match (lhs_opt, rhs_opt) {
375 (None, _) | (_, None) => mutarr.push_null(),
376 (Some(lhs_val), Some(rhs_val)) => {
377 buf.clear();
378 op(lhs_val, rhs_val, &mut buf);
379 mutarr.push_value(&buf)
380 },
381 });
382 mutarr.freeze()
383 });
384 ChunkedArray::from_chunk_iter(lhs.name().clone(), iter)
385}
386
387#[inline]
392pub fn binary_mut_values<T, U, V, F, Arr>(
393 lhs: &ChunkedArray<T>,
394 rhs: &ChunkedArray<U>,
395 mut op: F,
396 name: PlSmallStr,
397) -> ChunkedArray<V>
398where
399 T: PolarsDataType,
400 U: PolarsDataType,
401 V: PolarsDataType<Array = Arr>,
402 Arr: Array + StaticArray,
403 F: FnMut(&T::Array, &U::Array) -> Arr,
404{
405 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
406 let iter = lhs
407 .downcast_iter()
408 .zip(rhs.downcast_iter())
409 .map(|(lhs_arr, rhs_arr)| {
410 let ret = op(lhs_arr, rhs_arr);
411 let inp_val = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());
412 let val = combine_validities_and(inp_val.as_ref(), ret.validity());
413 ret.with_validity_typed(val)
414 });
415 ChunkedArray::from_chunk_iter(name, iter)
416}
417
418#[inline]
420pub fn binary_mut_with_options<T, U, V, F, Arr>(
421 lhs: &ChunkedArray<T>,
422 rhs: &ChunkedArray<U>,
423 mut op: F,
424 name: PlSmallStr,
425) -> ChunkedArray<V>
426where
427 T: PolarsDataType,
428 U: PolarsDataType,
429 V: PolarsDataType<Array = Arr>,
430 Arr: Array,
431 F: FnMut(&T::Array, &U::Array) -> Arr,
432{
433 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
434 let iter = lhs
435 .downcast_iter()
436 .zip(rhs.downcast_iter())
437 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
438 ChunkedArray::from_chunk_iter(name, iter)
439}
440
441#[inline]
442pub fn try_binary_mut_with_options<T, U, V, F, Arr, E>(
443 lhs: &ChunkedArray<T>,
444 rhs: &ChunkedArray<U>,
445 mut op: F,
446 name: PlSmallStr,
447) -> Result<ChunkedArray<V>, E>
448where
449 T: PolarsDataType,
450 U: PolarsDataType,
451 V: PolarsDataType<Array = Arr>,
452 Arr: Array,
453 F: FnMut(&T::Array, &U::Array) -> Result<Arr, E>,
454 E: Error,
455{
456 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
457 let iter = lhs
458 .downcast_iter()
459 .zip(rhs.downcast_iter())
460 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
461 ChunkedArray::try_from_chunk_iter(name, iter)
462}
463
464pub fn binary<T, U, V, F, Arr>(
466 lhs: &ChunkedArray<T>,
467 rhs: &ChunkedArray<U>,
468 op: F,
469) -> ChunkedArray<V>
470where
471 T: PolarsDataType,
472 U: PolarsDataType,
473 V: PolarsDataType<Array = Arr>,
474 Arr: Array,
475 F: FnMut(&T::Array, &U::Array) -> Arr,
476{
477 binary_mut_with_options(lhs, rhs, op, lhs.name().clone())
478}
479
480pub fn binary_owned<L, R, V, F, Arr>(
482 lhs: ChunkedArray<L>,
483 rhs: ChunkedArray<R>,
484 mut op: F,
485) -> ChunkedArray<V>
486where
487 L: PolarsDataType,
488 R: PolarsDataType,
489 V: PolarsDataType<Array = Arr>,
490 Arr: Array,
491 F: FnMut(L::Array, R::Array) -> Arr,
492{
493 let name = lhs.name().clone();
494 let (lhs, rhs) = align_chunks_binary_owned(lhs, rhs);
495 let iter = lhs
496 .downcast_into_iter()
497 .zip(rhs.downcast_into_iter())
498 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
499 ChunkedArray::from_chunk_iter(name, iter)
500}
501
502pub fn try_binary<T, U, V, F, Arr, E>(
504 lhs: &ChunkedArray<T>,
505 rhs: &ChunkedArray<U>,
506 mut op: F,
507) -> Result<ChunkedArray<V>, E>
508where
509 T: PolarsDataType,
510 U: PolarsDataType,
511 V: PolarsDataType<Array = Arr>,
512 Arr: Array,
513 F: FnMut(&T::Array, &U::Array) -> Result<Arr, E>,
514 E: Error,
515{
516 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
517 let iter = lhs
518 .downcast_iter()
519 .zip(rhs.downcast_iter())
520 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
521 ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter)
522}
523
524#[inline]
529pub unsafe fn binary_unchecked_same_type<T, U, F>(
530 lhs: &ChunkedArray<T>,
531 rhs: &ChunkedArray<U>,
532 mut op: F,
533 keep_sorted: bool,
534 keep_fast_explode: bool,
535) -> ChunkedArray<T>
536where
537 T: PolarsDataType,
538 U: PolarsDataType,
539 F: FnMut(&T::Array, &U::Array) -> Box<dyn Array>,
540{
541 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
542 let chunks = lhs
543 .downcast_iter()
544 .zip(rhs.downcast_iter())
545 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
546 .collect();
547
548 let mut ca = lhs.copy_with_chunks(chunks);
549
550 let mut retain_flags = StatisticsFlags::empty();
551 use StatisticsFlags as F;
552 retain_flags.set(F::IS_SORTED_ANY, keep_sorted);
553 retain_flags.set(F::CAN_FAST_EXPLODE_LIST, keep_fast_explode);
554 ca.retain_flags_from(lhs.as_ref(), retain_flags);
555
556 ca
557}
558
559#[inline]
560pub fn binary_to_series<T, U, F>(
561 lhs: &ChunkedArray<T>,
562 rhs: &ChunkedArray<U>,
563 mut op: F,
564) -> PolarsResult<Series>
565where
566 T: PolarsDataType,
567 U: PolarsDataType,
568 F: FnMut(&T::Array, &U::Array) -> Box<dyn Array>,
569{
570 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
571 let chunks = lhs
572 .downcast_iter()
573 .zip(rhs.downcast_iter())
574 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
575 .collect::<Vec<_>>();
576 Series::try_from((lhs.name().clone(), chunks))
577}
578
579#[inline]
584pub unsafe fn try_binary_unchecked_same_type<T, U, F, E>(
585 lhs: &ChunkedArray<T>,
586 rhs: &ChunkedArray<U>,
587 mut op: F,
588 keep_sorted: bool,
589 keep_fast_explode: bool,
590) -> Result<ChunkedArray<T>, E>
591where
592 T: PolarsDataType,
593 U: PolarsDataType,
594 F: FnMut(&T::Array, &U::Array) -> Result<Box<dyn Array>, E>,
595 E: Error,
596{
597 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
598 let chunks = lhs
599 .downcast_iter()
600 .zip(rhs.downcast_iter())
601 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
602 .collect::<Result<Vec<_>, E>>()?;
603 let mut ca = lhs.copy_with_chunks(chunks);
604
605 let mut retain_flags = StatisticsFlags::empty();
606 use StatisticsFlags as F;
607 retain_flags.set(F::IS_SORTED_ANY, keep_sorted);
608 retain_flags.set(F::CAN_FAST_EXPLODE_LIST, keep_fast_explode);
609 ca.retain_flags_from(lhs.as_ref(), retain_flags);
610
611 Ok(ca)
612}
613
614#[inline]
615pub fn try_ternary_elementwise<T, U, V, G, F, K, E>(
616 ca1: &ChunkedArray<T>,
617 ca2: &ChunkedArray<U>,
618 ca3: &ChunkedArray<G>,
619 mut op: F,
620) -> Result<ChunkedArray<V>, E>
621where
622 T: PolarsDataType,
623 U: PolarsDataType,
624 V: PolarsDataType,
625 G: PolarsDataType,
626 F: for<'a> FnMut(
627 Option<T::Physical<'a>>,
628 Option<U::Physical<'a>>,
629 Option<G::Physical<'a>>,
630 ) -> Result<Option<K>, E>,
631 V::Array: ArrayFromIter<Option<K>>,
632{
633 let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3);
634 let iter = ca1
635 .downcast_iter()
636 .zip(ca2.downcast_iter())
637 .zip(ca3.downcast_iter())
638 .map(|((ca1_arr, ca2_arr), ca3_arr)| {
639 let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map(
640 |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| {
641 op(ca1_opt_val, ca2_opt_val, ca3_opt_val)
642 },
643 );
644 element_iter.try_collect_arr()
645 });
646 ChunkedArray::try_from_chunk_iter(ca1.name().clone(), iter)
647}
648
649#[inline]
650pub fn ternary_elementwise<T, U, V, G, F>(
651 ca1: &ChunkedArray<T>,
652 ca2: &ChunkedArray<U>,
653 ca3: &ChunkedArray<G>,
654 mut op: F,
655) -> ChunkedArray<V>
656where
657 T: PolarsDataType,
658 U: PolarsDataType,
659 G: PolarsDataType,
660 V: PolarsDataType,
661 F: for<'a> TernaryFnMut<
662 Option<T::Physical<'a>>,
663 Option<U::Physical<'a>>,
664 Option<G::Physical<'a>>,
665 >,
666 V::Array: for<'a> ArrayFromIter<
667 <F as TernaryFnMut<
668 Option<T::Physical<'a>>,
669 Option<U::Physical<'a>>,
670 Option<G::Physical<'a>>,
671 >>::Ret,
672 >,
673{
674 let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3);
675 let iter = ca1
676 .downcast_iter()
677 .zip(ca2.downcast_iter())
678 .zip(ca3.downcast_iter())
679 .map(|((ca1_arr, ca2_arr), ca3_arr)| {
680 let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map(
681 |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| {
682 op(ca1_opt_val, ca2_opt_val, ca3_opt_val)
683 },
684 );
685 element_iter.collect_arr()
686 });
687 ChunkedArray::from_chunk_iter(ca1.name().clone(), iter)
688}
689
690pub fn broadcast_binary_elementwise<T, U, V, F>(
691 lhs: &ChunkedArray<T>,
692 rhs: &ChunkedArray<U>,
693 mut op: F,
694) -> ChunkedArray<V>
695where
696 T: PolarsDataType,
697 U: PolarsDataType,
698 V: PolarsDataType,
699 F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
700 V::Array: for<'a> ArrayFromIter<
701 <F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
702 >,
703{
704 match (lhs.len(), rhs.len()) {
705 (1, _) => {
706 let a = unsafe { lhs.get_unchecked(0) };
707 unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone())
708 },
709 (_, 1) => {
710 let b = unsafe { rhs.get_unchecked(0) };
711 unary_elementwise(lhs, |a| op(a, b.clone()))
712 },
713 _ => binary_elementwise(lhs, rhs, op),
714 }
715}
716
717pub fn broadcast_try_binary_elementwise<T, U, V, F, K, E>(
718 lhs: &ChunkedArray<T>,
719 rhs: &ChunkedArray<U>,
720 mut op: F,
721) -> Result<ChunkedArray<V>, E>
722where
723 T: PolarsDataType,
724 U: PolarsDataType,
725 V: PolarsDataType,
726 F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
727 V::Array: ArrayFromIter<Option<K>>,
728{
729 match (lhs.len(), rhs.len()) {
730 (1, _) => {
731 let a = unsafe { lhs.get_unchecked(0) };
732 Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name().clone()))
733 },
734 (_, 1) => {
735 let b = unsafe { rhs.get_unchecked(0) };
736 try_unary_elementwise(lhs, |a| op(a, b.clone()))
737 },
738 _ => try_binary_elementwise(lhs, rhs, op),
739 }
740}
741
742pub fn broadcast_binary_elementwise_values<T, U, V, F, K>(
743 lhs: &ChunkedArray<T>,
744 rhs: &ChunkedArray<U>,
745 mut op: F,
746) -> ChunkedArray<V>
747where
748 T: PolarsDataType,
749 U: PolarsDataType,
750 V: PolarsDataType,
751 F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
752 V::Array: ArrayFromIter<K>,
753{
754 if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
755 let min = lhs.len().min(rhs.len());
756 let max = lhs.len().max(rhs.len());
757 let len = if min == 1 { max } else { min };
758 let arr = V::Array::full_null(len, V::get_static_dtype().to_arrow(CompatLevel::newest()));
759
760 return ChunkedArray::with_chunk(lhs.name().clone(), arr);
761 }
762
763 match (lhs.len(), rhs.len()) {
764 (1, _) => {
765 let a = unsafe { lhs.value_unchecked(0) };
766 unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone())
767 },
768 (_, 1) => {
769 let b = unsafe { rhs.value_unchecked(0) };
770 unary_elementwise_values(lhs, |a| op(a, b.clone()))
771 },
772 _ => binary_elementwise_values(lhs, rhs, op),
773 }
774}
775
776pub fn apply_binary_kernel_broadcast<'l, 'r, L, R, O, K, LK, RK>(
777 lhs: &'l ChunkedArray<L>,
778 rhs: &'r ChunkedArray<R>,
779 kernel: K,
780 lhs_broadcast_kernel: LK,
781 rhs_broadcast_kernel: RK,
782) -> ChunkedArray<O>
783where
784 L: PolarsDataType,
785 R: PolarsDataType,
786 O: PolarsDataType,
787 K: Fn(&L::Array, &R::Array) -> O::Array,
788 LK: Fn(L::Physical<'l>, &R::Array) -> O::Array,
789 RK: Fn(&L::Array, R::Physical<'r>) -> O::Array,
790{
791 let name = lhs.name();
792 let out = match (lhs.len(), rhs.len()) {
793 (a, b) if a == b => binary(lhs, rhs, |lhs, rhs| kernel(lhs, rhs)),
794 (_, 1) => {
796 let opt_rhs = rhs.get(0);
797 match opt_rhs {
798 None => {
799 let arr = O::Array::full_null(
800 lhs.len(),
801 O::get_static_dtype().to_arrow(CompatLevel::newest()),
802 );
803 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
804 },
805 Some(rhs) => unary_kernel(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())),
806 }
807 },
808 (1, _) => {
809 let opt_lhs = lhs.get(0);
810 match opt_lhs {
811 None => {
812 let arr = O::Array::full_null(
813 rhs.len(),
814 O::get_static_dtype().to_arrow(CompatLevel::newest()),
815 );
816 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
817 },
818 Some(lhs) => unary_kernel(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)),
819 }
820 },
821 _ => panic!("Cannot apply operation on arrays of different lengths"),
822 };
823 out.with_name(name.clone())
824}
825
826pub fn apply_binary_kernel_broadcast_owned<L, R, O, K, LK, RK>(
827 lhs: ChunkedArray<L>,
828 rhs: ChunkedArray<R>,
829 kernel: K,
830 lhs_broadcast_kernel: LK,
831 rhs_broadcast_kernel: RK,
832) -> ChunkedArray<O>
833where
834 L: PolarsDataType,
835 R: PolarsDataType,
836 O: PolarsDataType,
837 K: Fn(L::Array, R::Array) -> O::Array,
838 for<'a> LK: Fn(L::Physical<'a>, R::Array) -> O::Array,
839 for<'a> RK: Fn(L::Array, R::Physical<'a>) -> O::Array,
840{
841 let name = lhs.name().to_owned();
842 let out = match (lhs.len(), rhs.len()) {
843 (a, b) if a == b => binary_owned(lhs, rhs, kernel),
844 (_, 1) => {
846 let opt_rhs = rhs.get(0);
847 match opt_rhs {
848 None => {
849 let arr = O::Array::full_null(
850 lhs.len(),
851 O::get_static_dtype().to_arrow(CompatLevel::newest()),
852 );
853 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
854 },
855 Some(rhs) => unary_kernel_owned(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())),
856 }
857 },
858 (1, _) => {
859 let opt_lhs = lhs.get(0);
860 match opt_lhs {
861 None => {
862 let arr = O::Array::full_null(
863 rhs.len(),
864 O::get_static_dtype().to_arrow(CompatLevel::newest()),
865 );
866 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
867 },
868 Some(lhs) => unary_kernel_owned(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)),
869 }
870 },
871 _ => panic!("Cannot apply operation on arrays of different lengths"),
872 };
873 out.with_name(name)
874}