use arrow::bitmap::Bitmap;
use arrow::legacy::utils::FromTrustedLenIterator;
use polars_compute::comparisons::TotalOrdKernel;
use crate::chunked_array::cast::CastOptions;
use crate::prelude::nulls::replace_non_null;
use crate::prelude::*;
#[cfg(feature = "dtype-categorical")]
fn cat_equality_helper<'a, Compare, Missing>(
lhs: &'a CategoricalChunked,
rhs: &'a CategoricalChunked,
missing_function: Missing,
compare_function: Compare,
) -> PolarsResult<BooleanChunked>
where
Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked,
Missing: Fn(&'a CategoricalChunked) -> BooleanChunked,
{
let rev_map_l = lhs.get_rev_map();
polars_ensure!(rev_map_l.same_src(rhs.get_rev_map()), string_cache_mismatch);
let rhs = rhs.physical();
if rhs.len() == 1 && rhs.null_count() == 0 {
let rhs = rhs.get(0).unwrap();
if rev_map_l.get_optional(rhs).is_none() {
return Ok(missing_function(lhs));
}
}
Ok(compare_function(lhs.physical(), rhs))
}
fn cat_compare_helper<'a, Compare, CompareString>(
lhs: &'a CategoricalChunked,
rhs: &'a CategoricalChunked,
compare_function: Compare,
compare_str_function: CompareString,
) -> PolarsResult<BooleanChunked>
where
Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked,
CompareString: Fn(&str, &str) -> bool,
{
let rev_map_l = lhs.get_rev_map();
let rev_map_r = rhs.get_rev_map();
polars_ensure!(rev_map_l.same_src(rev_map_r), ComputeError: "can only compare categoricals of the same type with the same categories");
if lhs.is_enum() || !lhs.uses_lexical_ordering() {
Ok(compare_function(lhs.physical(), rhs.physical()))
} else {
match (lhs.len(), rhs.len()) {
(lhs_len, 1) => {
let v = unsafe {
rhs.physical()
.get(0)
.map(|phys| rev_map_r.get_unchecked(phys))
};
let Some(v) = v else {
return Ok(BooleanChunked::full_null(lhs.name(), lhs_len));
};
Ok(lhs
.iter_str()
.map(|opt_s| opt_s.map(|s| compare_str_function(s, v)))
.collect_ca_trusted(lhs.name()))
},
(1, rhs_len) => {
let v = unsafe {
lhs.physical()
.get(0)
.map(|phys| rev_map_l.get_unchecked(phys))
};
let Some(v) = v else {
return Ok(BooleanChunked::full_null(lhs.name(), rhs_len));
};
Ok(rhs
.iter_str()
.map(|opt_s| opt_s.map(|s| compare_str_function(v, s)))
.collect_ca_trusted(lhs.name()))
},
(lhs_len, rhs_len) if lhs_len == rhs_len => Ok(lhs
.iter_str()
.zip(rhs.iter_str())
.map(|(l, r)| match (l, r) {
(None, _) => None,
(_, None) => None,
(Some(l), Some(r)) => Some(compare_str_function(l, r)),
})
.collect_ca_trusted(lhs.name())),
(lhs_len, rhs_len) => {
polars_bail!(ComputeError: "Columns are of unequal length: {} vs {}",lhs_len,rhs_len)
},
}
}
}
impl ChunkCompare<&CategoricalChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;
fn equal(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_equality_helper(
self,
rhs,
|lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, false),
UInt32Chunked::equal,
)
}
fn equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_equality_helper(
self,
rhs,
|lhs| BooleanChunked::full(lhs.name(), false, lhs.len()),
UInt32Chunked::equal_missing,
)
}
fn not_equal(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_equality_helper(
self,
rhs,
|lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, true),
UInt32Chunked::not_equal,
)
}
fn not_equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_equality_helper(
self,
rhs,
|lhs| BooleanChunked::full(lhs.name(), true, lhs.len()),
UInt32Chunked::not_equal_missing,
)
}
fn gt(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r)
}
fn gt_eq(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_compare_helper(self, rhs, UInt32Chunked::gt_eq, |l, r| l >= r)
}
fn lt(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_compare_helper(self, rhs, UInt32Chunked::lt, |l, r| l < r)
}
fn lt_eq(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_compare_helper(self, rhs, UInt32Chunked::lt_eq, |l, r| l <= r)
}
}
fn cat_str_equality_helper<'a, Missing, CompareNone, CompareCat, ComparePhys, CompareString>(
lhs: &'a CategoricalChunked,
rhs: &'a StringChunked,
missing_function: Missing,
compare_to_none: CompareNone,
cat_compare_function: CompareCat,
phys_compare_function: ComparePhys,
str_compare_function: CompareString,
) -> PolarsResult<BooleanChunked>
where
Missing: Fn(&CategoricalChunked) -> BooleanChunked,
CompareNone: Fn(&CategoricalChunked) -> BooleanChunked,
ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
{
if lhs.is_enum() {
let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?;
cat_compare_function(lhs, rhs_cat.categorical().unwrap())
} else if rhs.len() == 1 {
match rhs.get(0) {
None => Ok(compare_to_none(lhs)),
Some(s) => {
cat_single_str_equality_helper(lhs, s, missing_function, phys_compare_function)
},
}
} else {
let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?;
Ok(str_compare_function(lhs_string.str().unwrap(), rhs))
}
}
fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, CompareString>(
lhs: &'a CategoricalChunked,
rhs: &'a StringChunked,
cat_compare_function: CompareCat,
phys_compare_function: ComparePhys,
str_single_compare_function: CompareStringSingle,
str_compare_function: CompareString,
) -> PolarsResult<BooleanChunked>
where
CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
{
if lhs.is_enum() {
let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?;
cat_compare_function(lhs, rhs_cat.categorical().unwrap())
} else if rhs.len() == 1 {
match rhs.get(0) {
None => Ok(BooleanChunked::full_null(lhs.name(), lhs.len())),
Some(s) => cat_single_str_compare_helper(
lhs,
s,
phys_compare_function,
str_single_compare_function,
),
}
} else {
let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?;
Ok(str_compare_function(lhs_string.str().unwrap(), rhs))
}
}
impl ChunkCompare<&StringChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;
fn equal(&self, rhs: &StringChunked) -> Self::Item {
cat_str_equality_helper(
self,
rhs,
|lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, false),
|lhs| BooleanChunked::full_null(lhs.name(), lhs.len()),
|s1, s2| CategoricalChunked::equal(s1, s2),
UInt32Chunked::equal,
StringChunked::equal,
)
}
fn equal_missing(&self, rhs: &StringChunked) -> Self::Item {
cat_str_equality_helper(
self,
rhs,
|lhs| BooleanChunked::full(lhs.name(), false, lhs.len()),
|lhs| lhs.physical().is_null(),
|s1, s2| CategoricalChunked::equal_missing(s1, s2),
UInt32Chunked::equal_missing,
StringChunked::equal_missing,
)
}
fn not_equal(&self, rhs: &StringChunked) -> Self::Item {
cat_str_equality_helper(
self,
rhs,
|lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, true),
|lhs| BooleanChunked::full_null(lhs.name(), lhs.len()),
|s1, s2| CategoricalChunked::not_equal(s1, s2),
UInt32Chunked::not_equal,
StringChunked::not_equal,
)
}
fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item {
cat_str_equality_helper(
self,
rhs,
|lhs| BooleanChunked::full(lhs.name(), true, lhs.len()),
|lhs| !lhs.physical().is_null(),
|s1, s2| CategoricalChunked::not_equal_missing(s1, s2),
UInt32Chunked::not_equal_missing,
StringChunked::not_equal_missing,
)
}
fn gt(&self, rhs: &StringChunked) -> Self::Item {
cat_str_compare_helper(
self,
rhs,
|s1, s2| CategoricalChunked::gt(s1, s2),
UInt32Chunked::gt,
Utf8ViewArray::tot_gt_kernel_broadcast,
StringChunked::gt,
)
}
fn gt_eq(&self, rhs: &StringChunked) -> Self::Item {
cat_str_compare_helper(
self,
rhs,
|s1, s2| CategoricalChunked::gt_eq(s1, s2),
UInt32Chunked::gt_eq,
Utf8ViewArray::tot_ge_kernel_broadcast,
StringChunked::gt_eq,
)
}
fn lt(&self, rhs: &StringChunked) -> Self::Item {
cat_str_compare_helper(
self,
rhs,
|s1, s2| CategoricalChunked::lt(s1, s2),
UInt32Chunked::lt,
Utf8ViewArray::tot_lt_kernel_broadcast,
StringChunked::lt,
)
}
fn lt_eq(&self, rhs: &StringChunked) -> Self::Item {
cat_str_compare_helper(
self,
rhs,
|s1, s2| CategoricalChunked::lt_eq(s1, s2),
UInt32Chunked::lt_eq,
Utf8ViewArray::tot_le_kernel_broadcast,
StringChunked::lt_eq,
)
}
}
fn cat_single_str_equality_helper<'a, ComparePhys, Missing>(
lhs: &'a CategoricalChunked,
rhs: &'a str,
missing_function: Missing,
phys_compare_function: ComparePhys,
) -> PolarsResult<BooleanChunked>
where
ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
Missing: Fn(&CategoricalChunked) -> BooleanChunked,
{
let rev_map = lhs.get_rev_map();
let idx = rev_map.find(rhs);
if lhs.is_enum() {
let Some(idx) = idx else {
polars_bail!(
not_in_enum,
value = rhs,
categories = rev_map.get_categories()
)
};
Ok(phys_compare_function(lhs.physical(), idx))
} else {
match rev_map.find(rhs) {
None => Ok(missing_function(lhs)),
Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)),
}
}
}
fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>(
lhs: &'a CategoricalChunked,
rhs: &'a str,
phys_compare_function: ComparePhys,
str_single_compare_function: CompareStringSingle,
) -> PolarsResult<BooleanChunked>
where
CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
{
let rev_map = lhs.get_rev_map();
if lhs.is_enum() {
match rev_map.find(rhs) {
None => {
polars_bail!(
not_in_enum,
value = rhs,
categories = rev_map.get_categories()
)
},
Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)),
}
} else {
let bitmap = str_single_compare_function(lhs.get_rev_map().get_categories(), rhs);
Ok(
BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map(|opt_idx| {
opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) })
}))
.with_name(lhs.name()),
)
}
}
impl ChunkCompare<&str> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;
fn equal(&self, rhs: &str) -> Self::Item {
cat_single_str_equality_helper(
self,
rhs,
|lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, false),
UInt32Chunked::equal,
)
}
fn equal_missing(&self, rhs: &str) -> Self::Item {
cat_single_str_equality_helper(
self,
rhs,
|lhs| BooleanChunked::full(lhs.name(), false, lhs.len()),
UInt32Chunked::equal_missing,
)
}
fn not_equal(&self, rhs: &str) -> Self::Item {
cat_single_str_equality_helper(
self,
rhs,
|lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, true),
UInt32Chunked::not_equal,
)
}
fn not_equal_missing(&self, rhs: &str) -> Self::Item {
cat_single_str_equality_helper(
self,
rhs,
|lhs| BooleanChunked::full(lhs.name(), true, lhs.len()),
UInt32Chunked::equal_missing,
)
}
fn gt(&self, rhs: &str) -> Self::Item {
cat_single_str_compare_helper(
self,
rhs,
UInt32Chunked::gt,
Utf8ViewArray::tot_gt_kernel_broadcast,
)
}
fn gt_eq(&self, rhs: &str) -> Self::Item {
cat_single_str_compare_helper(
self,
rhs,
UInt32Chunked::gt_eq,
Utf8ViewArray::tot_ge_kernel_broadcast,
)
}
fn lt(&self, rhs: &str) -> Self::Item {
cat_single_str_compare_helper(
self,
rhs,
UInt32Chunked::lt,
Utf8ViewArray::tot_lt_kernel_broadcast,
)
}
fn lt_eq(&self, rhs: &str) -> Self::Item {
cat_single_str_compare_helper(
self,
rhs,
UInt32Chunked::lt_eq,
Utf8ViewArray::tot_le_kernel_broadcast,
)
}
}