polars_core/series/arithmetic/
list.rs

1//! Allow arithmetic operations for ListChunked.
2//! use polars_error::{feature_gated, PolarsResult};
3
4use polars_error::{PolarsResult, feature_gated};
5
6use super::list_utils::NumericOp;
7use super::{IntoSeries, ListChunked, ListType, NumOpsDispatchInner, Series};
8
9impl NumOpsDispatchInner for ListType {
10    fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
11        NumericListOp::add().execute(&lhs.clone().into_series(), rhs)
12    }
13
14    fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
15        NumericListOp::sub().execute(&lhs.clone().into_series(), rhs)
16    }
17
18    fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
19        NumericListOp::mul().execute(&lhs.clone().into_series(), rhs)
20    }
21
22    fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
23        NumericListOp::div().execute(&lhs.clone().into_series(), rhs)
24    }
25
26    fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
27        NumericListOp::rem().execute(&lhs.clone().into_series(), rhs)
28    }
29}
30
31#[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
32#[derive(Clone)]
33pub struct NumericListOp(NumericOp);
34
35impl NumericListOp {
36    pub fn add() -> Self {
37        Self(NumericOp::Add)
38    }
39
40    pub fn sub() -> Self {
41        Self(NumericOp::Sub)
42    }
43
44    pub fn mul() -> Self {
45        Self(NumericOp::Mul)
46    }
47
48    pub fn div() -> Self {
49        Self(NumericOp::Div)
50    }
51
52    pub fn rem() -> Self {
53        Self(NumericOp::Rem)
54    }
55
56    pub fn floor_div() -> Self {
57        Self(NumericOp::FloorDiv)
58    }
59}
60
61impl NumericListOp {
62    #[cfg_attr(not(feature = "list_arithmetic"), allow(unused))]
63    pub fn execute(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
64        feature_gated!("list_arithmetic", {
65            use std::borrow::Cow;
66
67            use either::Either;
68
69            // `trim_to_normalized_offsets` ensures we don't perform excessive
70            // memory allocation / compute on memory regions that have been
71            // sliced out.
72            let lhs = lhs
73                .trim_lists_to_normalized_offsets()
74                .map_or(Cow::Borrowed(lhs), Cow::Owned);
75            let rhs = rhs
76                .trim_lists_to_normalized_offsets()
77                .map_or(Cow::Borrowed(rhs), Cow::Owned);
78
79            let lhs = lhs.rechunk();
80            let rhs = rhs.rechunk();
81
82            let binary_op_exec = match ListNumericOpHelper::try_new(
83                self.clone(),
84                lhs.name().clone(),
85                lhs.dtype(),
86                rhs.dtype(),
87                lhs.len(),
88                rhs.len(),
89                {
90                    let (a, b) = lhs.list_offsets_and_validities_recursive();
91                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
92                    (a, b, lhs.clone())
93                },
94                {
95                    let (a, b) = rhs.list_offsets_and_validities_recursive();
96                    debug_assert!(a.iter().all(|x| *x.first() as usize == 0));
97                    (a, b, rhs.clone())
98                },
99                lhs.rechunk_validity(),
100                rhs.rechunk_validity(),
101            )? {
102                Either::Left(v) => v,
103                Either::Right(ca) => return Ok(ca.into_series()),
104            };
105
106            Ok(binary_op_exec.finish()?.into_series())
107        })
108    }
109}
110
111#[cfg(feature = "list_arithmetic")]
112use inner::ListNumericOpHelper;
113
114#[cfg(feature = "list_arithmetic")]
115mod inner {
116    use arrow::bitmap::Bitmap;
117    use arrow::compute::utils::combine_validities_and;
118    use arrow::offset::OffsetsBuffer;
119    use either::Either;
120    use list_utils::with_match_pl_num_arith;
121    use num_traits::Zero;
122    use polars_compute::arithmetic::pl_num::PlNumArithmetic;
123    use polars_utils::float::IsFloat;
124
125    use super::super::list_utils::{BinaryOpApplyType, Broadcast, NumericOp};
126    use super::super::*;
127
128    /// Utility to perform a binary operation between the primitive values of
129    /// 2 columns, where at least one of the columns is a `ListChunked` type.
130    pub(super) struct ListNumericOpHelper {
131        op: NumericListOp,
132        output_name: PlSmallStr,
133        op_apply_type: BinaryOpApplyType,
134        broadcast: Broadcast,
135        output_dtype: DataType,
136        output_primitive_dtype: DataType,
137        output_len: usize,
138        /// Outer validity of the result, we always materialize this to reduce the
139        /// amount of code paths we need.
140        outer_validity: Bitmap,
141        // The series are stored as they are used for list broadcasting.
142        data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
143        data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
144        list_to_prim_lhs: Option<(Box<dyn Array>, usize)>,
145        swapped: bool,
146    }
147
148    /// This lets us separate some logic into `new()` to reduce the amount of
149    /// monomorphized code.
150    impl ListNumericOpHelper {
151        /// Checks that:
152        /// * Dtypes are compatible:
153        ///   * list<->primitive | primitive<->list
154        ///   * list<->list both contain primitives (e.g. List<Int8>)
155        /// * Primitive dtypes match
156        /// * Lengths are compatible:
157        ///   * 1<->n | n<->1
158        ///   * n<->n
159        /// * Both sides have at least 1 non-NULL outer row.
160        ///
161        /// Does not check:
162        /// * Whether the offsets are aligned for list<->list, this will be checked during execution.
163        ///
164        /// This returns an `Either` which may contain the final result to simplify
165        /// the implementation.
166        #[allow(clippy::too_many_arguments)]
167        pub(super) fn try_new(
168            op: NumericListOp,
169            output_name: PlSmallStr,
170            dtype_lhs: &DataType,
171            dtype_rhs: &DataType,
172            len_lhs: usize,
173            len_rhs: usize,
174            data_lhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
175            data_rhs: (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
176            validity_lhs: Option<Bitmap>,
177            validity_rhs: Option<Bitmap>,
178        ) -> PolarsResult<Either<Self, ListChunked>> {
179            let prim_dtype_lhs = dtype_lhs.leaf_dtype();
180            let prim_dtype_rhs = dtype_rhs.leaf_dtype();
181
182            let output_primitive_dtype =
183                op.0.try_get_leaf_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
184
185            fn is_list_type_at_all_levels(dtype: &DataType) -> bool {
186                match dtype {
187                    DataType::List(inner) => is_list_type_at_all_levels(inner),
188                    dt if dt.is_supported_list_arithmetic_input() => true,
189                    _ => false,
190                }
191            }
192
193            let op_err_msg = |err_reason: &str| {
194                polars_err!(
195                    InvalidOperation:
196                    "cannot {} columns: {}: (left: {}, right: {})",
197                    op.0.name(), err_reason, dtype_lhs, dtype_rhs,
198                )
199            };
200
201            let ensure_list_type_at_all_levels = |dtype: &DataType| {
202                if !is_list_type_at_all_levels(dtype) {
203                    Err(op_err_msg("dtype was not list on all nesting levels"))
204                } else {
205                    Ok(())
206                }
207            };
208
209            let (op_apply_type, output_dtype) = match (dtype_lhs, dtype_rhs) {
210                (l @ DataType::List(a), r @ DataType::List(b)) => {
211                    // `get_arithmetic_field()` in the DSL checks this, but we also have to check here because if a user
212                    // directly adds 2 series together it bypasses the DSL.
213                    // This is currently duplicated code and should be replaced one day with an assert after Series ops get
214                    // checked properly.
215                    if ![a, b]
216                        .into_iter()
217                        .all(|x| x.is_supported_list_arithmetic_input())
218                    {
219                        polars_bail!(
220                            InvalidOperation:
221                            "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
222                            op.0.name(), l, r,
223                        );
224                    }
225                    (BinaryOpApplyType::ListToList, l)
226                },
227                (list_dtype @ DataType::List(_), x) if x.is_supported_list_arithmetic_input() => {
228                    ensure_list_type_at_all_levels(list_dtype)?;
229                    (BinaryOpApplyType::ListToPrimitive, list_dtype)
230                },
231                (x, list_dtype @ DataType::List(_)) if x.is_supported_list_arithmetic_input() => {
232                    ensure_list_type_at_all_levels(list_dtype)?;
233                    (BinaryOpApplyType::PrimitiveToList, list_dtype)
234                },
235                (l, r) => polars_bail!(
236                    InvalidOperation:
237                    "{} operation not supported for dtypes: {} != {}",
238                    op.0.name(), l, r,
239                ),
240            };
241
242            let output_dtype = output_dtype.cast_leaf(output_primitive_dtype.clone());
243
244            let (broadcast, output_len) = match (len_lhs, len_rhs) {
245                (l, r) if l == r => (Broadcast::NoBroadcast, l),
246                (1, v) => (Broadcast::Left, v),
247                (v, 1) => (Broadcast::Right, v),
248                (l, r) => polars_bail!(
249                    ShapeMismatch:
250                    "cannot {} two columns of differing lengths: {} != {}",
251                    op.0.name(), l, r
252                ),
253            };
254
255            let DataType::List(output_inner_dtype) = &output_dtype else {
256                unreachable!()
257            };
258
259            // # NULL semantics
260            // * [[1, 2]] (List[List[Int64]]) + NULL (Int64) => [[NULL, NULL]]
261            //   * Essentially as if the NULL primitive was added to every primitive in the row of the list column.
262            // * NULL (List[Int64]) + 1   (Int64)       => NULL
263            // * NULL (List[Int64]) + [1] (List[Int64]) => NULL
264
265            if output_len == 0
266                || (matches!(
267                    &op_apply_type,
268                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive
269                ) && validity_lhs.as_ref().is_some_and(|x| x.set_bits() == 0))
270                || (matches!(
271                    &op_apply_type,
272                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList
273                ) && validity_rhs.as_ref().is_some_and(|x| x.set_bits() == 0))
274            {
275                return Ok(Either::Right(ListChunked::full_null_with_dtype(
276                    output_name,
277                    output_len,
278                    output_inner_dtype.as_ref(),
279                )));
280            }
281
282            // At this point:
283            // * All unit length list columns have a valid outer value.
284
285            // The outer validity is just the validity of any non-broadcasting lists.
286            let outer_validity = match (&op_apply_type, &broadcast, validity_lhs, validity_rhs) {
287                // Both lists with same length, we combine the validity.
288                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast, l, r) => {
289                    combine_validities_and(l.as_ref(), r.as_ref())
290                },
291                // Match all other combinations that have non-broadcasting lists.
292                (
293                    BinaryOpApplyType::ListToList | BinaryOpApplyType::ListToPrimitive,
294                    Broadcast::NoBroadcast | Broadcast::Right,
295                    v,
296                    _,
297                )
298                | (
299                    BinaryOpApplyType::ListToList | BinaryOpApplyType::PrimitiveToList,
300                    Broadcast::NoBroadcast | Broadcast::Left,
301                    _,
302                    v,
303                ) => v,
304                _ => None,
305            }
306            .unwrap_or_else(|| Bitmap::new_with_value(true, output_len));
307
308            Ok(Either::Left(Self {
309                op,
310                output_name,
311                op_apply_type,
312                broadcast,
313                output_dtype: output_dtype.clone(),
314                output_primitive_dtype,
315                output_len,
316                outer_validity,
317                data_lhs,
318                data_rhs,
319                list_to_prim_lhs: None,
320                swapped: false,
321            }))
322        }
323
324        pub(super) fn finish(mut self) -> PolarsResult<ListChunked> {
325            // We have physical codepaths for a subset of the possible combinations of broadcasting and
326            // column types. The remaining combinations are handled by dispatching to the physical
327            // codepaths after operand swapping and/or materialized broadcasting.
328            //
329            // # Physical impl table
330            // Legend
331            // * |  N  | // impl "N"
332            // * | [N] | // dispatches to impl "N"
333            //
334            //                  |  L  |  N  |  R  | // Broadcast (L)eft, (N)oBroadcast, (R)ight
335            // ListToList       | [1] |  0  |  1  |
336            // ListToPrimitive  | [2] |  2  |  3  | // list broadcasting just materializes and dispatches to NoBroadcast
337            // PrimitiveToList  | [3] | [2] | [2] |
338
339            self.swapped = true;
340
341            match (&self.op_apply_type, &self.broadcast) {
342                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast)
343                | (BinaryOpApplyType::ListToList, Broadcast::Right)
344                | (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
345                | (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
346                    self.swapped = false;
347                    self._finish_impl_dispatch()
348                },
349                (BinaryOpApplyType::ListToList, Broadcast::Left) => {
350                    self.broadcast = Broadcast::Right;
351
352                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
353                    self._finish_impl_dispatch()
354                },
355                (BinaryOpApplyType::ListToPrimitive, Broadcast::Left) => {
356                    self.list_to_prim_lhs
357                        .replace(Self::materialize_broadcasted_list(
358                            &mut self.data_lhs,
359                            self.output_len,
360                            &self.output_primitive_dtype,
361                        ));
362
363                    self.broadcast = Broadcast::NoBroadcast;
364
365                    // This does not swap! We are just dispatching to `NoBroadcast`
366                    // after materializing the broadcasted list array.
367                    self.swapped = false;
368                    self._finish_impl_dispatch()
369                },
370                (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
371                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
372
373                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
374                    self._finish_impl_dispatch()
375                },
376                (BinaryOpApplyType::PrimitiveToList, Broadcast::Right) => {
377                    // We materialize the list columns with `new_from_index`, as otherwise we'd have to
378                    // implement logic that broadcasts the offsets and validities across multiple levels
379                    // of nesting. But we will re-use the materialized memory to store the result.
380
381                    self.list_to_prim_lhs
382                        .replace(Self::materialize_broadcasted_list(
383                            &mut self.data_rhs,
384                            self.output_len,
385                            &self.output_primitive_dtype,
386                        ));
387
388                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
389                    self.broadcast = Broadcast::NoBroadcast;
390
391                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
392                    self._finish_impl_dispatch()
393                },
394                (BinaryOpApplyType::PrimitiveToList, Broadcast::Left) => {
395                    self.op_apply_type = BinaryOpApplyType::ListToPrimitive;
396                    self.broadcast = Broadcast::Right;
397
398                    std::mem::swap(&mut self.data_lhs, &mut self.data_rhs);
399                    self._finish_impl_dispatch()
400                },
401            }
402        }
403
404        fn _finish_impl_dispatch(&mut self) -> PolarsResult<ListChunked> {
405            let output_dtype = self.output_dtype.clone();
406            let output_len = self.output_len;
407
408            let prim_lhs = self
409                .data_lhs
410                .2
411                .get_leaf_array()
412                .cast(&self.output_primitive_dtype)?
413                .rechunk();
414            let prim_rhs = self
415                .data_rhs
416                .2
417                .get_leaf_array()
418                .cast(&self.output_primitive_dtype)?
419                .rechunk();
420
421            debug_assert_eq!(prim_lhs.dtype(), prim_rhs.dtype());
422            let prim_dtype = prim_lhs.dtype();
423            debug_assert_eq!(prim_dtype, &self.output_primitive_dtype);
424
425            // Safety: Leaf dtypes have been checked to be numeric by `try_new()`
426            let out = with_match_physical_numeric_polars_type!(&prim_dtype, |$T| {
427                self._finish_impl::<$T>(prim_lhs, prim_rhs)
428            })?;
429
430            debug_assert_eq!(out.dtype(), &output_dtype);
431            assert_eq!(out.len(), output_len);
432
433            Ok(out)
434        }
435
436        /// Internal use only - contains physical impls.
437        fn _finish_impl<T: PolarsNumericType>(
438            &mut self,
439            prim_s_lhs: Series,
440            prim_s_rhs: Series,
441        ) -> PolarsResult<ListChunked>
442        where
443            T::Native: PlNumArithmetic,
444            PrimitiveArray<T::Native>:
445                polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
446            T::Native: Zero + IsFloat,
447        {
448            #[inline(never)]
449            fn check_mismatch_pos(
450                mismatch_pos: usize,
451                offsets_lhs: &OffsetsBuffer<i64>,
452                offsets_rhs: &OffsetsBuffer<i64>,
453            ) -> PolarsResult<()> {
454                if mismatch_pos < offsets_lhs.len_proxy() {
455                    // RHS could be broadcasted
456                    let len_r = offsets_rhs.length_at(if offsets_rhs.len_proxy() == 1 {
457                        0
458                    } else {
459                        mismatch_pos
460                    });
461                    polars_bail!(
462                        ShapeMismatch:
463                        "list lengths differed at index {}: {} != {}",
464                        mismatch_pos,
465                        offsets_lhs.length_at(mismatch_pos), len_r
466                    )
467                }
468                Ok(())
469            }
470
471            let mut arr_lhs = {
472                let ca: &ChunkedArray<T> = prim_s_lhs.as_ref().as_ref();
473                assert_eq!(ca.chunks().len(), 1);
474                ca.downcast_get(0).unwrap().clone()
475            };
476
477            let mut arr_rhs = {
478                let ca: &ChunkedArray<T> = prim_s_rhs.as_ref().as_ref();
479                assert_eq!(ca.chunks().len(), 1);
480                ca.downcast_get(0).unwrap().clone()
481            };
482
483            match (&self.op_apply_type, &self.broadcast) {
484                // We skip for this because it dispatches to `ArithmeticKernel`, which handles the
485                // validities for us.
486                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {},
487                _ if self.list_to_prim_lhs.is_none() => {
488                    self.op.0.prepare_numeric_op_side_validities::<T>(
489                        &mut arr_lhs,
490                        &mut arr_rhs,
491                        self.swapped,
492                    )
493                },
494                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
495                    // `self.list_to_prim_lhs` is `Some(_)`, this is handled later.
496                },
497                _ => unreachable!(),
498            }
499
500            //
501            // General notes
502            // * Lists can be:
503            //   * Sliced, in which case the primitive/leaf array needs to be indexed starting from an
504            //     offset instead of 0.
505            //   * Masked, in which case the masked rows are permitted to have non-matching widths.
506            //
507
508            let out = match (&self.op_apply_type, &self.broadcast) {
509                (BinaryOpApplyType::ListToList, Broadcast::NoBroadcast) => {
510                    let offsets_lhs = &self.data_lhs.0[0];
511                    let offsets_rhs = &self.data_rhs.0[0];
512
513                    assert_eq!(offsets_lhs.len_proxy(), offsets_rhs.len_proxy());
514
515                    // Output primitive (and optional validity) are aligned to the LHS input.
516                    let n_values = arr_lhs.len();
517                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
518                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
519
520                    // Counter that stops being incremented at the first row position with mismatching
521                    // list lengths.
522                    let mut mismatch_pos = 0;
523
524                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
525                        for (i, ((lhs_start, lhs_len), (rhs_start, rhs_len))) in offsets_lhs
526                            .offset_and_length_iter()
527                            .zip(offsets_rhs.offset_and_length_iter())
528                            .enumerate()
529                        {
530                            if
531                                (mismatch_pos == i)
532                                & (
533                                    (lhs_len == rhs_len)
534                                    | unsafe { !self.outer_validity.get_bit_unchecked(i) }
535                                )
536                            {
537                                mismatch_pos += 1;
538                            }
539
540                            // Both sides are lists, we restrict the index to the min length to avoid
541                            // OOB memory access.
542                            let len: usize = lhs_len.min(rhs_len);
543
544                            for i in 0..len {
545                                let l_idx = i + lhs_start;
546                                let r_idx = i + rhs_start;
547
548                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
549                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
550                                let v = $OP(l, r);
551
552                                unsafe { out_ptr.add(l_idx).write(v) };
553                            }
554                        }
555                    });
556
557                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
558
559                    unsafe { out_vec.set_len(n_values) };
560
561                    /// Reduce monomorphization
562                    #[inline(never)]
563                    fn combine_validities_list_to_list_no_broadcast(
564                        offsets_lhs: &OffsetsBuffer<i64>,
565                        offsets_rhs: &OffsetsBuffer<i64>,
566                        validity_lhs: Option<&Bitmap>,
567                        validity_rhs: Option<&Bitmap>,
568                        len_lhs: usize,
569                    ) -> Option<Bitmap> {
570                        match (validity_lhs, validity_rhs) {
571                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
572                            (Some(v), None) => return Some(v.clone()),
573                            (None, Some(v)) => {
574                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
575                            },
576                            (None, None) => None,
577                        }
578                        .map(|(mut validity_out, validity_rhs)| {
579                            for ((lhs_start, lhs_len), (rhs_start, rhs_len)) in offsets_lhs
580                                .offset_and_length_iter()
581                                .zip(offsets_rhs.offset_and_length_iter())
582                            {
583                                let len: usize = lhs_len.min(rhs_len);
584
585                                for i in 0..len {
586                                    let l_idx = i + lhs_start;
587                                    let r_idx = i + rhs_start;
588
589                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
590                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
591                                    let is_valid = l_valid & r_valid;
592
593                                    // Size and alignment of validity vec are based on LHS.
594                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
595                                }
596                            }
597
598                            validity_out.freeze()
599                        })
600                    }
601
602                    let leaf_validity = combine_validities_list_to_list_no_broadcast(
603                        offsets_lhs,
604                        offsets_rhs,
605                        arr_lhs.validity(),
606                        arr_rhs.validity(),
607                        arr_lhs.len(),
608                    );
609
610                    let arr =
611                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
612
613                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
614                    assert_eq!(offsets.len(), 1);
615
616                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
617                },
618                (BinaryOpApplyType::ListToList, Broadcast::Right) => {
619                    let offsets_lhs = &self.data_lhs.0[0];
620                    let offsets_rhs = &self.data_rhs.0[0];
621
622                    // Output primitive (and optional validity) are aligned to the LHS input.
623                    let n_values = arr_lhs.len();
624                    let mut out_vec: Vec<T::Native> = Vec::with_capacity(n_values);
625                    let out_ptr: *mut T::Native = out_vec.as_mut_ptr();
626
627                    assert_eq!(offsets_rhs.len_proxy(), 1);
628                    let rhs_start = *offsets_rhs.first() as usize;
629                    let width = offsets_rhs.range() as usize;
630
631                    let mut mismatch_pos = 0;
632
633                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
634                        for (i, (lhs_start, lhs_len)) in offsets_lhs.offset_and_length_iter().enumerate() {
635                            if ((lhs_len == width) & (mismatch_pos == i))
636                                | unsafe { !self.outer_validity.get_bit_unchecked(i) }
637                            {
638                                mismatch_pos += 1;
639                            }
640
641                            let len: usize = lhs_len.min(width);
642
643                            for i in 0..len {
644                                let l_idx = i + lhs_start;
645                                let r_idx = i + rhs_start;
646
647                                let l = unsafe { arr_lhs.value_unchecked(l_idx) };
648                                let r = unsafe { arr_rhs.value_unchecked(r_idx) };
649                                let v = $OP(l, r);
650
651                                unsafe {
652                                    out_ptr.add(l_idx).write(v);
653                                }
654                            }
655                        }
656                    });
657
658                    check_mismatch_pos(mismatch_pos, offsets_lhs, offsets_rhs)?;
659
660                    unsafe { out_vec.set_len(n_values) };
661
662                    #[inline(never)]
663                    fn combine_validities_list_to_list_broadcast_right(
664                        offsets_lhs: &OffsetsBuffer<i64>,
665                        validity_lhs: Option<&Bitmap>,
666                        validity_rhs: Option<&Bitmap>,
667                        len_lhs: usize,
668                        width: usize,
669                        rhs_start: usize,
670                    ) -> Option<Bitmap> {
671                        match (validity_lhs, validity_rhs) {
672                            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
673                            (Some(v), None) => return Some(v.clone()),
674                            (None, Some(v)) => {
675                                Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v))
676                            },
677                            (None, None) => None,
678                        }
679                        .map(|(mut validity_out, validity_rhs)| {
680                            for (lhs_start, lhs_len) in offsets_lhs.offset_and_length_iter() {
681                                let len: usize = lhs_len.min(width);
682
683                                for i in 0..len {
684                                    let l_idx = i + lhs_start;
685                                    let r_idx = i + rhs_start;
686
687                                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
688                                    let r_valid = unsafe { validity_rhs.get_bit_unchecked(r_idx) };
689                                    let is_valid = l_valid & r_valid;
690
691                                    // Size and alignment of validity vec are based on LHS.
692                                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
693                                }
694                            }
695
696                            validity_out.freeze()
697                        })
698                    }
699
700                    let leaf_validity = combine_validities_list_to_list_broadcast_right(
701                        offsets_lhs,
702                        arr_lhs.validity(),
703                        arr_rhs.validity(),
704                        arr_lhs.len(),
705                        width,
706                        rhs_start,
707                    );
708
709                    let arr =
710                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
711
712                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
713                    assert_eq!(offsets.len(), 1);
714
715                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
716                },
717                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast)
718                    if self.list_to_prim_lhs.is_none() =>
719                {
720                    let offsets_lhs = self.data_lhs.0.as_slice();
721
722                    // Notes
723                    // * Primitive indexing starts from 0
724                    // * Output is aligned to LHS array
725
726                    let n_values = arr_lhs.len();
727                    let mut out_vec = Vec::<T::Native>::with_capacity(n_values);
728                    let out_ptr = out_vec.as_mut_ptr();
729
730                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
731                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
732                        {
733                            let r = unsafe { arr_rhs.value_unchecked(i) };
734                            for l_idx in l_range {
735                                unsafe {
736                                    let l = arr_lhs.value_unchecked(l_idx);
737                                    let v = $OP(l, r);
738                                    out_ptr.add(l_idx).write(v);
739                                }
740                            }
741                        }
742                    });
743
744                    unsafe { out_vec.set_len(n_values) }
745
746                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
747                        offsets_lhs,
748                        arr_lhs.validity(),
749                        arr_rhs.validity(),
750                        arr_lhs.len(),
751                    );
752
753                    let arr =
754                        PrimitiveArray::<T::Native>::from_vec(out_vec).with_validity(leaf_validity);
755
756                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
757                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
758                },
759                // If we are dispatched here, it means that the LHS array is a unique allocation created
760                // after a unit-length list column was broadcasted, so this codepath mutably stores the
761                // results back into the LHS array to save memory.
762                (BinaryOpApplyType::ListToPrimitive, Broadcast::NoBroadcast) => {
763                    let offsets_lhs = self.data_lhs.0.as_slice();
764
765                    let (mut arr, n_values) = Option::take(&mut self.list_to_prim_lhs).unwrap();
766                    let arr = arr
767                        .as_any_mut()
768                        .downcast_mut::<PrimitiveArray<T::Native>>()
769                        .unwrap();
770                    let mut arr_lhs = std::mem::take(arr);
771
772                    self.op.0.prepare_numeric_op_side_validities::<T>(
773                        &mut arr_lhs,
774                        &mut arr_rhs,
775                        self.swapped,
776                    );
777
778                    let arr_lhs_mut_slice = arr_lhs.get_mut_values().unwrap();
779                    assert_eq!(arr_lhs_mut_slice.len(), n_values);
780
781                    with_match_pl_num_arith!(&self.op.0, self.swapped, |$OP| {
782                        for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate()
783                        {
784                            let r = unsafe { arr_rhs.value_unchecked(i) };
785                            for l_idx in l_range {
786                                unsafe {
787                                    let l = arr_lhs_mut_slice.get_unchecked_mut(l_idx);
788                                    *l = $OP(*l, r);
789                                }
790                            }
791                        }
792                    });
793
794                    let leaf_validity = combine_validities_list_to_primitive_no_broadcast(
795                        offsets_lhs,
796                        arr_lhs.validity(),
797                        arr_rhs.validity(),
798                        arr_lhs.len(),
799                    );
800
801                    let arr = arr_lhs.with_validity(leaf_validity);
802
803                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
804                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
805                },
806                (BinaryOpApplyType::ListToPrimitive, Broadcast::Right) => {
807                    assert_eq!(arr_rhs.len(), 1);
808
809                    let Some(r) = (unsafe { arr_rhs.get_unchecked(0) }) else {
810                        // RHS is single primitive NULL, create the result by setting the leaf validity to all-NULL.
811                        let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
812                        return Ok(self.finish_offsets_and_validities(
813                            Box::new(
814                                arr_lhs.clone().with_validity(Some(Bitmap::new_with_value(
815                                    false,
816                                    arr_lhs.len(),
817                                ))),
818                            ),
819                            offsets,
820                            validities,
821                        ));
822                    };
823
824                    let arr = self
825                        .op
826                        .0
827                        .apply_array_to_scalar::<T>(arr_lhs, r, self.swapped);
828                    let (offsets, validities, _) = std::mem::take(&mut self.data_lhs);
829
830                    self.finish_offsets_and_validities(Box::new(arr), offsets, validities)
831                },
832                v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Right)
833                | v @ (BinaryOpApplyType::ListToList, Broadcast::Left)
834                | v @ (BinaryOpApplyType::ListToPrimitive, Broadcast::Left)
835                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::Left)
836                | v @ (BinaryOpApplyType::PrimitiveToList, Broadcast::NoBroadcast) => {
837                    if cfg!(debug_assertions) {
838                        panic!("operation was not re-written: {v:?}")
839                    } else {
840                        unreachable!()
841                    }
842                },
843            };
844
845            Ok(out)
846        }
847
848        /// Construct the result `ListChunked` from the leaf array and the offsets/validities of every
849        /// level.
850        fn finish_offsets_and_validities(
851            &mut self,
852            leaf_array: Box<dyn Array>,
853            offsets: Vec<OffsetsBuffer<i64>>,
854            validities: Vec<Option<Bitmap>>,
855        ) -> ListChunked {
856            assert!(!offsets.is_empty());
857            assert_eq!(offsets.len(), validities.len());
858            let mut results = leaf_array;
859
860            let mut iter = offsets.into_iter().zip(validities).rev();
861
862            while iter.len() > 1 {
863                let (offsets, validity) = iter.next().unwrap();
864                let dtype = LargeListArray::default_datatype(results.dtype().clone());
865                results = Box::new(LargeListArray::new(dtype, offsets, results, validity));
866            }
867
868            // The combined outer validity is pre-computed during `try_new()`
869            let (offsets, _) = iter.next().unwrap();
870            let validity = std::mem::take(&mut self.outer_validity);
871            let dtype = LargeListArray::default_datatype(results.dtype().clone());
872            let results = LargeListArray::new(dtype, offsets, results, Some(validity));
873
874            ListChunked::with_chunk(std::mem::take(&mut self.output_name), results)
875        }
876
877        fn materialize_broadcasted_list(
878            side_data: &mut (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>, Series),
879            output_len: usize,
880            output_primitive_dtype: &DataType,
881        ) -> (Box<dyn Array>, usize) {
882            let s = &side_data.2;
883            assert_eq!(s.len(), 1);
884
885            let expected_n_values = {
886                let offsets = s.list_offsets_and_validities_recursive().0;
887                output_len * OffsetsBuffer::<i64>::leaf_full_start_end(&offsets).len()
888            };
889
890            let ca = s.list().unwrap();
891            // Remember to cast the leaf primitives to the supertype.
892            let ca = ca
893                .cast(&ca.dtype().cast_leaf(output_primitive_dtype.clone()))
894                .unwrap();
895            assert!(output_len > 1); // In case there is a fast-path that doesn't give us owned data.
896            let ca = ca.new_from_index(0, output_len).rechunk();
897
898            let s = ca.into_series();
899
900            *side_data = {
901                let (a, b) = s.list_offsets_and_validities_recursive();
902                // `Series::default()`: This field in the tuple is no longer used.
903                (a, b, Series::default())
904            };
905
906            let n_values = OffsetsBuffer::<i64>::leaf_full_start_end(&side_data.0).len();
907            assert_eq!(n_values, expected_n_values);
908
909            let mut s = s.get_leaf_array();
910            let v = unsafe { s.chunks_mut() };
911
912            assert_eq!(v.len(), 1);
913            (v.swap_remove(0), n_values)
914        }
915    }
916
917    /// Used in 2 places, so it's outside here.
918    #[inline(never)]
919    fn combine_validities_list_to_primitive_no_broadcast(
920        offsets_lhs: &[OffsetsBuffer<i64>],
921        validity_lhs: Option<&Bitmap>,
922        validity_rhs: Option<&Bitmap>,
923        len_lhs: usize,
924    ) -> Option<Bitmap> {
925        match (validity_lhs, validity_rhs) {
926            (Some(l), Some(r)) => Some((l.clone().make_mut(), r)),
927            (Some(v), None) => return Some(v.clone()),
928            // Materialize a full-true validity to re-use the codepath, as we still
929            // need to spread the bits from the RHS to the correct positions.
930            (None, Some(v)) => Some((Bitmap::new_with_value(true, len_lhs).make_mut(), v)),
931            (None, None) => None,
932        }
933        .map(|(mut validity_out, validity_rhs)| {
934            for (i, l_range) in OffsetsBuffer::<i64>::leaf_ranges_iter(offsets_lhs).enumerate() {
935                let r_valid = unsafe { validity_rhs.get_bit_unchecked(i) };
936                for l_idx in l_range {
937                    let l_valid = unsafe { validity_out.get_unchecked(l_idx) };
938                    let is_valid = l_valid & r_valid;
939
940                    // Size and alignment of validity vec are based on LHS.
941                    unsafe { validity_out.set_unchecked(l_idx, is_valid) };
942                }
943            }
944
945            validity_out.freeze()
946        })
947    }
948}