use arrow::legacy::time_zone::Tz;
use arrow::trusted_len::TrustedLen;
use polars_core::export::rayon::prelude::*;
use polars_core::prelude::*;
use polars_core::utils::_split_offsets;
use polars_core::utils::flatten::flatten_par;
use polars_core::POOL;
use polars_utils::slice::GetSaferUnchecked;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::prelude::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ClosedWindow {
    Left,
    Right,
    Both,
    None,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Label {
    Left,
    Right,
    DataPoint,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum StartBy {
    WindowBound,
    DataPoint,
    Monday,
    Tuesday,
    Wednesday,
    Thursday,
    Friday,
    Saturday,
    Sunday,
}
impl Default for StartBy {
    fn default() -> Self {
        Self::WindowBound
    }
}
impl StartBy {
    pub fn weekday(&self) -> Option<u32> {
        match self {
            StartBy::Monday => Some(0),
            StartBy::Tuesday => Some(1),
            StartBy::Wednesday => Some(2),
            StartBy::Thursday => Some(3),
            StartBy::Friday => Some(4),
            StartBy::Saturday => Some(5),
            StartBy::Sunday => Some(6),
            _ => None,
        }
    }
}
#[allow(clippy::too_many_arguments)]
fn update_groups_and_bounds(
    bounds_iter: BoundsIter<'_>,
    mut start: usize,
    time: &[i64],
    closed_window: ClosedWindow,
    include_lower_bound: bool,
    include_upper_bound: bool,
    lower_bound: &mut Vec<i64>,
    upper_bound: &mut Vec<i64>,
    groups: &mut Vec<[IdxSize; 2]>,
) {
    'bounds: for bi in bounds_iter {
        for &t in &time[start..time.len().saturating_sub(1)] {
            if bi.is_future(t, closed_window) {
                continue 'bounds;
            }
            if bi.is_member_entry(t, closed_window) {
                break;
            }
            start += 1;
        }
        let mut end = start;
        if end == time.len() - 1 {
            let t = time[end];
            if bi.is_member(t, closed_window) {
                if include_lower_bound {
                    lower_bound.push(bi.start);
                }
                if include_upper_bound {
                    upper_bound.push(bi.stop);
                }
                groups.push([end as IdxSize, 1])
            }
            continue;
        }
        for &t in &time[end..] {
            if !bi.is_member_exit(t, closed_window) {
                break;
            }
            end += 1;
        }
        let len = end - start;
        if include_lower_bound {
            lower_bound.push(bi.start);
        }
        if include_upper_bound {
            upper_bound.push(bi.stop);
        }
        groups.push([start as IdxSize, len as IdxSize])
    }
}
#[allow(clippy::too_many_arguments)]
pub fn group_by_windows(
    window: Window,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: &Option<TimeZone>,
    include_lower_bound: bool,
    include_upper_bound: bool,
    start_by: StartBy,
) -> (GroupsSlice, Vec<i64>, Vec<i64>) {
    let start = time[0];
    let boundary = if time.len() > 1 {
        let stop = time[time.len() - 1] + 1;
        Bounds::new_checked(start, stop)
    } else {
        let stop = start + 1;
        Bounds::new_checked(start, stop)
    };
    let size = {
        match tu {
            TimeUnit::Nanoseconds => window.estimate_overlapping_bounds_ns(boundary),
            TimeUnit::Microseconds => window.estimate_overlapping_bounds_us(boundary),
            TimeUnit::Milliseconds => window.estimate_overlapping_bounds_ms(boundary),
        }
    };
    let size_lower = if include_lower_bound { size } else { 0 };
    let size_upper = if include_upper_bound { size } else { 0 };
    let mut lower_bound = Vec::with_capacity(size_lower);
    let mut upper_bound = Vec::with_capacity(size_upper);
    let mut groups = Vec::with_capacity(size);
    let start_offset = 0;
    match tz {
        #[cfg(feature = "timezones")]
        Some(tz) => {
            update_groups_and_bounds(
                window
                    .get_overlapping_bounds_iter(
                        boundary,
                        closed_window,
                        tu,
                        tz.parse::<Tz>().ok().as_ref(),
                        start_by,
                    )
                    .unwrap(),
                start_offset,
                time,
                closed_window,
                include_lower_bound,
                include_upper_bound,
                &mut lower_bound,
                &mut upper_bound,
                &mut groups,
            );
        },
        _ => {
            update_groups_and_bounds(
                window
                    .get_overlapping_bounds_iter(boundary, closed_window, tu, None, start_by)
                    .unwrap(),
                start_offset,
                time,
                closed_window,
                include_lower_bound,
                include_upper_bound,
                &mut lower_bound,
                &mut upper_bound,
                &mut groups,
            );
        },
    };
    (groups, lower_bound, upper_bound)
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub(crate) fn group_by_values_iter_lookbehind(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
    start_offset: usize,
    upper_bound: Option<usize>,
) -> PolarsResult<impl TrustedLen<Item = PolarsResult<(IdxSize, IdxSize)>> + '_> {
    debug_assert!(offset.duration_ns() == period.duration_ns());
    debug_assert!(offset.negative);
    let add = match tu {
        TimeUnit::Nanoseconds => Duration::add_ns,
        TimeUnit::Microseconds => Duration::add_us,
        TimeUnit::Milliseconds => Duration::add_ms,
    };
    let upper_bound = upper_bound.unwrap_or(time.len());
    let mut start = if let Some(&t) = time.get(start_offset) {
        let lower = add(&offset, t, tz.as_ref())?;
        let upper = t;
        let b = Bounds::new(lower, upper);
        let slice = &time[..start_offset];
        slice.partition_point(|v| !b.is_member(*v, closed_window))
    } else {
        0
    };
    let mut end = start;
    Ok(time[start_offset..upper_bound]
        .iter()
        .enumerate()
        .map(move |(mut i, t)| {
            i += start_offset;
            let lower = add(&offset, *t, tz.as_ref())?;
            let upper = *t;
            let b = Bounds::new(lower, upper);
            for &t in unsafe { time.get_unchecked_release(start..i) } {
                if b.is_member_entry(t, closed_window) {
                    break;
                }
                start += 1;
            }
            if b.is_member_exit(*t, closed_window) {
                end = i;
            } else {
                end = std::cmp::max(end, start);
            }
            for &t in unsafe { time.get_unchecked_release(end..) } {
                if !b.is_member_exit(t, closed_window) {
                    break;
                }
                end += 1;
            }
            let len = end - start;
            let offset = start as IdxSize;
            Ok((offset, len as IdxSize))
        }))
}
pub(crate) fn group_by_values_iter_window_behind_t(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
) -> impl TrustedLen<Item = PolarsResult<(IdxSize, IdxSize)>> + '_ {
    let add = match tu {
        TimeUnit::Nanoseconds => Duration::add_ns,
        TimeUnit::Microseconds => Duration::add_us,
        TimeUnit::Milliseconds => Duration::add_ms,
    };
    let mut start = 0;
    let mut end = start;
    time.iter().map(move |lower| {
        let lower = add(&offset, *lower, tz.as_ref())?;
        let upper = add(&period, lower, tz.as_ref())?;
        let b = Bounds::new(lower, upper);
        if b.is_future(time[0], closed_window) {
            Ok((0, 0))
        } else {
            for &t in &time[start..] {
                if b.is_member_entry(t, closed_window) {
                    break;
                }
                start += 1;
            }
            end = std::cmp::max(start, end);
            for &t in &time[end..] {
                if !b.is_member_exit(t, closed_window) {
                    break;
                }
                end += 1;
            }
            let len = end - start;
            let offset = start as IdxSize;
            Ok((offset, len as IdxSize))
        }
    })
}
pub(crate) fn group_by_values_iter_partial_lookbehind(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
) -> impl TrustedLen<Item = PolarsResult<(IdxSize, IdxSize)>> + '_ {
    let add = match tu {
        TimeUnit::Nanoseconds => Duration::add_ns,
        TimeUnit::Microseconds => Duration::add_us,
        TimeUnit::Milliseconds => Duration::add_ms,
    };
    let mut start = 0;
    let mut end = start;
    time.iter().enumerate().map(move |(i, lower)| {
        let lower = add(&offset, *lower, tz.as_ref())?;
        let upper = add(&period, lower, tz.as_ref())?;
        let b = Bounds::new(lower, upper);
        for &t in &time[start..] {
            if b.is_member_entry(t, closed_window) || start == i {
                break;
            }
            start += 1;
        }
        end = std::cmp::max(start, end);
        for &t in &time[end..] {
            if !b.is_member_exit(t, closed_window) {
                break;
            }
            end += 1;
        }
        let len = end - start;
        let offset = start as IdxSize;
        Ok((offset, len as IdxSize))
    })
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn group_by_values_iter_lookahead(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
    start_offset: usize,
    upper_bound: Option<usize>,
) -> impl TrustedLen<Item = PolarsResult<(IdxSize, IdxSize)>> + '_ {
    let upper_bound = upper_bound.unwrap_or(time.len());
    let add = match tu {
        TimeUnit::Nanoseconds => Duration::add_ns,
        TimeUnit::Microseconds => Duration::add_us,
        TimeUnit::Milliseconds => Duration::add_ms,
    };
    let mut start = start_offset;
    let mut end = start;
    time[start_offset..upper_bound].iter().map(move |lower| {
        let lower = add(&offset, *lower, tz.as_ref())?;
        let upper = add(&period, lower, tz.as_ref())?;
        let b = Bounds::new(lower, upper);
        for &t in &time[start..] {
            if b.is_member_entry(t, closed_window) {
                break;
            }
            start += 1;
        }
        end = std::cmp::max(start, end);
        for &t in &time[end..] {
            if !b.is_member_exit(t, closed_window) {
                break;
            }
            end += 1;
        }
        let len = end - start;
        let offset = start as IdxSize;
        Ok((offset, len as IdxSize))
    })
}
#[cfg(feature = "rolling_window_by")]
#[inline]
pub(crate) fn group_by_values_iter(
    period: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
) -> PolarsResult<impl TrustedLen<Item = PolarsResult<(IdxSize, IdxSize)>> + '_> {
    let mut offset = period;
    offset.negative = true;
    group_by_values_iter_lookbehind(period, offset, time, closed_window, tu, tz, 0, None)
}
fn prune_splits_on_duplicates(time: &[i64], thread_offsets: &mut Vec<(usize, usize)>) {
    let is_valid = |window: &[(usize, usize)]| -> bool {
        debug_assert_eq!(window.len(), 2);
        let left_block_end = window[0].0 + window[0].1.saturating_sub(1);
        let right_block_start = window[1].0;
        time[left_block_end] != time[right_block_start]
    };
    if time.is_empty() || thread_offsets.len() <= 1 || thread_offsets.windows(2).all(is_valid) {
        return;
    }
    let mut new = vec![];
    for window in thread_offsets.windows(2) {
        let this_block_is_valid = is_valid(window);
        if this_block_is_valid {
            new.push(window[0])
        }
    }
    if thread_offsets.len() % 2 == 0 {
        let window = &thread_offsets[thread_offsets.len() - 2..];
        if is_valid(window) {
            new.push(thread_offsets[thread_offsets.len() - 1])
        }
    }
    if new.len() <= 1 {
        new = vec![(0, time.len())];
    } else {
        let mut previous_start = time.len();
        for window in new.iter_mut().rev() {
            window.1 = previous_start - window.0;
            previous_start = window.0;
        }
        new[0].0 = 0;
        new[0].1 = new[1].0;
        debug_assert_eq!(new.iter().map(|w| w.1).sum::<usize>(), time.len());
        prune_splits_on_duplicates(time, &mut new)
    }
    std::mem::swap(thread_offsets, &mut new);
}
#[allow(clippy::too_many_arguments)]
fn group_by_values_iter_lookbehind_collected(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
    start_offset: usize,
    upper_bound: Option<usize>,
) -> PolarsResult<Vec<[IdxSize; 2]>> {
    let iter = group_by_values_iter_lookbehind(
        period,
        offset,
        time,
        closed_window,
        tu,
        tz,
        start_offset,
        upper_bound,
    )?;
    iter.map(|result| result.map(|(offset, len)| [offset, len]))
        .collect::<PolarsResult<Vec<_>>>()
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn group_by_values_iter_lookahead_collected(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
    start_offset: usize,
    upper_bound: Option<usize>,
) -> PolarsResult<Vec<[IdxSize; 2]>> {
    let iter = group_by_values_iter_lookahead(
        period,
        offset,
        time,
        closed_window,
        tu,
        tz,
        start_offset,
        upper_bound,
    );
    iter.map(|result| result.map(|(offset, len)| [offset as IdxSize, len]))
        .collect::<PolarsResult<Vec<_>>>()
}
pub fn group_by_values(
    period: Duration,
    offset: Duration,
    time: &[i64],
    closed_window: ClosedWindow,
    tu: TimeUnit,
    tz: Option<Tz>,
) -> PolarsResult<GroupsSlice> {
    let mut thread_offsets = _split_offsets(time.len(), POOL.current_num_threads());
    prune_splits_on_duplicates(time, &mut thread_offsets);
    let run_parallel = !POOL.current_thread_has_pending_tasks().unwrap_or(false);
    if offset.negative && !offset.is_zero() {
        if offset.duration_ns() == period.duration_ns() {
            if !run_parallel {
                let vecs = group_by_values_iter_lookbehind_collected(
                    period,
                    offset,
                    time,
                    closed_window,
                    tu,
                    tz,
                    0,
                    None,
                )?;
                return Ok(GroupsSlice::from(vecs));
            }
            POOL.install(|| {
                let vals = thread_offsets
                    .par_iter()
                    .copied()
                    .map(|(base_offset, len)| {
                        let upper_bound = base_offset + len;
                        group_by_values_iter_lookbehind_collected(
                            period,
                            offset,
                            time,
                            closed_window,
                            tu,
                            tz,
                            base_offset,
                            Some(upper_bound),
                        )
                    })
                    .collect::<PolarsResult<Vec<_>>>()?;
                Ok(flatten_par(&vals))
            })
        } else if ((offset.duration_ns() >= period.duration_ns())
            && matches!(closed_window, ClosedWindow::Left | ClosedWindow::None))
            || ((offset.duration_ns() > period.duration_ns())
                && matches!(closed_window, ClosedWindow::Right | ClosedWindow::Both))
        {
            let iter =
                group_by_values_iter_window_behind_t(period, offset, time, closed_window, tu, tz);
            iter.map(|result| result.map(|(offset, len)| [offset, len]))
                .collect::<PolarsResult<_>>()
        }
        else {
            let iter = group_by_values_iter_partial_lookbehind(
                period,
                offset,
                time,
                closed_window,
                tu,
                tz,
            );
            iter.map(|result| result.map(|(offset, len)| [offset, len]))
                .collect::<PolarsResult<_>>()
        }
    } else if !offset.is_zero()
        || closed_window == ClosedWindow::Right
        || closed_window == ClosedWindow::None
    {
        if !run_parallel {
            let vecs = group_by_values_iter_lookahead_collected(
                period,
                offset,
                time,
                closed_window,
                tu,
                tz,
                0,
                None,
            )?;
            return Ok(GroupsSlice::from(vecs));
        }
        POOL.install(|| {
            let vals = thread_offsets
                .par_iter()
                .copied()
                .map(|(base_offset, len)| {
                    let lower_bound = base_offset;
                    let upper_bound = base_offset + len;
                    group_by_values_iter_lookahead_collected(
                        period,
                        offset,
                        time,
                        closed_window,
                        tu,
                        tz,
                        lower_bound,
                        Some(upper_bound),
                    )
                })
                .collect::<PolarsResult<Vec<_>>>()?;
            Ok(flatten_par(&vals))
        })
    } else {
        if !run_parallel {
            let vecs = group_by_values_iter_lookahead_collected(
                period,
                offset,
                time,
                closed_window,
                tu,
                tz,
                0,
                None,
            )?;
            return Ok(GroupsSlice::from(vecs));
        }
        POOL.install(|| {
            let vals = thread_offsets
                .par_iter()
                .copied()
                .map(|(base_offset, len)| {
                    let lower_bound = base_offset;
                    let upper_bound = base_offset + len;
                    group_by_values_iter_lookahead_collected(
                        period,
                        offset,
                        time,
                        closed_window,
                        tu,
                        tz,
                        lower_bound,
                        Some(upper_bound),
                    )
                })
                .collect::<PolarsResult<Vec<_>>>()?;
            Ok(flatten_par(&vals))
        })
    }
}
#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn test_prune_duplicates() {
        let time = &[0, 1, 1, 2, 2, 2, 3, 4, 5, 6, 5];
        let mut splits = vec![(0, 2), (2, 4), (6, 2), (8, 3)];
        prune_splits_on_duplicates(time, &mut splits);
        assert_eq!(splits, &[(0, 6), (6, 2), (8, 3)]);
    }
}