polars_core/series/arithmetic/
list.rs

1//! Allow arithmetic operations for ListChunked.
2//! use polars_error::{feature_gated, PolarsResult};
3
4use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8use crate::prelude::DataType;
9
10impl NumOpsDispatchInner for ListType {
11    fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
12        NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
13    }
14
15    fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
16        NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
17    }
18
19    fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
20        NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
21    }
22
23    fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
24        NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
25    }
26
27    fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
28        NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
29    }
30}
31
32#[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
33#[derive(Clone)]
34pub struct NumericListOp(NumericOp);
35
36impl NumericListOp {
37    pub fn add() -> Self {
38        Self(NumericOp::Add)
39    }
40
41    pub fn sub() -> Self {
42        Self(NumericOp::Sub)
43    }
44
45    pub fn mul() -> Self {
46        Self(NumericOp::Mul)
47    }
48
49    pub fn div() -> Self {
50        Self(NumericOp::Div)
51    }
52
53    pub fn rem() -> Self {
54        Self(NumericOp::Rem)
55    }
56
57    pub fn floor_div() -> Self {
58        Self(NumericOp::FloorDiv)
59    }
60
61    pub fn try_get_leaf_supertype(
62        &self,
63        prim_dtype_lhs: &DataType,
64        prim_dtype_rhs: &DataType,
65    ) -> PolarsResult<DataType> {
66        self.0
67            .try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)
68    }
69}
70
71impl NumericListOp {
72    #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
73    pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
74        feature_gated!("list_arithmetic", {
75            use std::borrow::Cow;
76
77            use either::Either;
78
79            // `trim_to_normalized_offsets` ensures we don't perform excessive
80            // memory allocation / compute on memory regions that have been
81            // sliced out.
82            let lhs = lhs
83                .trim_lists_to_normalized_offsets()
84                .map_or(Cow::Borrowed(lhs), Cow::Owned);
85            let rhs = rhs
86                .trim_lists_to_normalized_offsets()
87                .map_or(Cow::Borrowed(rhs), Cow::Owned);
88
89            let lhs = lhs.rechunk();
90            let rhs = rhs.rechunk();
91
92            let binary_op_exec = match ListNumericOpHelper::try_new(
93                self.clone(),
94                lhs.name().clone(),
95                lhs.dtype(),
96                rhs.dtype(),
97                lhs.len(),
98                rhs.len(),
99                {
100                    let (a, b) = lhs.list_offsets_and_validities_recursive();
101                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
102                    (a, b, lhs.clone())
103                },
104                {
105                    let (a, b) = rhs.list_offsets_and_validities_recursive();
106                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
107                    (a, b, rhs.clone())
108                },
109                lhs.rechunk_validity(),
110                rhs.rechunk_validity(),
111            )? {
112                Either::Left(v) => v,
113                Either::Right(ca) => return Ok(ca.into_series()),
114            };
115
116            Ok(binary_op_exec.finish()?.into_series())
117        })
118    }
119}
120
121#[cfg(feature = "list_arithmetic")]
122use inner::ListNumericOpHelper;
123
124#[cfg(feature = "list_arithmetic")]
125mod inner {
126    use arrow::bitmap::Bitmap;
127    use arrow::compute::utils::combine_validities_and;
128    use arrow::offset::OffsetsBuffer;
129    use either::Either;
130    use list_utils::with_match_pl_num_arith;
131    use num_traits::Zero;
132    use polars_compute::arithmetic::pl_num::PlNumArithmetic;
133    use polars_utils::float::IsFloat;
134
135    use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
136    use super::super::*;
137
138    /// Utility to perform a binary operation between the primitive values of
139    /// 2 columns, where at least one of the columns is a `ListChunked` type.
140    pub(super) struct ListNumericOpHelper {
141        op: NumericListOp,
142        output_name: PlSmallStr,
143        op_apply_type: BinaryOpApplyType,
144        broadcast: Broadcast,
145        output_dtype: DataType,
146        output_primitive_dtype: DataType,
147        output_len: usize,
148        /// Outer validity of the result, we always materialize this to reduce the
149        /// amount of code paths we need.
150        outer_validity: Bitmap,
151        // The series are stored as they are used for list broadcasting.
152        data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
153        data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
154        list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
155        swapped: bool,
156    }
157
158    /// This lets us separate some logic into `new()` to reduce the amount of
159    /// monomorphized code.
160    impl ListNumericOpHelper {
161        /// Checks that:
162        /// * Dtypes are compatible:
163        ///   * list<->primitive | primitive<->list
164        ///   * list<->list both contain primitives (e.g. List<Int8>)
165        /// * Primitive dtypes match
166        /// * Lengths are compatible:
167        ///   * 1<->n | n<->1
168        ///   * n<->n
169        /// * Both sides have at least 1 non-NULL outer row.
170        ///
171        /// Does not check:
172        /// * Whether the offsets are aligned for list<->list, this will be checked during execution.
173        ///
174        /// This returns an `Either` which may contain the final result to simplify
175        /// the implementation.
176        #[allow(clippy::too_many_arguments)]
177        pub(super) fn try_new(
178            op: NumericListOp,
179            output_name: PlSmallStr,
180            dtype_lhs: &DataType,
181            dtype_rhs: &DataType,
182            len_lhs: usize,
183            len_rhs: usize,
184            data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
185            data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
186            validity_lhs: Option<Bitmap>,
187            validity_rhs: Option<Bitmap>,
188        ) -> PolarsResult<Either<Self, ListChunked>> {
189            let prim_dtype_lhs = dtype_lhs.leaf_dtype();
190            let prim_dtype_rhs = dtype_rhs.leaf_dtype();
191
192            let output_primitive_dtype =
193                op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
194
195            fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
196                match dtype {
197                    DataType::List(inner) => is_list_type_at_all_levels(inner),
198                    dt if dt.is_supported_list_arithmetic_input() => true,
199                    _ => false,
200                }
201            }
202
203            let op_err_msg = |err_reason: &str| {
204                polars_err!(
205                    InvalidOperation:
206                    "cannot {} columns: {}: (left: {}, right: {})",
207                    op.0.name(), err_reason, dtype_lhs, dtype_rhs,
208                )
209            };
210
211            let ensure_list_type_at_all_levels = |dtype: &DataType| {
212                if !is_list_type_at_all_levels(dtype) {
213                    Err(op_err_msg("dtype was not list on all nesting levels"))
214                } else {
215                    Ok(())
216                }
217            };
218
219            let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
220                (l @ DataType::List(a), r @ DataType::List(b)) => {
221                    // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user
222                    // directly adds 2 series together it bypasses the DSL.
223                    // This is currently duplicated code and should be replaced one day with an assert after Series ops get
224                    // checked properly.
225                    if ![a, b]
226                        .into_iter()
227                        .all(|x| x.is_supported_list_arithmetic_input())
228                    {
229                        polars_bail!(
230                            InvalidOperation:
231                            "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
232                            op.0.name(), l, r,
233                        );
234                    }
235                    (BinaryOpApplyType::ListToList, l)
236                },
237                (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
238                    ensure_list_type_at_all_levels(list_dtype)?;
239                    (BinaryOpApplyType::ListToPrimitive, list_dtype)
240                },
241                (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
242                    ensure_list_type_at_all_levels(list_dtype)?;
243                    (BinaryOpApplyType::PrimitiveToList, list_dtype)
244                },
245                (l, r) => polars_bail!(
246                    InvalidOperation:
247                    "{} operation not supported for dtypes: {} != {}",
248                    op.0.name(), l, r,
249                ),
250            };
251
252            let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
253
254            let (broadcast, output_len) = match (len_lhs, len_rhs) {
255                (l, r) if l == r => (Broadcast::NoBroadcast, l),
256                (1, v) => (Broadcast::Left, v),
257                (v, 1) => (Broadcast::Right, v),
258                (l, r) => polars_bail!(
259                    ShapeMismatch:
260                    "cannot {} two columns of differing lengths: {} != {}",
261                    op.0.name(), l, r
262                ),
263            };
264
265            let DataType::List(output_inner_dtype) = &output_dtype else {
266                unreachable!()
267            };
268
269            // # NULL semantics
270            // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]]
271            //   * Essentially as if the NULL primitive was added to every primitive in the row of the list column.
272            // * NULL (List[Int64]) + 1   (Int64)       => NULL
273            // * NULL (List[Int64]) + [1] (List[Int64]) => NULL
274
275            if output_len == 0
276                || (matches!(
277                    &op_apply_type,
278                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
279                ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
280                || (matches!(
281                    &op_apply_type,
282                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
283                ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
284            {
285                return Ok(Either::Right(ListChunked::full_null_with_dtype(
286                    output_name,
287                    output_len,
288                    output_inner_dtype.as_ref(),
289                )));
290            }
291
292            // At this point:
293            // * All unit length list columns have a valid outer value.
294
295            // The outer validity is just the validity of any non-broadcasting lists.
296            let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
297                // Both lists with same length, we combine the validity.
298                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
299                    combine_validities_and(l.as_ref(), r.as_ref())
300                },
301                // Match all other combinations that have non-broadcasting lists.
302                (
303                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
304                    Broadcast::NoBroadcast | Broadcast::Right,
305                    v,
306                    _,
307                )
308                | (
309                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
310                    Broadcast::NoBroadcast | Broadcast::Left,
311                    _,
312                    v,
313                ) => v,
314                _ => None,
315            }
316            .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
317
318            Ok(Either::Left(Self {
319                op,
320                output_name,
321                op_apply_type,
322                broadcast,
323                output_dtype: output_dtype.clone(),
324                output_primitive_dtype,
325                output_len,
326                outer_validity,
327                data_lhs,
328                data_rhs,
329                list_to_prim_lhs: None,
330                swapped: false,
331            }))
332        }
333
334        pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
335            // We have physical codepaths for a subset of the possible combinations of broadcasting and
336            // column types. The remaining combinations are handled by dispatching to the physical
337            // codepaths after operand swapping and/or materialized broadcasting.
338            //
339            // # Physical impl table
340            // Legend
341            // * |  N  | // impl "N"
342            // * | [N] | // dispatches to impl "N"
343            //
344            //                  |  L  |  N  |  R  | // Broadcast (L)eft, (N)oBroadcast, (R)ight
345            // ListToList       | [1] |  0  |  1  |
346            // ListToPrimitive  | [2] |  2  |  3  | // list broadcasting just materializes and dispatches to NoBroadcast
347            // PrimitiveToList  | [3] | [2] | [2] |
348
349            self.swapped = true;
350
351            match (&self.op_apply_type, &self.broadcast) {
352                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
353                | (BinaryOpApplyType::ListToList, Broadcast::Right)
354                | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
355                | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
356                    self.swapped = false;
357                    self._finish_impl_dispatch()
358                },
359                (BinaryOpApplyType::ListToList, Broadcast::Left) => {
360                    self.broadcast = Broadcast::Right;
361
362                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
363                    self._finish_impl_dispatch()
364                },
365                (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
366                    self.list_to_prim_lhs
367                        .replace(Self::materialize_broadcasted_list(
368                            &mut self.data_lhs,
369                            self.output_len,
370                            &self.output_primitive_dtype,
371                        ));
372
373                    self.broadcast = Broadcast::NoBroadcast;
374
375                    // This does not swap! We are just dispatching to `NoBroadcast`
376                    // after materializing the broadcasted list array.
377                    self.swapped = false;
378                    self._finish_impl_dispatch()
379                },
380                (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
381                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
382
383                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
384                    self._finish_impl_dispatch()
385                },
386                (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
387                    // We materialize the list columns with `new_from_index`, as otherwise we'd have to
388                    // implement logic that broadcasts the offsets and validities across multiple levels
389                    // of nesting. But we will re-use the materialized memory to store the result.
390
391                    self.list_to_prim_lhs
392                        .replace(Self::materialize_broadcasted_list(
393                            &mut self.data_rhs,
394                            self.output_len,
395                            &self.output_primitive_dtype,
396                        ));
397
398                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
399                    self.broadcast = Broadcast::NoBroadcast;
400
401                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
402                    self._finish_impl_dispatch()
403                },
404                (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
405                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
406                    self.broadcast = Broadcast::Right;
407
408                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
409                    self._finish_impl_dispatch()
410                },
411            }
412        }
413
414        fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
415            let output_dtype = self.output_dtype.clone();
416            let output_len = self.output_len;
417
418            let prim_lhs = self
419                .data_lhs
420                .2
421                .get_leaf_array()
422                .cast(&self.output_primitive_dtype)?
423                .rechunk();
424            let prim_rhs = self
425                .data_rhs
426                .2
427                .get_leaf_array()
428                .cast(&self.output_primitive_dtype)?
429                .rechunk();
430
431            debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
432            let prim_dtype = prim_lhs.dtype();
433            debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
434
435            // Safety: Leaf dtypes have been checked to be numeric by `try_new()`
436            let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
437                self._finish_impl::<$T>(prim_lhs, prim_rhs)
438            })?;
439
440            debug_assert_eq!(out.dtype(), &output_dtype);
441            assert_eq!(out.len(), output_len);
442
443            Ok(out)
444        }
445
446        /// Internal use only - contains physical impls.
447        fn _finish_impl<T: PolarsNumericType>(
448            &mut self,
449            prim_s_lhs: Series,
450            prim_s_rhs: Series,
451        ) -> PolarsResult<ListChunked>
452        where
453            T::Native: PlNumArithmetic,
454            PrimitiveArray<T::Native>:
455                polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
456            T::Native: Zero + IsFloat,
457        {
458            #[inline(never)]
459            fn check_mismatch_pos(
460                mismatch_pos: usize,
461                offsets_lhs: &OffsetsBuffer<i64>,
462                offsets_rhs: &OffsetsBuffer<i64>,
463            ) -> PolarsResult<()> {
464                if mismatch_pos < offsets_lhs.len_proxy() {
465                    // RHS could be broadcasted
466                    let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
467                        0
468                    } else {
469                        mismatch_pos
470                    });
471                    polars_bail!(
472                        ShapeMismatch:
473                        "list lengths differed at index {}: {} != {}",
474                        mismatch_pos,
475                        offsets_lhs.length_at(mismatch_pos), len_r
476                    )
477                }
478                Ok(())
479            }
480
481            let mut arr_lhs = {
482                let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
483                assert_eq!(ca.chunks().len(), 1);
484                ca.downcast_get(0).unwrap().clone()
485            };
486
487            let mut arr_rhs = {
488                let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
489                assert_eq!(ca.chunks().len(), 1);
490                ca.downcast_get(0).unwrap().clone()
491            };
492
493            match (&self.op_apply_type, &self.broadcast) {
494                // We skip for this because it dispatches to `ArithmeticKernel`, which handles the
495                // validities for us.
496                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
497                _ if self.list_to_prim_lhs.is_none() => {
498                    self.op.0.prepare_numeric_op_side_validities::<T>(
499                        &mut arr_lhs,
500                        &mut arr_rhs,
501                        self.swapped,
502                    )
503                },
504                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
505                    // `self.list_to_prim_lhs` is `Some(_)`, this is handled later.
506                },
507                _ => unreachable!(),
508            }
509
510            //
511            // General notes
512            // * Lists can be:
513            //   * Sliced, in which case the primitive/leaf array needs to be indexed starting from an
514            //     offset instead of 0.
515            //   * Masked, in which case the masked rows are permitted to have non-matching widths.
516            //
517
518            let out = match (&self.op_apply_type, &self.broadcast) {
519                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
520                    let offsets_lhs = &self.data_lhs.0[0];
521                    let offsets_rhs = &self.data_rhs.0[0];
522
523                    assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
524
525                    // Output primitive (and optional validity) are aligned to the LHS input.
526                    let n_values = arr_lhs.len();
527                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
528                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
529
530                    // Counter that stops being incremented at the first row position with mismatching
531                    // list lengths.
532                    let mut mismatch_pos = 0;
533
534                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
535                        for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
536                            .offset_and_length_iter()
537                            .zip(offsets_rhs.offset_and_length_iter())
538                            .enumerate()
539                        {
540                            if
541                                (mismatch_pos == i)
542                                & (
543                                    (lhs_len == rhs_len)
544                                    | unsafe { !self.outer_validity.get_bit_unchecked(i) }
545                                )
546                            {
547                                mismatch_pos += 1;
548                            }
549
550                            // Both sides are lists, we restrict the index to the min length to avoid
551                            // OOB memory access.
552                            let len: usize = lhs_len.min(rhs_len);
553
554                            for i in 0..len {
555                                let l_idx = i + lhs_start;
556                                let r_idx = i + rhs_start;
557
558                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
559                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
560                                let v = $OP(l, r);
561
562                                unsafe { out_ptr.add(l_idx).write(v) };
563                            }
564                        }
565                    });
566
567                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
568
569                    unsafe { out_vec.set_len(n_values) };
570
571                    /// Reduce monomorphization
572                    #[inline(never)]
573                    fn combine_validities_list_to_list_no_broadcast(
574                        offsets_lhs: &OffsetsBuffer<i64>,
575                        offsets_rhs: &OffsetsBuffer<i64>,
576                        validity_lhs: Option<&Bitmap>,
577                        validity_rhs: Option<&Bitmap>,
578                        len_lhs: usize,
579                    ) -> Option<Bitmap> {
580                        match (validity_lhs, validity_rhs) {
581                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
582                            (Some(v), None) => return Some(v.clone()),
583                            (None, Some(v)) => {
584                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
585                            },
586                            (None, None) => None,
587                        }
588                        .map(|(mut validity_out, validity_rhs)| {
589                            for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
590                                .offset_and_length_iter()
591                                .zip(offsets_rhs.offset_and_length_iter())
592                            {
593                                let len: usize = lhs_len.min(rhs_len);
594
595                                for i in 0..len {
596                                    let l_idx = i + lhs_start;
597                                    let r_idx = i + rhs_start;
598
599                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
600                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
601                                    let is_valid = l_valid & r_valid;
602
603                                    // Size and alignment of validity vec are based on LHS.
604                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
605                                }
606                            }
607
608                            validity_out.freeze()
609                        })
610                    }
611
612                    let leaf_validity = combine_validities_list_to_list_no_broadcast(
613                        offsets_lhs,
614                        offsets_rhs,
615                        arr_lhs.validity(),
616                        arr_rhs.validity(),
617                        arr_lhs.len(),
618                    );
619
620                    let arr =
621                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
622
623                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
624                    assert_eq!(offsets.len(), 1);
625
626                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
627                },
628                (BinaryOpApplyType::ListToList, Broadcast::Right) => {
629                    let offsets_lhs = &self.data_lhs.0[0];
630                    let offsets_rhs = &self.data_rhs.0[0];
631
632                    // Output primitive (and optional validity) are aligned to the LHS input.
633                    let n_values = arr_lhs.len();
634                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
635                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
636
637                    assert_eq!(offsets_rhs.len_proxy(), 1);
638                    let rhs_start = *offsets_rhs.first() as usize;
639                    let width = offsets_rhs.range() as usize;
640
641                    let mut mismatch_pos = 0;
642
643                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
644                        for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
645                            if ((lhs_len == width) & (mismatch_pos == i))
646                                | unsafe { !self.outer_validity.get_bit_unchecked(i) }
647                            {
648                                mismatch_pos += 1;
649                            }
650
651                            let len: usize = lhs_len.min(width);
652
653                            for i in 0..len {
654                                let l_idx = i + lhs_start;
655                                let r_idx = i + rhs_start;
656
657                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
658                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
659                                let v = $OP(l, r);
660
661                                unsafe {
662                                    out_ptr.add(l_idx).write(v);
663                                }
664                            }
665                        }
666                    });
667
668                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
669
670                    unsafe { out_vec.set_len(n_values) };
671
672                    #[inline(never)]
673                    fn combine_validities_list_to_list_broadcast_right(
674                        offsets_lhs: &OffsetsBuffer<i64>,
675                        validity_lhs: Option<&Bitmap>,
676                        validity_rhs: Option<&Bitmap>,
677                        len_lhs: usize,
678                        width: usize,
679                        rhs_start: usize,
680                    ) -> Option<Bitmap> {
681                        match (validity_lhs, validity_rhs) {
682                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
683                            (Some(v), None) => return Some(v.clone()),
684                            (None, Some(v)) => {
685                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
686                            },
687                            (None, None) => None,
688                        }
689                        .map(|(mut validity_out, validity_rhs)| {
690                            for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
691                                let len: usize = lhs_len.min(width);
692
693                                for i in 0..len {
694                                    let l_idx = i + lhs_start;
695                                    let r_idx = i + rhs_start;
696
697                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
698                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
699                                    let is_valid = l_valid & r_valid;
700
701                                    // Size and alignment of validity vec are based on LHS.
702                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
703                                }
704                            }
705
706                            validity_out.freeze()
707                        })
708                    }
709
710                    let leaf_validity = combine_validities_list_to_list_broadcast_right(
711                        offsets_lhs,
712                        arr_lhs.validity(),
713                        arr_rhs.validity(),
714                        arr_lhs.len(),
715                        width,
716                        rhs_start,
717                    );
718
719                    let arr =
720                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
721
722                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
723                    assert_eq!(offsets.len(), 1);
724
725                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
726                },
727                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
728                    if self.list_to_prim_lhs.is_none() =>
729                {
730                    let offsets_lhs = self.data_lhs.0.as_slice();
731
732                    // Notes
733                    // * Primitive indexing starts from 0
734                    // * Output is aligned to LHS array
735
736                    let n_values = arr_lhs.len();
737                    let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
738                    let out_ptr = out_vec.as_mut_ptr();
739
740                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
741                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
742                        {
743                            let r = unsafe { arr_rhs.value_unchecked(i) };
744                            for l_idx in l_range {
745                                unsafe {
746                                    let l = arr_lhs.value_unchecked(l_idx);
747                                    let v = $OP(l, r);
748                                    out_ptr.add(l_idx).write(v);
749                                }
750                            }
751                        }
752                    });
753
754                    unsafe { out_vec.set_len(n_values) }
755
756                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
757                        offsets_lhs,
758                        arr_lhs.validity(),
759                        arr_rhs.validity(),
760                        arr_lhs.len(),
761                    );
762
763                    let arr =
764                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
765
766                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
767                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
768                },
769                // If we are dispatched here, it means that the LHS array is a unique allocation created
770                // after a unit-length list column was broadcasted, so this codepath mutably stores the
771                // results back into the LHS array to save memory.
772                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
773                    let offsets_lhs = self.data_lhs.0.as_slice();
774
775                    let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
776                    let arr = arr
777                        .as_any_mut()
778                        .downcast_mut::<PrimitiveArray<T::Native>>()
779                        .unwrap();
780                    let mut arr_lhs = std::mem::take(arr);
781
782                    self.op.0.prepare_numeric_op_side_validities::<T>(
783                        &mut arr_lhs,
784                        &mut arr_rhs,
785                        self.swapped,
786                    );
787
788                    let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
789                    assert_eq!(arr_lhs_mut_slice.len(), n_values);
790
791                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
792                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
793                        {
794                            let r = unsafe { arr_rhs.value_unchecked(i) };
795                            for l_idx in l_range {
796                                unsafe {
797                                    let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
798                                    *l = $OP(*l, r);
799                                }
800                            }
801                        }
802                    });
803
804                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
805                        offsets_lhs,
806                        arr_lhs.validity(),
807                        arr_rhs.validity(),
808                        arr_lhs.len(),
809                    );
810
811                    let arr = arr_lhs.with_validity(leaf_validity);
812
813                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
814                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
815                },
816                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
817                    assert_eq!(arr_rhs.len(), 1);
818
819                    let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
820                        // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL.
821                        let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
822                        return Ok(self.finish_offsets_and_validities(
823                            Box::new(
824                                arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
825                                    false,
826                                    arr_lhs.len(),
827                                ))),
828                            ),
829                            offsets,
830                            validities,
831                        ));
832                    };
833
834                    let arr = self
835                        .op
836                        .0
837                        .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
838                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
839
840                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
841                },
842                v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
843                | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
844                | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
845                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
846                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
847                    if cfg!(debug_assertions) {
848                        panic!("operation was not re-written: {v:?}")
849                    } else {
850                        unreachable!()
851                    }
852                },
853            };
854
855            Ok(out)
856        }
857
858        /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every
859        /// level.
860        fn finish_offsets_and_validities(
861            &mut self,
862            leaf_array: Box<dyn Array>,
863            offsets: Vec<OffsetsBuffer<i64>>,
864            validities: Vec<Option<Bitmap>>,
865        ) -> ListChunked {
866            assert!(!offsets.is_empty());
867            assert_eq!(offsets.len(), validities.len());
868            let mut results = leaf_array;
869
870            let mut iter = offsets.into_iter().zip(validities).rev();
871
872            while iter.len() > 1 {
873                let (offsets, validity) = iter.next().unwrap();
874                let dtype = LargeListArray::default_datatype(results.dtype().clone());
875                results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
876            }
877
878            // The combined outer validity is pre-computed during `try_new()`
879            let (offsets, _) = iter.next().unwrap();
880            let validity = std::mem::take(&mut self.outer_validity);
881            let dtype = LargeListArray::default_datatype(results.dtype().clone());
882            let results = LargeListArray::new(dtype, offsets, results, Some(validity));
883
884            ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
885        }
886
887        fn materialize_broadcasted_list(
888            side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
889            output_len: usize,
890            output_primitive_dtype: &DataType,
891        ) -> (Box<dyn Array>, usize) {
892            let s = &side_data.2;
893            assert_eq!(s.len(), 1);
894
895            let expected_n_values = {
896                let offsets = s.list_offsets_and_validities_recursive().0;
897                output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
898            };
899
900            let ca = s.list().unwrap();
901            // Remember to cast the leaf primitives to the supertype.
902            let ca = ca
903                .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
904                .unwrap();
905            assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data.
906            let ca = ca.new_from_index(0, output_len).rechunk();
907
908            let s = ca.into_series();
909
910            *side_data = {
911                let (a, b) = s.list_offsets_and_validities_recursive();
912                // `Series::default()`: This field in the tuple is no longer used.
913                (a, b, Series::default())
914            };
915
916            let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
917            assert_eq!(n_values, expected_n_values);
918
919            let mut s = s.get_leaf_array();
920            let v = unsafe { s.chunks_mut() };
921
922            assert_eq!(v.len(), 1);
923            (v.swap_remove(0), n_values)
924        }
925    }
926
927    /// Used in 2 places, so it's outside here.
928    #[inline(never)]
929    fn combine_validities_list_to_primitive_no_broadcast(
930        offsets_lhs: &[OffsetsBuffer<i64>],
931        validity_lhs: Option<&Bitmap>,
932        validity_rhs: Option<&Bitmap>,
933        len_lhs: usize,
934    ) -> Option<Bitmap> {
935        match (validity_lhs, validity_rhs) {
936            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
937            (Some(v), None) => return Some(v.clone()),
938            // Materialize a full-true validity to re-use the codepath, as we still
939            // need to spread the bits from the RHS to the correct positions.
940            (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
941            (None, None) => None,
942        }
943        .map(|(mut validity_out, validity_rhs)| {
944            for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
945                let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
946                for l_idx in l_range {
947                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
948                    let is_valid = l_valid & r_valid;
949
950                    // Size and alignment of validity vec are based on LHS.
951                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
952                }
953            }
954
955            validity_out.freeze()
956        })
957    }
958}