1use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8
9impl NumOpsDispatchInner for ListType {
10 fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
11 NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
12 }
13
14 fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
15 NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
16 }
17
18 fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
19 NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
20 }
21
22 fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
23 NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
24 }
25
26 fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
27 NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
28 }
29}
30
31#[derive(Clone)]
32pub struct NumericListOp(NumericOp);
33
34impl NumericListOp {
35 pub fn add() -> Self {
36 Self(NumericOp::Add)
37 }
38
39 pub fn sub() -> Self {
40 Self(NumericOp::Sub)
41 }
42
43 pub fn mul() -> Self {
44 Self(NumericOp::Mul)
45 }
46
47 pub fn div() -> Self {
48 Self(NumericOp::Div)
49 }
50
51 pub fn rem() -> Self {
52 Self(NumericOp::Rem)
53 }
54
55 pub fn floor_div() -> Self {
56 Self(NumericOp::FloorDiv)
57 }
58}
59
60impl NumericListOp {
61 #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
62 pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
63 feature_gated!("list_arithmetic", {
64 use either::Either;
65
66 let lhs = lhs.list_rechunk_and_trim_to_normalized_offsets();
70 let rhs = rhs.list_rechunk_and_trim_to_normalized_offsets();
71
72 let binary_op_exec = match ListNumericOpHelper::try_new(
73 self.clone(),
74 lhs.name().clone(),
75 lhs.dtype(),
76 rhs.dtype(),
77 lhs.len(),
78 rhs.len(),
79 {
80 let (a, b) = lhs.list_offsets_and_validities_recursive();
81 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
82 (a, b, lhs.clone())
83 },
84 {
85 let (a, b) = rhs.list_offsets_and_validities_recursive();
86 debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
87 (a, b, rhs.clone())
88 },
89 lhs.rechunk_validity(),
90 rhs.rechunk_validity(),
91 )? {
92 Either::Left(v) => v,
93 Either::Right(ca) => return Ok(ca.into_series()),
94 };
95
96 Ok(binary_op_exec.finish()?.into_series())
97 })
98 }
99}
100
101#[cfg(feature = "list_arithmetic")]
102use inner::ListNumericOpHelper;
103
104#[cfg(feature = "list_arithmetic")]
105mod inner {
106 use arrow::bitmap::Bitmap;
107 use arrow::compute::utils::combine_validities_and;
108 use arrow::offset::OffsetsBuffer;
109 use either::Either;
110 use list_utils::with_match_pl_num_arith;
111 use num_traits::Zero;
112 use polars_compute::arithmetic::pl_num::PlNumArithmetic;
113 use polars_utils::float::IsFloat;
114
115 use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
116 use super::super::*;
117
118 pub(super) struct ListNumericOpHelper {
121 op: NumericListOp,
122 output_name: PlSmallStr,
123 op_apply_type: BinaryOpApplyType,
124 broadcast: Broadcast,
125 output_dtype: DataType,
126 output_primitive_dtype: DataType,
127 output_len: usize,
128 outer_validity: Bitmap,
131 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
133 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
134 list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
135 swapped: bool,
136 }
137
138 impl ListNumericOpHelper {
141 #[allow(clippy::too_many_arguments)]
157 pub(super) fn try_new(
158 op: NumericListOp,
159 output_name: PlSmallStr,
160 dtype_lhs: &DataType,
161 dtype_rhs: &DataType,
162 len_lhs: usize,
163 len_rhs: usize,
164 data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
165 data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
166 validity_lhs: Option<Bitmap>,
167 validity_rhs: Option<Bitmap>,
168 ) -> PolarsResult<Either<Self, ListChunked>> {
169 let prim_dtype_lhs = dtype_lhs.leaf_dtype();
170 let prim_dtype_rhs = dtype_rhs.leaf_dtype();
171
172 let output_primitive_dtype =
173 op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
174
175 fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
176 match dtype {
177 DataType::List(inner) => is_list_type_at_all_levels(inner),
178 dt if dt.is_supported_list_arithmetic_input() => true,
179 _ => false,
180 }
181 }
182
183 let op_err_msg = |err_reason: &str| {
184 polars_err!(
185 InvalidOperation:
186 "cannot {} columns: {}: (left: {}, right: {})",
187 op.0.name(), err_reason, dtype_lhs, dtype_rhs,
188 )
189 };
190
191 let ensure_list_type_at_all_levels = |dtype: &DataType| {
192 if !is_list_type_at_all_levels(dtype) {
193 Err(op_err_msg("dtype was not list on all nesting levels"))
194 } else {
195 Ok(())
196 }
197 };
198
199 let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
200 (l @ DataType::List(a), r @ DataType::List(b)) => {
201 if ![a, b]
206 .into_iter()
207 .all(|x| x.is_supported_list_arithmetic_input())
208 {
209 polars_bail!(
210 InvalidOperation:
211 "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
212 op.0.name(), l, r,
213 );
214 }
215 (BinaryOpApplyType::ListToList, l)
216 },
217 (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
218 ensure_list_type_at_all_levels(list_dtype)?;
219 (BinaryOpApplyType::ListToPrimitive, list_dtype)
220 },
221 (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
222 ensure_list_type_at_all_levels(list_dtype)?;
223 (BinaryOpApplyType::PrimitiveToList, list_dtype)
224 },
225 (l, r) => polars_bail!(
226 InvalidOperation:
227 "{} operation not supported for dtypes: {} != {}",
228 op.0.name(), l, r,
229 ),
230 };
231
232 let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
233
234 let (broadcast, output_len) = match (len_lhs, len_rhs) {
235 (l, r) if l == r => (Broadcast::NoBroadcast, l),
236 (1, v) => (Broadcast::Left, v),
237 (v, 1) => (Broadcast::Right, v),
238 (l, r) => polars_bail!(
239 ShapeMismatch:
240 "cannot {} two columns of differing lengths: {} != {}",
241 op.0.name(), l, r
242 ),
243 };
244
245 let DataType::List(output_inner_dtype) = &output_dtype else {
246 unreachable!()
247 };
248
249 if output_len == 0
256 || (matches!(
257 &op_apply_type,
258 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
259 ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
260 || (matches!(
261 &op_apply_type,
262 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
263 ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
264 {
265 return Ok(Either::Right(ListChunked::full_null_with_dtype(
266 output_name.clone(),
267 output_len,
268 output_inner_dtype.as_ref(),
269 )));
270 }
271
272 let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
277 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
279 combine_validities_and(l.as_ref(), r.as_ref())
280 },
281 (
283 BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
284 Broadcast::NoBroadcast | Broadcast::Right,
285 v,
286 _,
287 )
288 | (
289 BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
290 Broadcast::NoBroadcast | Broadcast::Left,
291 _,
292 v,
293 ) => v,
294 _ => None,
295 }
296 .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
297
298 Ok(Either::Left(Self {
299 op,
300 output_name,
301 op_apply_type,
302 broadcast,
303 output_dtype: output_dtype.clone(),
304 output_primitive_dtype,
305 output_len,
306 outer_validity,
307 data_lhs,
308 data_rhs,
309 list_to_prim_lhs: None,
310 swapped: false,
311 }))
312 }
313
314 pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
315 self.swapped = true;
330
331 match (&self.op_apply_type, &self.broadcast) {
332 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
333 | (BinaryOpApplyType::ListToList, Broadcast::Right)
334 | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
335 | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
336 self.swapped = false;
337 self._finish_impl_dispatch()
338 },
339 (BinaryOpApplyType::ListToList, Broadcast::Left) => {
340 self.broadcast = Broadcast::Right;
341
342 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
343 self._finish_impl_dispatch()
344 },
345 (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
346 self.list_to_prim_lhs
347 .replace(Self::materialize_broadcasted_list(
348 &mut self.data_lhs,
349 self.output_len,
350 &self.output_primitive_dtype,
351 ));
352
353 self.broadcast = Broadcast::NoBroadcast;
354
355 self.swapped = false;
358 self._finish_impl_dispatch()
359 },
360 (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
361 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
362
363 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
364 self._finish_impl_dispatch()
365 },
366 (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
367 self.list_to_prim_lhs
372 .replace(Self::materialize_broadcasted_list(
373 &mut self.data_rhs,
374 self.output_len,
375 &self.output_primitive_dtype,
376 ));
377
378 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
379 self.broadcast = Broadcast::NoBroadcast;
380
381 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
382 self._finish_impl_dispatch()
383 },
384 (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
385 self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
386 self.broadcast = Broadcast::Right;
387
388 std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
389 self._finish_impl_dispatch()
390 },
391 }
392 }
393
394 fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
395 let output_dtype = self.output_dtype.clone();
396 let output_len = self.output_len;
397
398 let prim_lhs = self
399 .data_lhs
400 .2
401 .get_leaf_array()
402 .cast(&self.output_primitive_dtype)?
403 .rechunk();
404 let prim_rhs = self
405 .data_rhs
406 .2
407 .get_leaf_array()
408 .cast(&self.output_primitive_dtype)?
409 .rechunk();
410
411 debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
412 let prim_dtype = prim_lhs.dtype();
413 debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
414
415 let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
417 self._finish_impl::<$T>(prim_lhs, prim_rhs)
418 })?;
419
420 debug_assert_eq!(out.dtype(), &output_dtype);
421 assert_eq!(out.len(), output_len);
422
423 Ok(out)
424 }
425
426 fn _finish_impl<T: PolarsNumericType>(
428 &mut self,
429 prim_s_lhs: Series,
430 prim_s_rhs: Series,
431 ) -> PolarsResult<ListChunked>
432 where
433 T::Native: PlNumArithmetic,
434 PrimitiveArray<T::Native>:
435 polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
436 T::Native: Zero + IsFloat,
437 {
438 #[inline(never)]
439 fn check_mismatch_pos(
440 mismatch_pos: usize,
441 offsets_lhs: &OffsetsBuffer<i64>,
442 offsets_rhs: &OffsetsBuffer<i64>,
443 ) -> PolarsResult<()> {
444 if mismatch_pos < offsets_lhs.len_proxy() {
445 let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
447 0
448 } else {
449 mismatch_pos
450 });
451 polars_bail!(
452 ShapeMismatch:
453 "list lengths differed at index {}: {} != {}",
454 mismatch_pos,
455 offsets_lhs.length_at(mismatch_pos), len_r
456 )
457 }
458 Ok(())
459 }
460
461 let mut arr_lhs = {
462 let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
463 assert_eq!(ca.chunks().len(), 1);
464 ca.downcast_get(0).unwrap().clone()
465 };
466
467 let mut arr_rhs = {
468 let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
469 assert_eq!(ca.chunks().len(), 1);
470 ca.downcast_get(0).unwrap().clone()
471 };
472
473 match (&self.op_apply_type, &self.broadcast) {
474 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
477 _ if self.list_to_prim_lhs.is_none() => {
478 self.op.0.prepare_numeric_op_side_validities::<T>(
479 &mut arr_lhs,
480 &mut arr_rhs,
481 self.swapped,
482 )
483 },
484 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
485 },
487 _ => unreachable!(),
488 }
489
490 let out = match (&self.op_apply_type, &self.broadcast) {
499 (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
500 let offsets_lhs = &self.data_lhs.0[0];
501 let offsets_rhs = &self.data_rhs.0[0];
502
503 assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
504
505 let n_values = arr_lhs.len();
507 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
508 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
509
510 let mut mismatch_pos = 0;
513
514 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
515 for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
516 .offset_and_length_iter()
517 .zip(offsets_rhs.offset_and_length_iter())
518 .enumerate()
519 {
520 if
521 (mismatch_pos == i)
522 & (
523 (lhs_len == rhs_len)
524 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
525 )
526 {
527 mismatch_pos += 1;
528 }
529
530 let len: usize = lhs_len.min(rhs_len);
533
534 for i in 0..len {
535 let l_idx = i + lhs_start;
536 let r_idx = i + rhs_start;
537
538 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
539 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
540 let v = $OP(l, r);
541
542 unsafe { out_ptr.add(l_idx).write(v) };
543 }
544 }
545 });
546
547 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
548
549 unsafe { out_vec.set_len(n_values) };
550
551 #[inline(never)]
553 fn combine_validities_list_to_list_no_broadcast(
554 offsets_lhs: &OffsetsBuffer<i64>,
555 offsets_rhs: &OffsetsBuffer<i64>,
556 validity_lhs: Option<&Bitmap>,
557 validity_rhs: Option<&Bitmap>,
558 len_lhs: usize,
559 ) -> Option<Bitmap> {
560 match (validity_lhs, validity_rhs) {
561 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
562 (Some(v), None) => return Some(v.clone()),
563 (None, Some(v)) => {
564 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
565 },
566 (None, None) => None,
567 }
568 .map(|(mut validity_out, validity_rhs)| {
569 for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
570 .offset_and_length_iter()
571 .zip(offsets_rhs.offset_and_length_iter())
572 {
573 let len: usize = lhs_len.min(rhs_len);
574
575 for i in 0..len {
576 let l_idx = i + lhs_start;
577 let r_idx = i + rhs_start;
578
579 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
580 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
581 let is_valid = l_valid & r_valid;
582
583 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
585 }
586 }
587
588 validity_out.freeze()
589 })
590 }
591
592 let leaf_validity = combine_validities_list_to_list_no_broadcast(
593 offsets_lhs,
594 offsets_rhs,
595 arr_lhs.validity(),
596 arr_rhs.validity(),
597 arr_lhs.len(),
598 );
599
600 let arr =
601 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
602
603 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
604 assert_eq!(offsets.len(), 1);
605
606 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
607 },
608 (BinaryOpApplyType::ListToList, Broadcast::Right) => {
609 let offsets_lhs = &self.data_lhs.0[0];
610 let offsets_rhs = &self.data_rhs.0[0];
611
612 let n_values = arr_lhs.len();
614 let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
615 let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
616
617 assert_eq!(offsets_rhs.len_proxy(), 1);
618 let rhs_start = *offsets_rhs.first() as usize;
619 let width = offsets_rhs.range() as usize;
620
621 let mut mismatch_pos = 0;
622
623 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
624 for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
625 if ((lhs_len == width) & (mismatch_pos == i))
626 | unsafe { !self.outer_validity.get_bit_unchecked(i) }
627 {
628 mismatch_pos += 1;
629 }
630
631 let len: usize = lhs_len.min(width);
632
633 for i in 0..len {
634 let l_idx = i + lhs_start;
635 let r_idx = i + rhs_start;
636
637 let l = unsafe { arr_lhs.value_unchecked(l_idx) };
638 let r = unsafe { arr_rhs.value_unchecked(r_idx) };
639 let v = $OP(l, r);
640
641 unsafe {
642 out_ptr.add(l_idx).write(v);
643 }
644 }
645 }
646 });
647
648 check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
649
650 unsafe { out_vec.set_len(n_values) };
651
652 #[inline(never)]
653 fn combine_validities_list_to_list_broadcast_right(
654 offsets_lhs: &OffsetsBuffer<i64>,
655 validity_lhs: Option<&Bitmap>,
656 validity_rhs: Option<&Bitmap>,
657 len_lhs: usize,
658 width: usize,
659 rhs_start: usize,
660 ) -> Option<Bitmap> {
661 match (validity_lhs, validity_rhs) {
662 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
663 (Some(v), None) => return Some(v.clone()),
664 (None, Some(v)) => {
665 Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
666 },
667 (None, None) => None,
668 }
669 .map(|(mut validity_out, validity_rhs)| {
670 for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
671 let len: usize = lhs_len.min(width);
672
673 for i in 0..len {
674 let l_idx = i + lhs_start;
675 let r_idx = i + rhs_start;
676
677 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
678 let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
679 let is_valid = l_valid & r_valid;
680
681 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
683 }
684 }
685
686 validity_out.freeze()
687 })
688 }
689
690 let leaf_validity = combine_validities_list_to_list_broadcast_right(
691 offsets_lhs,
692 arr_lhs.validity(),
693 arr_rhs.validity(),
694 arr_lhs.len(),
695 width,
696 rhs_start,
697 );
698
699 let arr =
700 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
701
702 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
703 assert_eq!(offsets.len(), 1);
704
705 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
706 },
707 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
708 if self.list_to_prim_lhs.is_none() =>
709 {
710 let offsets_lhs = self.data_lhs.0.as_slice();
711
712 let n_values = arr_lhs.len();
717 let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
718 let out_ptr = out_vec.as_mut_ptr();
719
720 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
721 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
722 {
723 let r = unsafe { arr_rhs.value_unchecked(i) };
724 for l_idx in l_range {
725 unsafe {
726 let l = arr_lhs.value_unchecked(l_idx);
727 let v = $OP(l, r);
728 out_ptr.add(l_idx).write(v);
729 }
730 }
731 }
732 });
733
734 unsafe { out_vec.set_len(n_values) }
735
736 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
737 offsets_lhs,
738 arr_lhs.validity(),
739 arr_rhs.validity(),
740 arr_lhs.len(),
741 );
742
743 let arr =
744 PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
745
746 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
747 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
748 },
749 (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
753 let offsets_lhs = self.data_lhs.0.as_slice();
754
755 let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
756 let arr = arr
757 .as_any_mut()
758 .downcast_mut::<PrimitiveArray<T::Native>>()
759 .unwrap();
760 let mut arr_lhs = std::mem::take(arr);
761
762 self.op.0.prepare_numeric_op_side_validities::<T>(
763 &mut arr_lhs,
764 &mut arr_rhs,
765 self.swapped,
766 );
767
768 let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
769 assert_eq!(arr_lhs_mut_slice.len(), n_values);
770
771 with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
772 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
773 {
774 let r = unsafe { arr_rhs.value_unchecked(i) };
775 for l_idx in l_range {
776 unsafe {
777 let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
778 *l = $OP(*l, r);
779 }
780 }
781 }
782 });
783
784 let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
785 offsets_lhs,
786 arr_lhs.validity(),
787 arr_rhs.validity(),
788 arr_lhs.len(),
789 );
790
791 let arr = arr_lhs.with_validity(leaf_validity);
792
793 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
794 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
795 },
796 (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
797 assert_eq!(arr_rhs.len(), 1);
798
799 let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
800 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
802 return Ok(self.finish_offsets_and_validities(
803 Box::new(
804 arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
805 false,
806 arr_lhs.len(),
807 ))),
808 ),
809 offsets,
810 validities,
811 ));
812 };
813
814 let arr = self
815 .op
816 .0
817 .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
818 let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
819
820 self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
821 },
822 v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
823 | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
824 | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
825 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
826 | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
827 if cfg!(debug_assertions) {
828 panic!("operation was not re-written: {:?}", v)
829 } else {
830 unreachable!()
831 }
832 },
833 };
834
835 Ok(out)
836 }
837
838 fn finish_offsets_and_validities(
841 &mut self,
842 leaf_array: Box<dyn Array>,
843 offsets: Vec<OffsetsBuffer<i64>>,
844 validities: Vec<Option<Bitmap>>,
845 ) -> ListChunked {
846 assert!(!offsets.is_empty());
847 assert_eq!(offsets.len(), validities.len());
848 let mut results = leaf_array;
849
850 let mut iter = offsets.into_iter().zip(validities).rev();
851
852 while iter.len() > 1 {
853 let (offsets, validity) = iter.next().unwrap();
854 let dtype = LargeListArray::default_datatype(results.dtype().clone());
855 results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
856 }
857
858 let (offsets, _) = iter.next().unwrap();
860 let validity = std::mem::take(&mut self.outer_validity);
861 let dtype = LargeListArray::default_datatype(results.dtype().clone());
862 let results = LargeListArray::new(dtype, offsets, results, Some(validity));
863
864 ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
865 }
866
867 fn materialize_broadcasted_list(
868 side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
869 output_len: usize,
870 output_primitive_dtype: &DataType,
871 ) -> (Box<dyn Array>, usize) {
872 let s = &side_data.2;
873 assert_eq!(s.len(), 1);
874
875 let expected_n_values = {
876 let offsets = s.list_offsets_and_validities_recursive().0;
877 output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
878 };
879
880 let ca = s.list().unwrap();
881 let ca = ca
883 .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
884 .unwrap();
885 assert!(output_len > 1); let ca = ca.new_from_index(0, output_len).rechunk();
887
888 let s = ca.into_series();
889
890 *side_data = {
891 let (a, b) = s.list_offsets_and_validities_recursive();
892 (a, b, Series::default())
894 };
895
896 let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
897 assert_eq!(n_values, expected_n_values);
898
899 let mut s = s.get_leaf_array();
900 let v = unsafe { s.chunks_mut() };
901
902 assert_eq!(v.len(), 1);
903 (v.swap_remove(0), n_values)
904 }
905 }
906
907 #[inline(never)]
909 fn combine_validities_list_to_primitive_no_broadcast(
910 offsets_lhs: &[OffsetsBuffer<i64>],
911 validity_lhs: Option<&Bitmap>,
912 validity_rhs: Option<&Bitmap>,
913 len_lhs: usize,
914 ) -> Option<Bitmap> {
915 match (validity_lhs, validity_rhs) {
916 (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
917 (Some(v), None) => return Some(v.clone()),
918 (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
921 (None, None) => None,
922 }
923 .map(|(mut validity_out, validity_rhs)| {
924 for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
925 let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
926 for l_idx in l_range {
927 let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
928 let is_valid = l_valid & r_valid;
929
930 unsafe { validity_out.set_unchecked(l_idx, is_valid) };
932 }
933 }
934
935 validity_out.freeze()
936 })
937 }
938}