1use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8
9impl NumOpsDispatchInner for ListType {
10 fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
11 NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
12 }
13
14 fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
15 NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
16 }
17
18 fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
19 NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
20 }
21
22 fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
23 NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
24 }
25
26 fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
27 NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
28 }
29}
30
31#[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
32#[derive(Clone)]
33pub struct NumericListOp(NumericOp);
34
35impl NumericListOp {
36 pub fn add() -> Self {
37 Self(NumericOp::Add)
38 }
39
40 pub fn sub() -> Self {
41 Self(NumericOp::Sub)
42 }
43
44 pub fn mul() -> Self {
45 Self(NumericOp::Mul)
46 }
47
48 pub fn div() -> Self {
49 Self(NumericOp::Div)
50 }
51
52 pub fn rem() -> Self {
53 Self(NumericOp::Rem)
54 }
55
56 pub fn floor_div() -> Self {
57 Self(NumericOp::FloorDiv)
58 }
59}
60
61impl NumericListOp {
62 #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
63 pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
64 feature_gated!("list_arithmetic", {
65 use std::borrow::Cow;
66
67 use either::Either;
68
69 let lhs = lhs
73 .trim_lists_to_normalized_offsets()
74 .map_or(Cow::Borrowed(lhs), Cow::Owned);
75 let rhs = rhs
76 .trim_lists_to_normalized_offsets()
77 .map_or(Cow::Borrowed(rhs), Cow::Owned);
78
79 let lhs = lhs.rechunk();
80 let rhs = rhs.rechunk();
81
82 let binary_op_exec = match ListNumericOpHelper::try_new(
83 self.clone(),
84 lhs.name().clone(),
85 lhs.dtype(),
86 rhs.dtype(),
87 lhs.len(),
88 rhs.len(),
89 {
90 let (a, b) = lhs.list_offsets_and_validities_recursive();
91 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
92 (a, b, lhs.clone())
93 },
94 {
95 let (a, b) = rhs.list_offsets_and_validities_recursive();
96 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
97 (a, b, rhs.clone())
98 },
99 lhs.rechunk_validity(),
100 rhs.rechunk_validity(),
101 )? {
102 Either::Left(v) => v,
103 Either::Right(ca) => return Ok(ca.into_series()),
104 };
105
106 Ok(binary_op_exec.finish()?.into_series())
107 })
108 }
109}
110
111#[cfg(feature = "list_arithmetic")]
112use inner::ListNumericOpHelper;
113
114#[cfg(feature = "list_arithmetic")]
115mod inner {
116 use arrow::bitmap::Bitmap;
117 use arrow::compute::utils::combine_validities_and;
118 use arrow::offset::OffsetsBuffer;
119 use either::Either;
120 use list_utils::with_match_pl_num_arith;
121 use num_traits::Zero;
122 use polars_compute::arithmetic::pl_num::PlNumArithmetic;
123 use polars_utils::float::IsFloat;
124
125 use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
126 use super::super::*;
127
128 pub(super) struct ListNumericOpHelper {
131 op: NumericListOp,
132 output_name: PlSmallStr,
133 op_apply_type: BinaryOpApplyType,
134 broadcast: Broadcast,
135 output_dtype: DataType,
136 output_primitive_dtype: DataType,
137 output_len: usize,
138 outer_validity: Bitmap,
141 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
143 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
144 list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
145 swapped: bool,
146 }
147
148 impl ListNumericOpHelper {
151 #[allow(clippy::too_many_arguments)]
167 pub(super) fn try_new(
168 op: NumericListOp,
169 output_name: PlSmallStr,
170 dtype_lhs: &DataType,
171 dtype_rhs: &DataType,
172 len_lhs: usize,
173 len_rhs: usize,
174 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
175 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
176 validity_lhs: Option<Bitmap>,
177 validity_rhs: Option<Bitmap>,
178 ) -> PolarsResult<Either<Self, ListChunked>> {
179 let prim_dtype_lhs = dtype_lhs.leaf_dtype();
180 let prim_dtype_rhs = dtype_rhs.leaf_dtype();
181
182 let output_primitive_dtype =
183 op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
184
185 fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
186 match dtype {
187 DataType::List(inner) => is_list_type_at_all_levels(inner),
188 dt if dt.is_supported_list_arithmetic_input() => true,
189 _ => false,
190 }
191 }
192
193 let op_err_msg = |err_reason: &str| {
194 polars_err!(
195 InvalidOperation:
196 "cannot {} columns: {}: (left: {}, right: {})",
197 op.0.name(), err_reason, dtype_lhs, dtype_rhs,
198 )
199 };
200
201 let ensure_list_type_at_all_levels = |dtype: &DataType| {
202 if !is_list_type_at_all_levels(dtype) {
203 Err(op_err_msg("dtype was not list on all nesting levels"))
204 } else {
205 Ok(())
206 }
207 };
208
209 let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
210 (l @ DataType::List(a), r @ DataType::List(b)) => {
211 if ![a, b]
216 .into_iter()
217 .all(|x| x.is_supported_list_arithmetic_input())
218 {
219 polars_bail!(
220 InvalidOperation:
221 "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
222 op.0.name(), l, r,
223 );
224 }
225 (BinaryOpApplyType::ListToList, l)
226 },
227 (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
228 ensure_list_type_at_all_levels(list_dtype)?;
229 (BinaryOpApplyType::ListToPrimitive, list_dtype)
230 },
231 (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
232 ensure_list_type_at_all_levels(list_dtype)?;
233 (BinaryOpApplyType::PrimitiveToList, list_dtype)
234 },
235 (l, r) => polars_bail!(
236 InvalidOperation:
237 "{} operation not supported for dtypes: {} != {}",
238 op.0.name(), l, r,
239 ),
240 };
241
242 let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
243
244 let (broadcast, output_len) = match (len_lhs, len_rhs) {
245 (l, r) if l == r => (Broadcast::NoBroadcast, l),
246 (1, v) => (Broadcast::Left, v),
247 (v, 1) => (Broadcast::Right, v),
248 (l, r) => polars_bail!(
249 ShapeMismatch:
250 "cannot {} two columns of differing lengths: {} != {}",
251 op.0.name(), l, r
252 ),
253 };
254
255 let DataType::List(output_inner_dtype) = &output_dtype else {
256 unreachable!()
257 };
258
259 if output_len == 0
266 || (matches!(
267 &op_apply_type,
268 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
269 ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
270 || (matches!(
271 &op_apply_type,
272 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
273 ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
274 {
275 return Ok(Either::Right(ListChunked::full_null_with_dtype(
276 output_name,
277 output_len,
278 output_inner_dtype.as_ref(),
279 )));
280 }
281
282 let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
287 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
289 combine_validities_and(l.as_ref(), r.as_ref())
290 },
291 (
293 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
294 Broadcast::NoBroadcast | Broadcast::Right,
295 v,
296 _,
297 )
298 | (
299 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
300 Broadcast::NoBroadcast | Broadcast::Left,
301 _,
302 v,
303 ) => v,
304 _ => None,
305 }
306 .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
307
308 Ok(Either::Left(Self {
309 op,
310 output_name,
311 op_apply_type,
312 broadcast,
313 output_dtype: output_dtype.clone(),
314 output_primitive_dtype,
315 output_len,
316 outer_validity,
317 data_lhs,
318 data_rhs,
319 list_to_prim_lhs: None,
320 swapped: false,
321 }))
322 }
323
324 pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
325 self.swapped = true;
340
341 match (&self.op_apply_type, &self.broadcast) {
342 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
343 | (BinaryOpApplyType::ListToList, Broadcast::Right)
344 | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
345 | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
346 self.swapped = false;
347 self._finish_impl_dispatch()
348 },
349 (BinaryOpApplyType::ListToList, Broadcast::Left) => {
350 self.broadcast = Broadcast::Right;
351
352 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
353 self._finish_impl_dispatch()
354 },
355 (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
356 self.list_to_prim_lhs
357 .replace(Self::materialize_broadcasted_list(
358 &mut self.data_lhs,
359 self.output_len,
360 &self.output_primitive_dtype,
361 ));
362
363 self.broadcast = Broadcast::NoBroadcast;
364
365 self.swapped = false;
368 self._finish_impl_dispatch()
369 },
370 (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
371 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
372
373 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
374 self._finish_impl_dispatch()
375 },
376 (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
377 self.list_to_prim_lhs
382 .replace(Self::materialize_broadcasted_list(
383 &mut self.data_rhs,
384 self.output_len,
385 &self.output_primitive_dtype,
386 ));
387
388 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
389 self.broadcast = Broadcast::NoBroadcast;
390
391 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
392 self._finish_impl_dispatch()
393 },
394 (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
395 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
396 self.broadcast = Broadcast::Right;
397
398 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
399 self._finish_impl_dispatch()
400 },
401 }
402 }
403
404 fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
405 let output_dtype = self.output_dtype.clone();
406 let output_len = self.output_len;
407
408 let prim_lhs = self
409 .data_lhs
410 .2
411 .get_leaf_array()
412 .cast(&self.output_primitive_dtype)?
413 .rechunk();
414 let prim_rhs = self
415 .data_rhs
416 .2
417 .get_leaf_array()
418 .cast(&self.output_primitive_dtype)?
419 .rechunk();
420
421 debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
422 let prim_dtype = prim_lhs.dtype();
423 debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
424
425 let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
427 self._finish_impl::<$T>(prim_lhs, prim_rhs)
428 })?;
429
430 debug_assert_eq!(out.dtype(), &output_dtype);
431 assert_eq!(out.len(), output_len);
432
433 Ok(out)
434 }
435
436 fn _finish_impl<T: PolarsNumericType>(
438 &mut self,
439 prim_s_lhs: Series,
440 prim_s_rhs: Series,
441 ) -> PolarsResult<ListChunked>
442 where
443 T::Native: PlNumArithmetic,
444 PrimitiveArray<T::Native>:
445 polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
446 T::Native: Zero + IsFloat,
447 {
448 #[inline(never)]
449 fn check_mismatch_pos(
450 mismatch_pos: usize,
451 offsets_lhs: &OffsetsBuffer<i64>,
452 offsets_rhs: &OffsetsBuffer<i64>,
453 ) -> PolarsResult<()> {
454 if mismatch_pos < offsets_lhs.len_proxy() {
455 let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
457 0
458 } else {
459 mismatch_pos
460 });
461 polars_bail!(
462 ShapeMismatch:
463 "list lengths differed at index {}: {} != {}",
464 mismatch_pos,
465 offsets_lhs.length_at(mismatch_pos), len_r
466 )
467 }
468 Ok(())
469 }
470
471 let mut arr_lhs = {
472 let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
473 assert_eq!(ca.chunks().len(), 1);
474 ca.downcast_get(0).unwrap().clone()
475 };
476
477 let mut arr_rhs = {
478 let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
479 assert_eq!(ca.chunks().len(), 1);
480 ca.downcast_get(0).unwrap().clone()
481 };
482
483 match (&self.op_apply_type, &self.broadcast) {
484 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
487 _ if self.list_to_prim_lhs.is_none() => {
488 self.op.0.prepare_numeric_op_side_validities::<T>(
489 &mut arr_lhs,
490 &mut arr_rhs,
491 self.swapped,
492 )
493 },
494 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
495 },
497 _ => unreachable!(),
498 }
499
500 let out = match (&self.op_apply_type, &self.broadcast) {
509 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
510 let offsets_lhs = &self.data_lhs.0[0];
511 let offsets_rhs = &self.data_rhs.0[0];
512
513 assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
514
515 let n_values = arr_lhs.len();
517 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
518 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
519
520 let mut mismatch_pos = 0;
523
524 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
525 for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
526 .offset_and_length_iter()
527 .zip(offsets_rhs.offset_and_length_iter())
528 .enumerate()
529 {
530 if
531 (mismatch_pos == i)
532 & (
533 (lhs_len == rhs_len)
534 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
535 )
536 {
537 mismatch_pos += 1;
538 }
539
540 let len: usize = lhs_len.min(rhs_len);
543
544 for i in 0..len {
545 let l_idx = i + lhs_start;
546 let r_idx = i + rhs_start;
547
548 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
549 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
550 let v = $OP(l, r);
551
552 unsafe { out_ptr.add(l_idx).write(v) };
553 }
554 }
555 });
556
557 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
558
559 unsafe { out_vec.set_len(n_values) };
560
561 #[inline(never)]
563 fn combine_validities_list_to_list_no_broadcast(
564 offsets_lhs: &OffsetsBuffer<i64>,
565 offsets_rhs: &OffsetsBuffer<i64>,
566 validity_lhs: Option<&Bitmap>,
567 validity_rhs: Option<&Bitmap>,
568 len_lhs: usize,
569 ) -> Option<Bitmap> {
570 match (validity_lhs, validity_rhs) {
571 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
572 (Some(v), None) => return Some(v.clone()),
573 (None, Some(v)) => {
574 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
575 },
576 (None, None) => None,
577 }
578 .map(|(mut validity_out, validity_rhs)| {
579 for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
580 .offset_and_length_iter()
581 .zip(offsets_rhs.offset_and_length_iter())
582 {
583 let len: usize = lhs_len.min(rhs_len);
584
585 for i in 0..len {
586 let l_idx = i + lhs_start;
587 let r_idx = i + rhs_start;
588
589 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
590 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
591 let is_valid = l_valid & r_valid;
592
593 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
595 }
596 }
597
598 validity_out.freeze()
599 })
600 }
601
602 let leaf_validity = combine_validities_list_to_list_no_broadcast(
603 offsets_lhs,
604 offsets_rhs,
605 arr_lhs.validity(),
606 arr_rhs.validity(),
607 arr_lhs.len(),
608 );
609
610 let arr =
611 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
612
613 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
614 assert_eq!(offsets.len(), 1);
615
616 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
617 },
618 (BinaryOpApplyType::ListToList, Broadcast::Right) => {
619 let offsets_lhs = &self.data_lhs.0[0];
620 let offsets_rhs = &self.data_rhs.0[0];
621
622 let n_values = arr_lhs.len();
624 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
625 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
626
627 assert_eq!(offsets_rhs.len_proxy(), 1);
628 let rhs_start = *offsets_rhs.first() as usize;
629 let width = offsets_rhs.range() as usize;
630
631 let mut mismatch_pos = 0;
632
633 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
634 for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
635 if ((lhs_len == width) & (mismatch_pos == i))
636 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
637 {
638 mismatch_pos += 1;
639 }
640
641 let len: usize = lhs_len.min(width);
642
643 for i in 0..len {
644 let l_idx = i + lhs_start;
645 let r_idx = i + rhs_start;
646
647 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
648 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
649 let v = $OP(l, r);
650
651 unsafe {
652 out_ptr.add(l_idx).write(v);
653 }
654 }
655 }
656 });
657
658 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
659
660 unsafe { out_vec.set_len(n_values) };
661
662 #[inline(never)]
663 fn combine_validities_list_to_list_broadcast_right(
664 offsets_lhs: &OffsetsBuffer<i64>,
665 validity_lhs: Option<&Bitmap>,
666 validity_rhs: Option<&Bitmap>,
667 len_lhs: usize,
668 width: usize,
669 rhs_start: usize,
670 ) -> Option<Bitmap> {
671 match (validity_lhs, validity_rhs) {
672 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
673 (Some(v), None) => return Some(v.clone()),
674 (None, Some(v)) => {
675 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
676 },
677 (None, None) => None,
678 }
679 .map(|(mut validity_out, validity_rhs)| {
680 for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
681 let len: usize = lhs_len.min(width);
682
683 for i in 0..len {
684 let l_idx = i + lhs_start;
685 let r_idx = i + rhs_start;
686
687 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
688 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
689 let is_valid = l_valid & r_valid;
690
691 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
693 }
694 }
695
696 validity_out.freeze()
697 })
698 }
699
700 let leaf_validity = combine_validities_list_to_list_broadcast_right(
701 offsets_lhs,
702 arr_lhs.validity(),
703 arr_rhs.validity(),
704 arr_lhs.len(),
705 width,
706 rhs_start,
707 );
708
709 let arr =
710 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
711
712 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
713 assert_eq!(offsets.len(), 1);
714
715 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
716 },
717 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
718 if self.list_to_prim_lhs.is_none() =>
719 {
720 let offsets_lhs = self.data_lhs.0.as_slice();
721
722 let n_values = arr_lhs.len();
727 let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
728 let out_ptr = out_vec.as_mut_ptr();
729
730 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
731 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
732 {
733 let r = unsafe { arr_rhs.value_unchecked(i) };
734 for l_idx in l_range {
735 unsafe {
736 let l = arr_lhs.value_unchecked(l_idx);
737 let v = $OP(l, r);
738 out_ptr.add(l_idx).write(v);
739 }
740 }
741 }
742 });
743
744 unsafe { out_vec.set_len(n_values) }
745
746 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
747 offsets_lhs,
748 arr_lhs.validity(),
749 arr_rhs.validity(),
750 arr_lhs.len(),
751 );
752
753 let arr =
754 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
755
756 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
757 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
758 },
759 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
763 let offsets_lhs = self.data_lhs.0.as_slice();
764
765 let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
766 let arr = arr
767 .as_any_mut()
768 .downcast_mut::<PrimitiveArray<T::Native>>()
769 .unwrap();
770 let mut arr_lhs = std::mem::take(arr);
771
772 self.op.0.prepare_numeric_op_side_validities::<T>(
773 &mut arr_lhs,
774 &mut arr_rhs,
775 self.swapped,
776 );
777
778 let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
779 assert_eq!(arr_lhs_mut_slice.len(), n_values);
780
781 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
782 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
783 {
784 let r = unsafe { arr_rhs.value_unchecked(i) };
785 for l_idx in l_range {
786 unsafe {
787 let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
788 *l = $OP(*l, r);
789 }
790 }
791 }
792 });
793
794 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
795 offsets_lhs,
796 arr_lhs.validity(),
797 arr_rhs.validity(),
798 arr_lhs.len(),
799 );
800
801 let arr = arr_lhs.with_validity(leaf_validity);
802
803 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
804 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
805 },
806 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
807 assert_eq!(arr_rhs.len(), 1);
808
809 let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
810 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
812 return Ok(self.finish_offsets_and_validities(
813 Box::new(
814 arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
815 false,
816 arr_lhs.len(),
817 ))),
818 ),
819 offsets,
820 validities,
821 ));
822 };
823
824 let arr = self
825 .op
826 .0
827 .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
828 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
829
830 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
831 },
832 v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
833 | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
834 | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
835 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
836 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
837 if cfg!(debug_assertions) {
838 panic!("operation was not re-written: {v:?}")
839 } else {
840 unreachable!()
841 }
842 },
843 };
844
845 Ok(out)
846 }
847
848 fn finish_offsets_and_validities(
851 &mut self,
852 leaf_array: Box<dyn Array>,
853 offsets: Vec<OffsetsBuffer<i64>>,
854 validities: Vec<Option<Bitmap>>,
855 ) -> ListChunked {
856 assert!(!offsets.is_empty());
857 assert_eq!(offsets.len(), validities.len());
858 let mut results = leaf_array;
859
860 let mut iter = offsets.into_iter().zip(validities).rev();
861
862 while iter.len() > 1 {
863 let (offsets, validity) = iter.next().unwrap();
864 let dtype = LargeListArray::default_datatype(results.dtype().clone());
865 results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
866 }
867
868 let (offsets, _) = iter.next().unwrap();
870 let validity = std::mem::take(&mut self.outer_validity);
871 let dtype = LargeListArray::default_datatype(results.dtype().clone());
872 let results = LargeListArray::new(dtype, offsets, results, Some(validity));
873
874 ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
875 }
876
877 fn materialize_broadcasted_list(
878 side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
879 output_len: usize,
880 output_primitive_dtype: &DataType,
881 ) -> (Box<dyn Array>, usize) {
882 let s = &side_data.2;
883 assert_eq!(s.len(), 1);
884
885 let expected_n_values = {
886 let offsets = s.list_offsets_and_validities_recursive().0;
887 output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
888 };
889
890 let ca = s.list().unwrap();
891 let ca = ca
893 .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
894 .unwrap();
895 assert!(output_len > 1); let ca = ca.new_from_index(0, output_len).rechunk();
897
898 let s = ca.into_series();
899
900 *side_data = {
901 let (a, b) = s.list_offsets_and_validities_recursive();
902 (a, b, Series::default())
904 };
905
906 let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
907 assert_eq!(n_values, expected_n_values);
908
909 let mut s = s.get_leaf_array();
910 let v = unsafe { s.chunks_mut() };
911
912 assert_eq!(v.len(), 1);
913 (v.swap_remove(0), n_values)
914 }
915 }
916
917 #[inline(never)]
919 fn combine_validities_list_to_primitive_no_broadcast(
920 offsets_lhs: &[OffsetsBuffer<i64>],
921 validity_lhs: Option<&Bitmap>,
922 validity_rhs: Option<&Bitmap>,
923 len_lhs: usize,
924 ) -> Option<Bitmap> {
925 match (validity_lhs, validity_rhs) {
926 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
927 (Some(v), None) => return Some(v.clone()),
928 (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
931 (None, None) => None,
932 }
933 .map(|(mut validity_out, validity_rhs)| {
934 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
935 let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
936 for l_idx in l_range {
937 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
938 let is_valid = l_valid & r_valid;
939
940 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
942 }
943 }
944
945 validity_out.freeze()
946 })
947 }
948}