use super::*;
pub(super) type JoinIds = Vec<IdxSize>;
pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
pub type InnerJoinIds = (JoinIds, JoinIds);
#[cfg(feature = "chunked_ids")]
pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
#[cfg(feature = "chunked_ids")]
pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
#[cfg(not(feature = "chunked_ids"))]
pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
#[cfg(not(feature = "chunked_ids"))]
pub type ChunkJoinIds = Vec<IdxSize>;
use polars_core::export::once_cell::sync::Lazy;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct JoinArgs {
pub how: JoinType,
pub validation: JoinValidation,
pub suffix: Option<PlSmallStr>,
pub slice: Option<(i64, usize)>,
pub join_nulls: bool,
pub coalesce: JoinCoalesce,
}
impl JoinArgs {
pub fn should_coalesce(&self) -> bool {
self.coalesce.coalesce(&self.how)
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinCoalesce {
#[default]
JoinSpecific,
CoalesceColumns,
KeepColumns,
}
impl JoinCoalesce {
pub fn coalesce(&self, join_type: &JoinType) -> bool {
use JoinCoalesce::*;
use JoinType::*;
match join_type {
Left | Inner | Right => {
matches!(self, JoinSpecific | CoalesceColumns)
},
Full { .. } => {
matches!(self, CoalesceColumns)
},
#[cfg(feature = "asof_join")]
AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
#[cfg(feature = "iejoin")]
IEJoin(_) => false,
Cross => false,
#[cfg(feature = "semi_anti_join")]
Semi | Anti => false,
}
}
}
impl Default for JoinArgs {
fn default() -> Self {
Self {
how: JoinType::Inner,
validation: Default::default(),
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}
}
impl JoinArgs {
pub fn new(how: JoinType) -> Self {
Self {
how,
validation: Default::default(),
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}
pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}
pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
self.suffix = suffix;
self
}
pub fn suffix(&self) -> &PlSmallStr {
static DEFAULT: Lazy<PlSmallStr> = Lazy::new(|| PlSmallStr::from_static("_right"));
self.suffix.as_ref().unwrap_or(&*DEFAULT)
}
}
#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[strum(serialize_all = "snake_case")]
pub enum JoinType {
Inner,
Left,
Right,
Full,
#[cfg(feature = "asof_join")]
AsOf(AsOfOptions),
Cross,
#[cfg(feature = "semi_anti_join")]
Semi,
#[cfg(feature = "semi_anti_join")]
Anti,
#[cfg(feature = "iejoin")]
IEJoin(IEJoinOptions),
}
impl From<JoinType> for JoinArgs {
fn from(value: JoinType) -> Self {
JoinArgs::new(value)
}
}
impl Display for JoinType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use JoinType::*;
let val = match self {
Left => "LEFT",
Right => "RIGHT",
Inner => "INNER",
Full { .. } => "FULL",
#[cfg(feature = "asof_join")]
AsOf(_) => "ASOF",
#[cfg(feature = "iejoin")]
IEJoin(_) => "IEJOIN",
Cross => "CROSS",
#[cfg(feature = "semi_anti_join")]
Semi => "SEMI",
#[cfg(feature = "semi_anti_join")]
Anti => "ANTI",
};
write!(f, "{val}")
}
}
impl Debug for JoinType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl JoinType {
pub fn is_asof(&self) -> bool {
#[cfg(feature = "asof_join")]
{
matches!(self, JoinType::AsOf(_))
}
#[cfg(not(feature = "asof_join"))]
{
false
}
}
pub fn is_cross(&self) -> bool {
matches!(self, JoinType::Cross)
}
pub fn is_ie(&self) -> bool {
#[cfg(feature = "iejoin")]
{
matches!(self, JoinType::IEJoin(_))
}
#[cfg(not(feature = "iejoin"))]
{
false
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinValidation {
#[default]
ManyToMany,
ManyToOne,
OneToMany,
OneToOne,
}
impl JoinValidation {
pub fn needs_checks(&self) -> bool {
!matches!(self, JoinValidation::ManyToMany)
}
fn swap(self, swap: bool) -> Self {
use JoinValidation::*;
if swap {
match self {
ManyToMany => ManyToMany,
ManyToOne => OneToMany,
OneToMany => ManyToOne,
OneToOne => OneToOne,
}
} else {
self
}
}
pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
if !self.needs_checks() {
return Ok(());
}
polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full{..} | JoinType::Left),
ComputeError: "{self} validation on a {join_type} join is not supported");
Ok(())
}
pub(super) fn validate_probe(
&self,
s_left: &Series,
s_right: &Series,
build_shortest_table: bool,
join_nulls: bool,
) -> PolarsResult<()> {
let should_swap = build_shortest_table && s_left.len() <= s_right.len();
let probe = if should_swap { s_right } else { s_left };
use JoinValidation::*;
let valid = match self.swap(should_swap) {
ManyToMany | ManyToOne => true,
OneToMany | OneToOne => {
if !join_nulls && probe.null_count() > 0 {
probe.n_unique()? - 1 == probe.len() - probe.null_count()
} else {
probe.n_unique()? == probe.len()
}
},
};
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
Ok(())
}
pub(super) fn validate_build(
&self,
build_size: usize,
expected_size: usize,
swapped: bool,
) -> PolarsResult<()> {
use JoinValidation::*;
let valid = match self.swap(swapped) {
ManyToMany | OneToMany => true,
ManyToOne | OneToOne => build_size == expected_size,
};
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
Ok(())
}
}
impl Display for JoinValidation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let s = match self {
JoinValidation::ManyToMany => "m:m",
JoinValidation::ManyToOne => "m:1",
JoinValidation::OneToMany => "1:m",
JoinValidation::OneToOne => "1:1",
};
write!(f, "{s}")
}
}
impl Debug for JoinValidation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "JoinValidation: {self}")
}
}