1#![allow(unsafe_op_in_unsafe_fn)]
2use std::error::Error;
3
4use arrow::array::{Array, MutablePlString, StaticArray};
5use arrow::compute::utils::combine_validities_and;
6use polars_error::PolarsResult;
7use polars_utils::pl_str::PlSmallStr;
8
9use crate::chunked_array::flags::StatisticsFlags;
10use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter};
11use crate::prelude::{ChunkedArray, CompatLevel, PolarsDataType, Series, StringChunked};
12use crate::utils::{align_chunks_binary, align_chunks_binary_owned, align_chunks_ternary};
13
14pub trait UnaryFnMut<A1>: FnMut(A1) -> Self::Ret {
17 type Ret;
18}
19
20impl<A1, R, T: FnMut(A1) -> R> UnaryFnMut<A1> for T {
21 type Ret = R;
22}
23
24pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
27 type Ret;
28}
29
30impl<A1, A2, A3, R, T: FnMut(A1, A2, A3) -> R> TernaryFnMut<A1, A2, A3> for T {
31 type Ret = R;
32}
33
34pub trait BinaryFnMut<A1, A2>: FnMut(A1, A2) -> Self::Ret {
37 type Ret;
38}
39
40impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
41 type Ret = R;
42}
43
44#[inline]
46pub fn unary_kernel<T, V, F, Arr>(ca: &ChunkedArray<T>, op: F) -> ChunkedArray<V>
47where
48 T: PolarsDataType,
49 V: PolarsDataType<Array = Arr>,
50 Arr: Array,
51 F: FnMut(&T::Array) -> Arr,
52{
53 let iter = ca.downcast_iter().map(op);
54 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
55}
56
57#[inline]
59pub fn unary_kernel_owned<T, V, F, Arr>(ca: ChunkedArray<T>, op: F) -> ChunkedArray<V>
60where
61 T: PolarsDataType,
62 V: PolarsDataType<Array = Arr>,
63 Arr: Array,
64 F: FnMut(T::Array) -> Arr,
65{
66 let name = ca.name().clone();
67 let iter = ca.downcast_into_iter().map(op);
68 ChunkedArray::from_chunk_iter(name, iter)
69}
70
71#[inline]
72pub fn unary_elementwise<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
73where
74 T: PolarsDataType,
75 V: PolarsDataType,
76 F: UnaryFnMut<Option<T::Physical<'a>>>,
77 V::Array: ArrayFromIter<<F as UnaryFnMut<Option<T::Physical<'a>>>>::Ret>,
78{
79 if ca.has_nulls() {
80 let iter = ca
81 .downcast_iter()
82 .map(|arr| arr.iter().map(&mut op).collect_arr());
83 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
84 } else {
85 let iter = ca
86 .downcast_iter()
87 .map(|arr| arr.values_iter().map(|x| op(Some(x))).collect_arr());
88 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
89 }
90}
91
92#[inline]
93pub fn try_unary_elementwise<'a, T, V, F, K, E>(
94 ca: &'a ChunkedArray<T>,
95 mut op: F,
96) -> Result<ChunkedArray<V>, E>
97where
98 T: PolarsDataType,
99 V: PolarsDataType,
100 F: FnMut(Option<T::Physical<'a>>) -> Result<Option<K>, E>,
101 V::Array: ArrayFromIter<Option<K>>,
102{
103 let iter = ca
104 .downcast_iter()
105 .map(|arr| arr.iter().map(&mut op).try_collect_arr());
106 ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter)
107}
108
109#[inline]
110pub fn unary_elementwise_values<'a, T, V, F>(ca: &'a ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
111where
112 T: PolarsDataType,
113 V: PolarsDataType,
114 F: UnaryFnMut<T::Physical<'a>>,
115 V::Array: ArrayFromIter<<F as UnaryFnMut<T::Physical<'a>>>::Ret>,
116{
117 if ca.null_count() == ca.len() {
118 let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow(CompatLevel::newest()));
119 return ChunkedArray::with_chunk(ca.name().clone(), arr);
120 }
121
122 let iter = ca.downcast_iter().map(|arr| {
123 let validity = arr.validity().cloned();
124 let arr: V::Array = arr.values_iter().map(&mut op).collect_arr();
125 arr.with_validity_typed(validity)
126 });
127 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
128}
129
130#[inline]
131pub fn try_unary_elementwise_values<'a, T, V, F, K, E>(
132 ca: &'a ChunkedArray<T>,
133 mut op: F,
134) -> Result<ChunkedArray<V>, E>
135where
136 T: PolarsDataType,
137 V: PolarsDataType,
138 F: FnMut(T::Physical<'a>) -> Result<K, E>,
139 V::Array: ArrayFromIter<K>,
140{
141 if ca.null_count() == ca.len() {
142 let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow(CompatLevel::newest()));
143 return Ok(ChunkedArray::with_chunk(ca.name().clone(), arr));
144 }
145
146 let iter = ca.downcast_iter().map(|arr| {
147 let validity = arr.validity().cloned();
148 let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?;
149 Ok(arr.with_validity_typed(validity))
150 });
151 ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter)
152}
153
154#[inline]
159pub fn unary_mut_values<T, V, F, Arr>(ca: &ChunkedArray<T>, mut op: F) -> ChunkedArray<V>
160where
161 T: PolarsDataType,
162 V: PolarsDataType<Array = Arr>,
163 Arr: Array + StaticArray,
164 F: FnMut(&T::Array) -> Arr,
165{
166 let iter = ca
167 .downcast_iter()
168 .map(|arr| op(arr).with_validity_typed(arr.validity().cloned()));
169 ChunkedArray::from_chunk_iter(ca.name().clone(), iter)
170}
171
172#[inline]
174pub fn unary_mut_with_options<T, V, F, Arr>(ca: &ChunkedArray<T>, op: F) -> ChunkedArray<V>
175where
176 T: PolarsDataType,
177 V: PolarsDataType<Array = Arr>,
178 Arr: Array + StaticArray,
179 F: FnMut(&T::Array) -> Arr,
180{
181 ChunkedArray::from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op))
182}
183
184#[inline]
185pub fn try_unary_mut_with_options<T, V, F, Arr, E>(
186 ca: &ChunkedArray<T>,
187 op: F,
188) -> Result<ChunkedArray<V>, E>
189where
190 T: PolarsDataType,
191 V: PolarsDataType<Array = Arr>,
192 Arr: Array + StaticArray,
193 F: FnMut(&T::Array) -> Result<Arr, E>,
194 E: Error,
195{
196 ChunkedArray::try_from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op))
197}
198
199#[inline]
200pub fn binary_elementwise<T, U, V, F>(
201 lhs: &ChunkedArray<T>,
202 rhs: &ChunkedArray<U>,
203 mut op: F,
204) -> ChunkedArray<V>
205where
206 T: PolarsDataType,
207 U: PolarsDataType,
208 V: PolarsDataType,
209 F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
210 V::Array: for<'a> ArrayFromIter<
211 <F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
212 >,
213{
214 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
215 let iter = lhs
216 .downcast_iter()
217 .zip(rhs.downcast_iter())
218 .map(|(lhs_arr, rhs_arr)| {
219 let element_iter = lhs_arr
220 .iter()
221 .zip(rhs_arr.iter())
222 .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val));
223 element_iter.collect_arr()
224 });
225 ChunkedArray::from_chunk_iter(lhs.name().clone(), iter)
226}
227
228#[inline]
229pub fn binary_elementwise_for_each<'a, 'b, T, U, F>(
230 lhs: &'a ChunkedArray<T>,
231 rhs: &'b ChunkedArray<U>,
232 mut op: F,
233) where
234 T: PolarsDataType,
235 U: PolarsDataType,
236 F: FnMut(Option<T::Physical<'a>>, Option<U::Physical<'b>>),
237{
238 let mut lhs_arr_iter = lhs.downcast_iter();
239 let mut rhs_arr_iter = rhs.downcast_iter();
240
241 let lhs_arr = lhs_arr_iter.next().unwrap();
242 let rhs_arr = rhs_arr_iter.next().unwrap();
243
244 let mut lhs_remaining = lhs_arr.len();
245 let mut rhs_remaining = rhs_arr.len();
246 let mut lhs_iter = lhs_arr.iter();
247 let mut rhs_iter = rhs_arr.iter();
248
249 loop {
250 let range = std::cmp::min(lhs_remaining, rhs_remaining);
251
252 for _ in 0..range {
253 let lhs_opt_val = unsafe { lhs_iter.next().unwrap_unchecked() };
255 let rhs_opt_val = unsafe { rhs_iter.next().unwrap_unchecked() };
256 op(lhs_opt_val, rhs_opt_val)
257 }
258 lhs_remaining -= range;
259 rhs_remaining -= range;
260
261 if lhs_remaining == 0 {
262 let Some(new_arr) = lhs_arr_iter.next() else {
263 return;
264 };
265 lhs_remaining = new_arr.len();
266 lhs_iter = new_arr.iter();
267 }
268 if rhs_remaining == 0 {
269 let Some(new_arr) = rhs_arr_iter.next() else {
270 return;
271 };
272 rhs_remaining = new_arr.len();
273 rhs_iter = new_arr.iter();
274 }
275 }
276}
277
278#[inline]
279pub fn try_binary_elementwise<T, U, V, F, K, E>(
280 lhs: &ChunkedArray<T>,
281 rhs: &ChunkedArray<U>,
282 mut op: F,
283) -> Result<ChunkedArray<V>, E>
284where
285 T: PolarsDataType,
286 U: PolarsDataType,
287 V: PolarsDataType,
288 F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
289 V::Array: ArrayFromIter<Option<K>>,
290{
291 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
292 let iter = lhs
293 .downcast_iter()
294 .zip(rhs.downcast_iter())
295 .map(|(lhs_arr, rhs_arr)| {
296 let element_iter = lhs_arr
297 .iter()
298 .zip(rhs_arr.iter())
299 .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val));
300 element_iter.try_collect_arr()
301 });
302 ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter)
303}
304
305#[inline]
306pub fn binary_elementwise_values<T, U, V, F, K>(
307 lhs: &ChunkedArray<T>,
308 rhs: &ChunkedArray<U>,
309 mut op: F,
310) -> ChunkedArray<V>
311where
312 T: PolarsDataType,
313 U: PolarsDataType,
314 V: PolarsDataType,
315 F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
316 V::Array: ArrayFromIter<K>,
317{
318 if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
319 let len = lhs.len().min(rhs.len());
320 let arr = V::Array::full_null(len, V::get_dtype().to_arrow(CompatLevel::newest()));
321
322 return ChunkedArray::with_chunk(lhs.name().clone(), arr);
323 }
324
325 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
326
327 let iter = lhs
328 .downcast_iter()
329 .zip(rhs.downcast_iter())
330 .map(|(lhs_arr, rhs_arr)| {
331 let validity = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());
332
333 let element_iter = lhs_arr
334 .values_iter()
335 .zip(rhs_arr.values_iter())
336 .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val));
337
338 let array: V::Array = element_iter.collect_arr();
339 array.with_validity_typed(validity)
340 });
341 ChunkedArray::from_chunk_iter(lhs.name().clone(), iter)
342}
343
344#[inline]
348pub fn binary_elementwise_into_string_amortized<T, U, F>(
349 lhs: &ChunkedArray<T>,
350 rhs: &ChunkedArray<U>,
351 mut op: F,
352) -> StringChunked
353where
354 T: PolarsDataType,
355 U: PolarsDataType,
356 F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>, &mut String),
357{
358 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
359 let mut buf = String::new();
360 let iter = lhs
361 .downcast_iter()
362 .zip(rhs.downcast_iter())
363 .map(|(lhs_arr, rhs_arr)| {
364 let mut mutarr = MutablePlString::with_capacity(lhs_arr.len());
365 lhs_arr
366 .iter()
367 .zip(rhs_arr.iter())
368 .for_each(|(lhs_opt, rhs_opt)| match (lhs_opt, rhs_opt) {
369 (None, _) | (_, None) => mutarr.push_null(),
370 (Some(lhs_val), Some(rhs_val)) => {
371 buf.clear();
372 op(lhs_val, rhs_val, &mut buf);
373 mutarr.push_value(&buf)
374 },
375 });
376 mutarr.freeze()
377 });
378 ChunkedArray::from_chunk_iter(lhs.name().clone(), iter)
379}
380
381#[inline]
386pub fn binary_mut_values<T, U, V, F, Arr>(
387 lhs: &ChunkedArray<T>,
388 rhs: &ChunkedArray<U>,
389 mut op: F,
390 name: PlSmallStr,
391) -> ChunkedArray<V>
392where
393 T: PolarsDataType,
394 U: PolarsDataType,
395 V: PolarsDataType<Array = Arr>,
396 Arr: Array + StaticArray,
397 F: FnMut(&T::Array, &U::Array) -> Arr,
398{
399 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
400 let iter = lhs
401 .downcast_iter()
402 .zip(rhs.downcast_iter())
403 .map(|(lhs_arr, rhs_arr)| {
404 let ret = op(lhs_arr, rhs_arr);
405 let inp_val = combine_validities_and(lhs_arr.validity(), rhs_arr.validity());
406 let val = combine_validities_and(inp_val.as_ref(), ret.validity());
407 ret.with_validity_typed(val)
408 });
409 ChunkedArray::from_chunk_iter(name, iter)
410}
411
412#[inline]
414pub fn binary_mut_with_options<T, U, V, F, Arr>(
415 lhs: &ChunkedArray<T>,
416 rhs: &ChunkedArray<U>,
417 mut op: F,
418 name: PlSmallStr,
419) -> ChunkedArray<V>
420where
421 T: PolarsDataType,
422 U: PolarsDataType,
423 V: PolarsDataType<Array = Arr>,
424 Arr: Array,
425 F: FnMut(&T::Array, &U::Array) -> Arr,
426{
427 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
428 let iter = lhs
429 .downcast_iter()
430 .zip(rhs.downcast_iter())
431 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
432 ChunkedArray::from_chunk_iter(name, iter)
433}
434
435#[inline]
436pub fn try_binary_mut_with_options<T, U, V, F, Arr, E>(
437 lhs: &ChunkedArray<T>,
438 rhs: &ChunkedArray<U>,
439 mut op: F,
440 name: PlSmallStr,
441) -> Result<ChunkedArray<V>, E>
442where
443 T: PolarsDataType,
444 U: PolarsDataType,
445 V: PolarsDataType<Array = Arr>,
446 Arr: Array,
447 F: FnMut(&T::Array, &U::Array) -> Result<Arr, E>,
448 E: Error,
449{
450 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
451 let iter = lhs
452 .downcast_iter()
453 .zip(rhs.downcast_iter())
454 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
455 ChunkedArray::try_from_chunk_iter(name, iter)
456}
457
458pub fn binary<T, U, V, F, Arr>(
460 lhs: &ChunkedArray<T>,
461 rhs: &ChunkedArray<U>,
462 op: F,
463) -> ChunkedArray<V>
464where
465 T: PolarsDataType,
466 U: PolarsDataType,
467 V: PolarsDataType<Array = Arr>,
468 Arr: Array,
469 F: FnMut(&T::Array, &U::Array) -> Arr,
470{
471 binary_mut_with_options(lhs, rhs, op, lhs.name().clone())
472}
473
474pub fn binary_owned<L, R, V, F, Arr>(
476 lhs: ChunkedArray<L>,
477 rhs: ChunkedArray<R>,
478 mut op: F,
479) -> ChunkedArray<V>
480where
481 L: PolarsDataType,
482 R: PolarsDataType,
483 V: PolarsDataType<Array = Arr>,
484 Arr: Array,
485 F: FnMut(L::Array, R::Array) -> Arr,
486{
487 let name = lhs.name().clone();
488 let (lhs, rhs) = align_chunks_binary_owned(lhs, rhs);
489 let iter = lhs
490 .downcast_into_iter()
491 .zip(rhs.downcast_into_iter())
492 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
493 ChunkedArray::from_chunk_iter(name, iter)
494}
495
496pub fn try_binary<T, U, V, F, Arr, E>(
498 lhs: &ChunkedArray<T>,
499 rhs: &ChunkedArray<U>,
500 mut op: F,
501) -> Result<ChunkedArray<V>, E>
502where
503 T: PolarsDataType,
504 U: PolarsDataType,
505 V: PolarsDataType<Array = Arr>,
506 Arr: Array,
507 F: FnMut(&T::Array, &U::Array) -> Result<Arr, E>,
508 E: Error,
509{
510 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
511 let iter = lhs
512 .downcast_iter()
513 .zip(rhs.downcast_iter())
514 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
515 ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter)
516}
517
518#[inline]
523pub unsafe fn binary_unchecked_same_type<T, U, F>(
524 lhs: &ChunkedArray<T>,
525 rhs: &ChunkedArray<U>,
526 mut op: F,
527 keep_sorted: bool,
528 keep_fast_explode: bool,
529) -> ChunkedArray<T>
530where
531 T: PolarsDataType,
532 U: PolarsDataType,
533 F: FnMut(&T::Array, &U::Array) -> Box<dyn Array>,
534{
535 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
536 let chunks = lhs
537 .downcast_iter()
538 .zip(rhs.downcast_iter())
539 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
540 .collect();
541
542 let mut ca = lhs.copy_with_chunks(chunks);
543
544 let mut retain_flags = StatisticsFlags::empty();
545 use StatisticsFlags as F;
546 retain_flags.set(F::IS_SORTED_ANY, keep_sorted);
547 retain_flags.set(F::CAN_FAST_EXPLODE_LIST, keep_fast_explode);
548 ca.retain_flags_from(lhs.as_ref(), retain_flags);
549
550 ca
551}
552
553#[inline]
554pub fn binary_to_series<T, U, F>(
555 lhs: &ChunkedArray<T>,
556 rhs: &ChunkedArray<U>,
557 mut op: F,
558) -> PolarsResult<Series>
559where
560 T: PolarsDataType,
561 U: PolarsDataType,
562 F: FnMut(&T::Array, &U::Array) -> Box<dyn Array>,
563{
564 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
565 let chunks = lhs
566 .downcast_iter()
567 .zip(rhs.downcast_iter())
568 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
569 .collect::<Vec<_>>();
570 Series::try_from((lhs.name().clone(), chunks))
571}
572
573#[inline]
578pub unsafe fn try_binary_unchecked_same_type<T, U, F, E>(
579 lhs: &ChunkedArray<T>,
580 rhs: &ChunkedArray<U>,
581 mut op: F,
582 keep_sorted: bool,
583 keep_fast_explode: bool,
584) -> Result<ChunkedArray<T>, E>
585where
586 T: PolarsDataType,
587 U: PolarsDataType,
588 F: FnMut(&T::Array, &U::Array) -> Result<Box<dyn Array>, E>,
589 E: Error,
590{
591 let (lhs, rhs) = align_chunks_binary(lhs, rhs);
592 let chunks = lhs
593 .downcast_iter()
594 .zip(rhs.downcast_iter())
595 .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr))
596 .collect::<Result<Vec<_>, E>>()?;
597 let mut ca = lhs.copy_with_chunks(chunks);
598
599 let mut retain_flags = StatisticsFlags::empty();
600 use StatisticsFlags as F;
601 retain_flags.set(F::IS_SORTED_ANY, keep_sorted);
602 retain_flags.set(F::CAN_FAST_EXPLODE_LIST, keep_fast_explode);
603 ca.retain_flags_from(lhs.as_ref(), retain_flags);
604
605 Ok(ca)
606}
607
608#[inline]
609pub fn try_ternary_elementwise<T, U, V, G, F, K, E>(
610 ca1: &ChunkedArray<T>,
611 ca2: &ChunkedArray<U>,
612 ca3: &ChunkedArray<G>,
613 mut op: F,
614) -> Result<ChunkedArray<V>, E>
615where
616 T: PolarsDataType,
617 U: PolarsDataType,
618 V: PolarsDataType,
619 G: PolarsDataType,
620 F: for<'a> FnMut(
621 Option<T::Physical<'a>>,
622 Option<U::Physical<'a>>,
623 Option<G::Physical<'a>>,
624 ) -> Result<Option<K>, E>,
625 V::Array: ArrayFromIter<Option<K>>,
626{
627 let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3);
628 let iter = ca1
629 .downcast_iter()
630 .zip(ca2.downcast_iter())
631 .zip(ca3.downcast_iter())
632 .map(|((ca1_arr, ca2_arr), ca3_arr)| {
633 let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map(
634 |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| {
635 op(ca1_opt_val, ca2_opt_val, ca3_opt_val)
636 },
637 );
638 element_iter.try_collect_arr()
639 });
640 ChunkedArray::try_from_chunk_iter(ca1.name().clone(), iter)
641}
642
643#[inline]
644pub fn ternary_elementwise<T, U, V, G, F>(
645 ca1: &ChunkedArray<T>,
646 ca2: &ChunkedArray<U>,
647 ca3: &ChunkedArray<G>,
648 mut op: F,
649) -> ChunkedArray<V>
650where
651 T: PolarsDataType,
652 U: PolarsDataType,
653 G: PolarsDataType,
654 V: PolarsDataType,
655 F: for<'a> TernaryFnMut<
656 Option<T::Physical<'a>>,
657 Option<U::Physical<'a>>,
658 Option<G::Physical<'a>>,
659 >,
660 V::Array: for<'a> ArrayFromIter<
661 <F as TernaryFnMut<
662 Option<T::Physical<'a>>,
663 Option<U::Physical<'a>>,
664 Option<G::Physical<'a>>,
665 >>::Ret,
666 >,
667{
668 let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3);
669 let iter = ca1
670 .downcast_iter()
671 .zip(ca2.downcast_iter())
672 .zip(ca3.downcast_iter())
673 .map(|((ca1_arr, ca2_arr), ca3_arr)| {
674 let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map(
675 |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| {
676 op(ca1_opt_val, ca2_opt_val, ca3_opt_val)
677 },
678 );
679 element_iter.collect_arr()
680 });
681 ChunkedArray::from_chunk_iter(ca1.name().clone(), iter)
682}
683
684pub fn broadcast_binary_elementwise<T, U, V, F>(
685 lhs: &ChunkedArray<T>,
686 rhs: &ChunkedArray<U>,
687 mut op: F,
688) -> ChunkedArray<V>
689where
690 T: PolarsDataType,
691 U: PolarsDataType,
692 V: PolarsDataType,
693 F: for<'a> BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>,
694 V::Array: for<'a> ArrayFromIter<
695 <F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
696 >,
697{
698 match (lhs.len(), rhs.len()) {
699 (1, _) => {
700 let a = unsafe { lhs.get_unchecked(0) };
701 unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone())
702 },
703 (_, 1) => {
704 let b = unsafe { rhs.get_unchecked(0) };
705 unary_elementwise(lhs, |a| op(a, b.clone()))
706 },
707 _ => binary_elementwise(lhs, rhs, op),
708 }
709}
710
711pub fn broadcast_try_binary_elementwise<T, U, V, F, K, E>(
712 lhs: &ChunkedArray<T>,
713 rhs: &ChunkedArray<U>,
714 mut op: F,
715) -> Result<ChunkedArray<V>, E>
716where
717 T: PolarsDataType,
718 U: PolarsDataType,
719 V: PolarsDataType,
720 F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
721 V::Array: ArrayFromIter<Option<K>>,
722{
723 match (lhs.len(), rhs.len()) {
724 (1, _) => {
725 let a = unsafe { lhs.get_unchecked(0) };
726 Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name().clone()))
727 },
728 (_, 1) => {
729 let b = unsafe { rhs.get_unchecked(0) };
730 try_unary_elementwise(lhs, |a| op(a, b.clone()))
731 },
732 _ => try_binary_elementwise(lhs, rhs, op),
733 }
734}
735
736pub fn broadcast_binary_elementwise_values<T, U, V, F, K>(
737 lhs: &ChunkedArray<T>,
738 rhs: &ChunkedArray<U>,
739 mut op: F,
740) -> ChunkedArray<V>
741where
742 T: PolarsDataType,
743 U: PolarsDataType,
744 V: PolarsDataType,
745 F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
746 V::Array: ArrayFromIter<K>,
747{
748 if lhs.null_count() == lhs.len() || rhs.null_count() == rhs.len() {
749 let min = lhs.len().min(rhs.len());
750 let max = lhs.len().max(rhs.len());
751 let len = if min == 1 { max } else { min };
752 let arr = V::Array::full_null(len, V::get_dtype().to_arrow(CompatLevel::newest()));
753
754 return ChunkedArray::with_chunk(lhs.name().clone(), arr);
755 }
756
757 match (lhs.len(), rhs.len()) {
758 (1, _) => {
759 let a = unsafe { lhs.value_unchecked(0) };
760 unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone())
761 },
762 (_, 1) => {
763 let b = unsafe { rhs.value_unchecked(0) };
764 unary_elementwise_values(lhs, |a| op(a, b.clone()))
765 },
766 _ => binary_elementwise_values(lhs, rhs, op),
767 }
768}
769
770pub fn apply_binary_kernel_broadcast<'l, 'r, L, R, O, K, LK, RK>(
771 lhs: &'l ChunkedArray<L>,
772 rhs: &'r ChunkedArray<R>,
773 kernel: K,
774 lhs_broadcast_kernel: LK,
775 rhs_broadcast_kernel: RK,
776) -> ChunkedArray<O>
777where
778 L: PolarsDataType,
779 R: PolarsDataType,
780 O: PolarsDataType,
781 K: Fn(&L::Array, &R::Array) -> O::Array,
782 LK: Fn(L::Physical<'l>, &R::Array) -> O::Array,
783 RK: Fn(&L::Array, R::Physical<'r>) -> O::Array,
784{
785 let name = lhs.name();
786 let out = match (lhs.len(), rhs.len()) {
787 (a, b) if a == b => binary(lhs, rhs, |lhs, rhs| kernel(lhs, rhs)),
788 (_, 1) => {
790 let opt_rhs = rhs.get(0);
791 match opt_rhs {
792 None => {
793 let arr = O::Array::full_null(
794 lhs.len(),
795 O::get_dtype().to_arrow(CompatLevel::newest()),
796 );
797 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
798 },
799 Some(rhs) => unary_kernel(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())),
800 }
801 },
802 (1, _) => {
803 let opt_lhs = lhs.get(0);
804 match opt_lhs {
805 None => {
806 let arr = O::Array::full_null(
807 rhs.len(),
808 O::get_dtype().to_arrow(CompatLevel::newest()),
809 );
810 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
811 },
812 Some(lhs) => unary_kernel(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)),
813 }
814 },
815 _ => panic!("Cannot apply operation on arrays of different lengths"),
816 };
817 out.with_name(name.clone())
818}
819
820pub fn apply_binary_kernel_broadcast_owned<L, R, O, K, LK, RK>(
821 lhs: ChunkedArray<L>,
822 rhs: ChunkedArray<R>,
823 kernel: K,
824 lhs_broadcast_kernel: LK,
825 rhs_broadcast_kernel: RK,
826) -> ChunkedArray<O>
827where
828 L: PolarsDataType,
829 R: PolarsDataType,
830 O: PolarsDataType,
831 K: Fn(L::Array, R::Array) -> O::Array,
832 for<'a> LK: Fn(L::Physical<'a>, R::Array) -> O::Array,
833 for<'a> RK: Fn(L::Array, R::Physical<'a>) -> O::Array,
834{
835 let name = lhs.name().to_owned();
836 let out = match (lhs.len(), rhs.len()) {
837 (a, b) if a == b => binary_owned(lhs, rhs, kernel),
838 (_, 1) => {
840 let opt_rhs = rhs.get(0);
841 match opt_rhs {
842 None => {
843 let arr = O::Array::full_null(
844 lhs.len(),
845 O::get_dtype().to_arrow(CompatLevel::newest()),
846 );
847 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
848 },
849 Some(rhs) => unary_kernel_owned(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())),
850 }
851 },
852 (1, _) => {
853 let opt_lhs = lhs.get(0);
854 match opt_lhs {
855 None => {
856 let arr = O::Array::full_null(
857 rhs.len(),
858 O::get_dtype().to_arrow(CompatLevel::newest()),
859 );
860 ChunkedArray::<O>::with_chunk(lhs.name().clone(), arr)
861 },
862 Some(lhs) => unary_kernel_owned(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)),
863 }
864 },
865 _ => panic!("Cannot apply operation on arrays of different lengths"),
866 };
867 out.with_name(name)
868}