use std::ops::{Add, Div, Mul, Rem, Sub};
use rayon::prelude::*;
use crate::prelude::*;
use crate::utils::try_get_supertype;
use crate::POOL;
fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {
df.columns.iter().try_fold(rhs.dtype().clone(), |dt, s| {
try_get_supertype(s.dtype(), &dt)
})
}
macro_rules! impl_arithmetic {
($self:expr, $rhs:expr, $operand: tt) => {{
let st = get_supertype_all($self, $rhs)?;
let rhs = $rhs.cast(&st)?;
let cols = POOL.install(|| {$self.columns.par_iter().map(|s| {
Ok(&s.cast(&st)? $operand &rhs)
}).collect::<PolarsResult<_>>()})?;
Ok(unsafe { DataFrame::new_no_checks(cols) })
}}
}
impl Add<&Series> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn add(self, rhs: &Series) -> Self::Output {
impl_arithmetic!(self, rhs, +)
}
}
impl Add<&Series> for DataFrame {
type Output = PolarsResult<DataFrame>;
fn add(self, rhs: &Series) -> Self::Output {
(&self).add(rhs)
}
}
impl Sub<&Series> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn sub(self, rhs: &Series) -> Self::Output {
impl_arithmetic!(self, rhs, -)
}
}
impl Sub<&Series> for DataFrame {
type Output = PolarsResult<DataFrame>;
fn sub(self, rhs: &Series) -> Self::Output {
(&self).sub(rhs)
}
}
impl Mul<&Series> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn mul(self, rhs: &Series) -> Self::Output {
impl_arithmetic!(self, rhs, *)
}
}
impl Mul<&Series> for DataFrame {
type Output = PolarsResult<DataFrame>;
fn mul(self, rhs: &Series) -> Self::Output {
(&self).mul(rhs)
}
}
impl Div<&Series> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn div(self, rhs: &Series) -> Self::Output {
impl_arithmetic!(self, rhs, /)
}
}
impl Div<&Series> for DataFrame {
type Output = PolarsResult<DataFrame>;
fn div(self, rhs: &Series) -> Self::Output {
(&self).div(rhs)
}
}
impl Rem<&Series> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn rem(self, rhs: &Series) -> Self::Output {
impl_arithmetic!(self, rhs, %)
}
}
impl Rem<&Series> for DataFrame {
type Output = PolarsResult<DataFrame>;
fn rem(self, rhs: &Series) -> Self::Output {
(&self).rem(rhs)
}
}
impl DataFrame {
fn binary_aligned(
&self,
other: &DataFrame,
f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),
) -> PolarsResult<DataFrame> {
let max_len = std::cmp::max(self.height(), other.height());
let max_width = std::cmp::max(self.width(), other.width());
let cols = self
.get_columns()
.par_iter()
.zip(other.get_columns().par_iter())
.map(|(l, r)| {
let diff_l = max_len - l.len();
let diff_r = max_len - r.len();
let st = try_get_supertype(l.dtype(), r.dtype())?;
let mut l = l.cast(&st)?;
let mut r = r.cast(&st)?;
if diff_l > 0 {
l = l.extend_constant(AnyValue::Null, diff_l)?;
};
if diff_r > 0 {
r = r.extend_constant(AnyValue::Null, diff_r)?;
};
f(&l, &r)
});
let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;
let col_len = cols.len();
if col_len < max_width {
let df = if col_len < self.width() { self } else { other };
for i in col_len..max_len {
let s = &df.get_columns()[i];
let name = s.name();
let dtype = s.dtype();
let vals: &[Option<i32>] = &[None];
let s = Series::new(name, vals).cast(dtype)?;
cols.push(s.new_from_index(0, max_len))
}
}
DataFrame::new(cols)
}
}
impl Add<&DataFrame> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn add(self, rhs: &DataFrame) -> Self::Output {
self.binary_aligned(rhs, &|a, b| Ok(a + b))
}
}
impl Sub<&DataFrame> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn sub(self, rhs: &DataFrame) -> Self::Output {
self.binary_aligned(rhs, &|a, b| Ok(a - b))
}
}
impl Div<&DataFrame> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn div(self, rhs: &DataFrame) -> Self::Output {
self.binary_aligned(rhs, &|a, b| Ok(a / b))
}
}
impl Mul<&DataFrame> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn mul(self, rhs: &DataFrame) -> Self::Output {
self.binary_aligned(rhs, &|a, b| Ok(a * b))
}
}
impl Rem<&DataFrame> for &DataFrame {
type Output = PolarsResult<DataFrame>;
fn rem(self, rhs: &DataFrame) -> Self::Output {
self.binary_aligned(rhs, &|a, b| Ok(a % b))
}
}