1use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8use crate::prelude::DataType;
9
10impl NumOpsDispatchInner for ListType {
11 fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
12 NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
13 }
14
15 fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
16 NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
17 }
18
19 fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
20 NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
21 }
22
23 fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
24 NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
25 }
26
27 fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
28 NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
29 }
30}
31
32#[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
33#[derive(Clone)]
34pub struct NumericListOp(NumericOp);
35
36impl NumericListOp {
37 pub fn add() -> Self {
38 Self(NumericOp::Add)
39 }
40
41 pub fn sub() -> Self {
42 Self(NumericOp::Sub)
43 }
44
45 pub fn mul() -> Self {
46 Self(NumericOp::Mul)
47 }
48
49 pub fn div() -> Self {
50 Self(NumericOp::Div)
51 }
52
53 pub fn rem() -> Self {
54 Self(NumericOp::Rem)
55 }
56
57 pub fn floor_div() -> Self {
58 Self(NumericOp::FloorDiv)
59 }
60
61 pub fn try_get_leaf_supertype(
62 &self,
63 prim_dtype_lhs: &DataType,
64 prim_dtype_rhs: &DataType,
65 ) -> PolarsResult<DataType> {
66 self.0
67 .try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)
68 }
69}
70
71impl NumericListOp {
72 #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
73 pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
74 feature_gated!("list_arithmetic", {
75 use std::borrow::Cow;
76
77 use either::Either;
78
79 let lhs = lhs
83 .trim_lists_to_normalized_offsets()
84 .map_or(Cow::Borrowed(lhs), Cow::Owned);
85 let rhs = rhs
86 .trim_lists_to_normalized_offsets()
87 .map_or(Cow::Borrowed(rhs), Cow::Owned);
88
89 let lhs = lhs.rechunk();
90 let rhs = rhs.rechunk();
91
92 let binary_op_exec = match ListNumericOpHelper::try_new(
93 self.clone(),
94 lhs.name().clone(),
95 lhs.dtype(),
96 rhs.dtype(),
97 lhs.len(),
98 rhs.len(),
99 {
100 let (a, b) = lhs.list_offsets_and_validities_recursive();
101 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
102 (a, b, lhs.clone())
103 },
104 {
105 let (a, b) = rhs.list_offsets_and_validities_recursive();
106 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
107 (a, b, rhs.clone())
108 },
109 lhs.rechunk_validity(),
110 rhs.rechunk_validity(),
111 )? {
112 Either::Left(v) => v,
113 Either::Right(ca) => return Ok(ca.into_series()),
114 };
115
116 Ok(binary_op_exec.finish()?.into_series())
117 })
118 }
119}
120
121#[cfg(feature = "list_arithmetic")]
122use inner::ListNumericOpHelper;
123
124#[cfg(feature = "list_arithmetic")]
125mod inner {
126 use arrow::bitmap::Bitmap;
127 use arrow::compute::utils::combine_validities_and;
128 use arrow::offset::OffsetsBuffer;
129 use either::Either;
130 use list_utils::with_match_pl_num_arith;
131 use num_traits::Zero;
132 use polars_compute::arithmetic::pl_num::PlNumArithmetic;
133 use polars_utils::float::IsFloat;
134
135 use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
136 use super::super::*;
137
138 pub(super) struct ListNumericOpHelper {
141 op: NumericListOp,
142 output_name: PlSmallStr,
143 op_apply_type: BinaryOpApplyType,
144 broadcast: Broadcast,
145 output_dtype: DataType,
146 output_primitive_dtype: DataType,
147 output_len: usize,
148 outer_validity: Bitmap,
151 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
153 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
154 list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
155 swapped: bool,
156 }
157
158 impl ListNumericOpHelper {
161 #[allow(clippy::too_many_arguments)]
177 pub(super) fn try_new(
178 op: NumericListOp,
179 output_name: PlSmallStr,
180 dtype_lhs: &DataType,
181 dtype_rhs: &DataType,
182 len_lhs: usize,
183 len_rhs: usize,
184 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
185 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
186 validity_lhs: Option<Bitmap>,
187 validity_rhs: Option<Bitmap>,
188 ) -> PolarsResult<Either<Self, ListChunked>> {
189 let prim_dtype_lhs = dtype_lhs.leaf_dtype();
190 let prim_dtype_rhs = dtype_rhs.leaf_dtype();
191
192 let output_primitive_dtype =
193 op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
194
195 fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
196 match dtype {
197 DataType::List(inner) => is_list_type_at_all_levels(inner),
198 dt if dt.is_supported_list_arithmetic_input() => true,
199 _ => false,
200 }
201 }
202
203 let op_err_msg = |err_reason: &str| {
204 polars_err!(
205 InvalidOperation:
206 "cannot {} columns: {}: (left: {}, right: {})",
207 op.0.name(), err_reason, dtype_lhs, dtype_rhs,
208 )
209 };
210
211 let ensure_list_type_at_all_levels = |dtype: &DataType| {
212 if !is_list_type_at_all_levels(dtype) {
213 Err(op_err_msg("dtype was not list on all nesting levels"))
214 } else {
215 Ok(())
216 }
217 };
218
219 let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
220 (l @ DataType::List(a), r @ DataType::List(b)) => {
221 if ![a, b]
226 .into_iter()
227 .all(|x| x.is_supported_list_arithmetic_input())
228 {
229 polars_bail!(
230 InvalidOperation:
231 "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
232 op.0.name(), l, r,
233 );
234 }
235 (BinaryOpApplyType::ListToList, l)
236 },
237 (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
238 ensure_list_type_at_all_levels(list_dtype)?;
239 (BinaryOpApplyType::ListToPrimitive, list_dtype)
240 },
241 (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
242 ensure_list_type_at_all_levels(list_dtype)?;
243 (BinaryOpApplyType::PrimitiveToList, list_dtype)
244 },
245 (l, r) => polars_bail!(
246 InvalidOperation:
247 "{} operation not supported for dtypes: {} != {}",
248 op.0.name(), l, r,
249 ),
250 };
251
252 let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
253
254 let (broadcast, output_len) = match (len_lhs, len_rhs) {
255 (l, r) if l == r => (Broadcast::NoBroadcast, l),
256 (1, v) => (Broadcast::Left, v),
257 (v, 1) => (Broadcast::Right, v),
258 (l, r) => polars_bail!(
259 ShapeMismatch:
260 "cannot {} two columns of differing lengths: {} != {}",
261 op.0.name(), l, r
262 ),
263 };
264
265 let DataType::List(output_inner_dtype) = &output_dtype else {
266 unreachable!()
267 };
268
269 if output_len == 0
276 || (matches!(
277 &op_apply_type,
278 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
279 ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
280 || (matches!(
281 &op_apply_type,
282 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
283 ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
284 {
285 return Ok(Either::Right(ListChunked::full_null_with_dtype(
286 output_name,
287 output_len,
288 output_inner_dtype.as_ref(),
289 )));
290 }
291
292 let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
297 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
299 combine_validities_and(l.as_ref(), r.as_ref())
300 },
301 (
303 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
304 Broadcast::NoBroadcast | Broadcast::Right,
305 v,
306 _,
307 )
308 | (
309 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
310 Broadcast::NoBroadcast | Broadcast::Left,
311 _,
312 v,
313 ) => v,
314 _ => None,
315 }
316 .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
317
318 Ok(Either::Left(Self {
319 op,
320 output_name,
321 op_apply_type,
322 broadcast,
323 output_dtype: output_dtype.clone(),
324 output_primitive_dtype,
325 output_len,
326 outer_validity,
327 data_lhs,
328 data_rhs,
329 list_to_prim_lhs: None,
330 swapped: false,
331 }))
332 }
333
334 pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
335 self.swapped = true;
350
351 match (&self.op_apply_type, &self.broadcast) {
352 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
353 | (BinaryOpApplyType::ListToList, Broadcast::Right)
354 | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
355 | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
356 self.swapped = false;
357 self._finish_impl_dispatch()
358 },
359 (BinaryOpApplyType::ListToList, Broadcast::Left) => {
360 self.broadcast = Broadcast::Right;
361
362 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
363 self._finish_impl_dispatch()
364 },
365 (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
366 self.list_to_prim_lhs
367 .replace(Self::materialize_broadcasted_list(
368 &mut self.data_lhs,
369 self.output_len,
370 &self.output_primitive_dtype,
371 ));
372
373 self.broadcast = Broadcast::NoBroadcast;
374
375 self.swapped = false;
378 self._finish_impl_dispatch()
379 },
380 (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
381 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
382
383 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
384 self._finish_impl_dispatch()
385 },
386 (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
387 self.list_to_prim_lhs
392 .replace(Self::materialize_broadcasted_list(
393 &mut self.data_rhs,
394 self.output_len,
395 &self.output_primitive_dtype,
396 ));
397
398 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
399 self.broadcast = Broadcast::NoBroadcast;
400
401 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
402 self._finish_impl_dispatch()
403 },
404 (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
405 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
406 self.broadcast = Broadcast::Right;
407
408 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
409 self._finish_impl_dispatch()
410 },
411 }
412 }
413
414 fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
415 let output_dtype = self.output_dtype.clone();
416 let output_len = self.output_len;
417
418 let prim_lhs = self
419 .data_lhs
420 .2
421 .get_leaf_array()
422 .cast(&self.output_primitive_dtype)?
423 .rechunk();
424 let prim_rhs = self
425 .data_rhs
426 .2
427 .get_leaf_array()
428 .cast(&self.output_primitive_dtype)?
429 .rechunk();
430
431 debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
432 let prim_dtype = prim_lhs.dtype();
433 debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
434
435 let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
437 self._finish_impl::<$T>(prim_lhs, prim_rhs)
438 })?;
439
440 debug_assert_eq!(out.dtype(), &output_dtype);
441 assert_eq!(out.len(), output_len);
442
443 Ok(out)
444 }
445
446 fn _finish_impl<T: PolarsNumericType>(
448 &mut self,
449 prim_s_lhs: Series,
450 prim_s_rhs: Series,
451 ) -> PolarsResult<ListChunked>
452 where
453 T::Native: PlNumArithmetic,
454 PrimitiveArray<T::Native>:
455 polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
456 T::Native: Zero + IsFloat,
457 {
458 #[inline(never)]
459 fn check_mismatch_pos(
460 mismatch_pos: usize,
461 offsets_lhs: &OffsetsBuffer<i64>,
462 offsets_rhs: &OffsetsBuffer<i64>,
463 ) -> PolarsResult<()> {
464 if mismatch_pos < offsets_lhs.len_proxy() {
465 let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
467 0
468 } else {
469 mismatch_pos
470 });
471 polars_bail!(
472 ShapeMismatch:
473 "list lengths differed at index {}: {} != {}",
474 mismatch_pos,
475 offsets_lhs.length_at(mismatch_pos), len_r
476 )
477 }
478 Ok(())
479 }
480
481 let mut arr_lhs = {
482 let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
483 assert_eq!(ca.chunks().len(), 1);
484 ca.downcast_get(0).unwrap().clone()
485 };
486
487 let mut arr_rhs = {
488 let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
489 assert_eq!(ca.chunks().len(), 1);
490 ca.downcast_get(0).unwrap().clone()
491 };
492
493 match (&self.op_apply_type, &self.broadcast) {
494 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
497 _ if self.list_to_prim_lhs.is_none() => {
498 self.op.0.prepare_numeric_op_side_validities::<T>(
499 &mut arr_lhs,
500 &mut arr_rhs,
501 self.swapped,
502 )
503 },
504 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
505 },
507 _ => unreachable!(),
508 }
509
510 let out = match (&self.op_apply_type, &self.broadcast) {
519 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
520 let offsets_lhs = &self.data_lhs.0[0];
521 let offsets_rhs = &self.data_rhs.0[0];
522
523 assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
524
525 let n_values = arr_lhs.len();
527 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
528 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
529
530 let mut mismatch_pos = 0;
533
534 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
535 for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
536 .offset_and_length_iter()
537 .zip(offsets_rhs.offset_and_length_iter())
538 .enumerate()
539 {
540 if
541 (mismatch_pos == i)
542 & (
543 (lhs_len == rhs_len)
544 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
545 )
546 {
547 mismatch_pos += 1;
548 }
549
550 let len: usize = lhs_len.min(rhs_len);
553
554 for i in 0..len {
555 let l_idx = i + lhs_start;
556 let r_idx = i + rhs_start;
557
558 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
559 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
560 let v = $OP(l, r);
561
562 unsafe { out_ptr.add(l_idx).write(v) };
563 }
564 }
565 });
566
567 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
568
569 unsafe { out_vec.set_len(n_values) };
570
571 #[inline(never)]
573 fn combine_validities_list_to_list_no_broadcast(
574 offsets_lhs: &OffsetsBuffer<i64>,
575 offsets_rhs: &OffsetsBuffer<i64>,
576 validity_lhs: Option<&Bitmap>,
577 validity_rhs: Option<&Bitmap>,
578 len_lhs: usize,
579 ) -> Option<Bitmap> {
580 match (validity_lhs, validity_rhs) {
581 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
582 (Some(v), None) => return Some(v.clone()),
583 (None, Some(v)) => {
584 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
585 },
586 (None, None) => None,
587 }
588 .map(|(mut validity_out, validity_rhs)| {
589 for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
590 .offset_and_length_iter()
591 .zip(offsets_rhs.offset_and_length_iter())
592 {
593 let len: usize = lhs_len.min(rhs_len);
594
595 for i in 0..len {
596 let l_idx = i + lhs_start;
597 let r_idx = i + rhs_start;
598
599 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
600 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
601 let is_valid = l_valid & r_valid;
602
603 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
605 }
606 }
607
608 validity_out.freeze()
609 })
610 }
611
612 let leaf_validity = combine_validities_list_to_list_no_broadcast(
613 offsets_lhs,
614 offsets_rhs,
615 arr_lhs.validity(),
616 arr_rhs.validity(),
617 arr_lhs.len(),
618 );
619
620 let arr =
621 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
622
623 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
624 assert_eq!(offsets.len(), 1);
625
626 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
627 },
628 (BinaryOpApplyType::ListToList, Broadcast::Right) => {
629 let offsets_lhs = &self.data_lhs.0[0];
630 let offsets_rhs = &self.data_rhs.0[0];
631
632 let n_values = arr_lhs.len();
634 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
635 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
636
637 assert_eq!(offsets_rhs.len_proxy(), 1);
638 let rhs_start = *offsets_rhs.first() as usize;
639 let width = offsets_rhs.range() as usize;
640
641 let mut mismatch_pos = 0;
642
643 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
644 for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
645 if ((lhs_len == width) & (mismatch_pos == i))
646 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
647 {
648 mismatch_pos += 1;
649 }
650
651 let len: usize = lhs_len.min(width);
652
653 for i in 0..len {
654 let l_idx = i + lhs_start;
655 let r_idx = i + rhs_start;
656
657 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
658 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
659 let v = $OP(l, r);
660
661 unsafe {
662 out_ptr.add(l_idx).write(v);
663 }
664 }
665 }
666 });
667
668 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
669
670 unsafe { out_vec.set_len(n_values) };
671
672 #[inline(never)]
673 fn combine_validities_list_to_list_broadcast_right(
674 offsets_lhs: &OffsetsBuffer<i64>,
675 validity_lhs: Option<&Bitmap>,
676 validity_rhs: Option<&Bitmap>,
677 len_lhs: usize,
678 width: usize,
679 rhs_start: usize,
680 ) -> Option<Bitmap> {
681 match (validity_lhs, validity_rhs) {
682 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
683 (Some(v), None) => return Some(v.clone()),
684 (None, Some(v)) => {
685 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
686 },
687 (None, None) => None,
688 }
689 .map(|(mut validity_out, validity_rhs)| {
690 for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
691 let len: usize = lhs_len.min(width);
692
693 for i in 0..len {
694 let l_idx = i + lhs_start;
695 let r_idx = i + rhs_start;
696
697 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
698 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
699 let is_valid = l_valid & r_valid;
700
701 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
703 }
704 }
705
706 validity_out.freeze()
707 })
708 }
709
710 let leaf_validity = combine_validities_list_to_list_broadcast_right(
711 offsets_lhs,
712 arr_lhs.validity(),
713 arr_rhs.validity(),
714 arr_lhs.len(),
715 width,
716 rhs_start,
717 );
718
719 let arr =
720 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
721
722 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
723 assert_eq!(offsets.len(), 1);
724
725 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
726 },
727 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
728 if self.list_to_prim_lhs.is_none() =>
729 {
730 let offsets_lhs = self.data_lhs.0.as_slice();
731
732 let n_values = arr_lhs.len();
737 let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
738 let out_ptr = out_vec.as_mut_ptr();
739
740 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
741 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
742 {
743 let r = unsafe { arr_rhs.value_unchecked(i) };
744 for l_idx in l_range {
745 unsafe {
746 let l = arr_lhs.value_unchecked(l_idx);
747 let v = $OP(l, r);
748 out_ptr.add(l_idx).write(v);
749 }
750 }
751 }
752 });
753
754 unsafe { out_vec.set_len(n_values) }
755
756 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
757 offsets_lhs,
758 arr_lhs.validity(),
759 arr_rhs.validity(),
760 arr_lhs.len(),
761 );
762
763 let arr =
764 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
765
766 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
767 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
768 },
769 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
773 let offsets_lhs = self.data_lhs.0.as_slice();
774
775 let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
776 let arr = arr
777 .as_any_mut()
778 .downcast_mut::<PrimitiveArray<T::Native>>()
779 .unwrap();
780 let mut arr_lhs = std::mem::take(arr);
781
782 self.op.0.prepare_numeric_op_side_validities::<T>(
783 &mut arr_lhs,
784 &mut arr_rhs,
785 self.swapped,
786 );
787
788 let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
789 assert_eq!(arr_lhs_mut_slice.len(), n_values);
790
791 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
792 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
793 {
794 let r = unsafe { arr_rhs.value_unchecked(i) };
795 for l_idx in l_range {
796 unsafe {
797 let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
798 *l = $OP(*l, r);
799 }
800 }
801 }
802 });
803
804 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
805 offsets_lhs,
806 arr_lhs.validity(),
807 arr_rhs.validity(),
808 arr_lhs.len(),
809 );
810
811 let arr = arr_lhs.with_validity(leaf_validity);
812
813 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
814 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
815 },
816 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
817 assert_eq!(arr_rhs.len(), 1);
818
819 let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
820 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
822 return Ok(self.finish_offsets_and_validities(
823 Box::new(
824 arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
825 false,
826 arr_lhs.len(),
827 ))),
828 ),
829 offsets,
830 validities,
831 ));
832 };
833
834 let arr = self
835 .op
836 .0
837 .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
838 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
839
840 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
841 },
842 v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
843 | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
844 | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
845 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
846 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
847 if cfg!(debug_assertions) {
848 panic!("operation was not re-written: {v:?}")
849 } else {
850 unreachable!()
851 }
852 },
853 };
854
855 Ok(out)
856 }
857
858 fn finish_offsets_and_validities(
861 &mut self,
862 leaf_array: Box<dyn Array>,
863 offsets: Vec<OffsetsBuffer<i64>>,
864 validities: Vec<Option<Bitmap>>,
865 ) -> ListChunked {
866 assert!(!offsets.is_empty());
867 assert_eq!(offsets.len(), validities.len());
868 let mut results = leaf_array;
869
870 let mut iter = offsets.into_iter().zip(validities).rev();
871
872 while iter.len() > 1 {
873 let (offsets, validity) = iter.next().unwrap();
874 let dtype = LargeListArray::default_datatype(results.dtype().clone());
875 results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
876 }
877
878 let (offsets, _) = iter.next().unwrap();
880 let validity = std::mem::take(&mut self.outer_validity);
881 let dtype = LargeListArray::default_datatype(results.dtype().clone());
882 let results = LargeListArray::new(dtype, offsets, results, Some(validity));
883
884 ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
885 }
886
887 fn materialize_broadcasted_list(
888 side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
889 output_len: usize,
890 output_primitive_dtype: &DataType,
891 ) -> (Box<dyn Array>, usize) {
892 let s = &side_data.2;
893 assert_eq!(s.len(), 1);
894
895 let expected_n_values = {
896 let offsets = s.list_offsets_and_validities_recursive().0;
897 output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
898 };
899
900 let ca = s.list().unwrap();
901 let ca = ca
903 .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
904 .unwrap();
905 assert!(output_len > 1); let ca = ca.new_from_index(0, output_len).rechunk();
907
908 let s = ca.into_series();
909
910 *side_data = {
911 let (a, b) = s.list_offsets_and_validities_recursive();
912 (a, b, Series::default())
914 };
915
916 let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
917 assert_eq!(n_values, expected_n_values);
918
919 let mut s = s.get_leaf_array();
920 let v = unsafe { s.chunks_mut() };
921
922 assert_eq!(v.len(), 1);
923 (v.swap_remove(0), n_values)
924 }
925 }
926
927 #[inline(never)]
929 fn combine_validities_list_to_primitive_no_broadcast(
930 offsets_lhs: &[OffsetsBuffer<i64>],
931 validity_lhs: Option<&Bitmap>,
932 validity_rhs: Option<&Bitmap>,
933 len_lhs: usize,
934 ) -> Option<Bitmap> {
935 match (validity_lhs, validity_rhs) {
936 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
937 (Some(v), None) => return Some(v.clone()),
938 (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
941 (None, None) => None,
942 }
943 .map(|(mut validity_out, validity_rhs)| {
944 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
945 let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
946 for l_idx in l_range {
947 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
948 let is_valid = l_valid & r_valid;
949
950 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
952 }
953 }
954
955 validity_out.freeze()
956 })
957 }
958}