use std::borrow::Cow;
use std::collections::VecDeque;
use std::ops::{Deref, Range};
use arrow::array::new_empty_array;
use arrow::datatypes::ArrowSchemaRef;
use polars_core::prelude::*;
use polars_core::utils::{accumulate_dataframes_vertical, split_df};
use polars_core::POOL;
use polars_parquet::read::{self, ArrayIter, FileMetaData, PhysicalType, RowGroupMetaData};
use rayon::prelude::*;
#[cfg(feature = "cloud")]
use super::async_impl::FetchRowGroupsFromObjectStore;
use super::mmap::{mmap_columns, ColumnStore};
use super::predicates::read_this_row_group;
use super::to_metadata::ToMetadata;
use super::utils::materialize_empty_df;
use super::{mmap, ParallelStrategy};
use crate::mmap::{MmapBytesReader, ReaderBytes};
use crate::parquet::metadata::FileMetaDataRef;
use crate::predicates::{apply_predicate, PhysicalIoExpr};
use crate::utils::get_reader_bytes;
use crate::RowIndex;
#[cfg(debug_assertions)]
fn assert_dtypes(data_type: &ArrowDataType) {
    match data_type {
        ArrowDataType::Utf8 => {
            unreachable!()
        },
        ArrowDataType::Binary => {
            unreachable!()
        },
        ArrowDataType::List(_) => {
            unreachable!()
        },
        ArrowDataType::LargeList(inner) => {
            assert_dtypes(&inner.data_type);
        },
        ArrowDataType::Struct(fields) => {
            for fld in fields {
                assert_dtypes(fld.data_type())
            }
        },
        _ => {},
    }
}
fn column_idx_to_series(
    column_i: usize,
    md: &RowGroupMetaData,
    remaining_rows: usize,
    file_schema: &ArrowSchema,
    store: &mmap::ColumnStore,
    chunk_size: usize,
) -> PolarsResult<Series> {
    let field = &file_schema.fields[column_i];
    #[cfg(debug_assertions)]
    {
        assert_dtypes(field.data_type())
    }
    let columns = mmap_columns(store, md.columns(), &field.name);
    let iter = mmap::to_deserializer(columns, field.clone(), remaining_rows, Some(chunk_size))?;
    let mut series = if remaining_rows < md.num_rows() {
        array_iter_to_series(iter, field, Some(remaining_rows))
    } else {
        array_iter_to_series(iter, field, None)
    }?;
    let Some(Ok(stats)) = md.columns()[column_i].statistics() else {
        return Ok(series);
    };
    let series_trait = series.as_ref();
    macro_rules! match_dtypes_into_metadata {
        ($(($dtype:pat, $phystype:pat) => ($stats:ident, $pldtype:ty),)+) => {
            match (series_trait.dtype(), stats.physical_type()) {
                $(
                ($dtype, $phystype) => {
                    series.try_set_metadata(
                        ToMetadata::<$pldtype>::to_metadata(stats.$stats())
                    );
                })+
                _ => {},
            }
        };
    }
    use {DataType as D, PhysicalType as P};
    match_dtypes_into_metadata! {
        (D::Boolean, P::Boolean  ) => (expect_as_boolean, BooleanType),
        (D::UInt8,   P::Int32    ) => (expect_as_int32,   UInt8Type  ),
        (D::UInt16,  P::Int32    ) => (expect_as_int32,   UInt16Type ),
        (D::UInt32,  P::Int32    ) => (expect_as_int32,   UInt32Type ),
        (D::UInt64,  P::Int64    ) => (expect_as_int64,   UInt64Type ),
        (D::Int8,    P::Int32    ) => (expect_as_int32,   Int8Type   ),
        (D::Int16,   P::Int32    ) => (expect_as_int32,   Int16Type  ),
        (D::Int32,   P::Int32    ) => (expect_as_int32,   Int32Type  ),
        (D::Int64,   P::Int64    ) => (expect_as_int64,   Int64Type  ),
        (D::Float32, P::Float    ) => (expect_as_float,   Float32Type),
        (D::Float64, P::Double   ) => (expect_as_double,  Float64Type),
        (D::String,  P::ByteArray) => (expect_as_binary,  StringType ),
        (D::Binary,  P::ByteArray) => (expect_as_binary,  BinaryType ),
    }
    Ok(series)
}
pub(super) fn array_iter_to_series(
    iter: ArrayIter,
    field: &ArrowField,
    num_rows: Option<usize>,
) -> PolarsResult<Series> {
    let mut total_count = 0;
    let chunks = match num_rows {
        None => iter.collect::<PolarsResult<Vec<_>>>()?,
        Some(n) => {
            let mut out = Vec::with_capacity(2);
            for arr in iter {
                let arr = arr?;
                let len = arr.len();
                out.push(arr);
                total_count += len;
                if total_count >= n {
                    break;
                }
            }
            out
        },
    };
    if chunks.is_empty() {
        let arr = new_empty_array(field.data_type.clone());
        Series::try_from((field, arr))
    } else {
        Series::try_from((field, chunks))
    }
}
pub(crate) fn materialize_hive_partitions(
    df: &mut DataFrame,
    hive_partition_columns: Option<&[Series]>,
    num_rows: usize,
) {
    if let Some(hive_columns) = hive_partition_columns {
        for s in hive_columns {
            unsafe { df.with_column_unchecked(s.new_from_index(0, num_rows)) };
        }
    }
}
#[allow(clippy::too_many_arguments)]
fn rg_to_dfs(
    store: &mmap::ColumnStore,
    previous_row_count: &mut IdxSize,
    row_group_start: usize,
    row_group_end: usize,
    remaining_rows: &mut usize,
    file_metadata: &FileMetaData,
    schema: &ArrowSchemaRef,
    predicate: Option<&dyn PhysicalIoExpr>,
    row_index: Option<RowIndex>,
    parallel: ParallelStrategy,
    projection: &[usize],
    use_statistics: bool,
    hive_partition_columns: Option<&[Series]>,
) -> PolarsResult<Vec<DataFrame>> {
    if let ParallelStrategy::Columns | ParallelStrategy::None = parallel {
        rg_to_dfs_optionally_par_over_columns(
            store,
            previous_row_count,
            row_group_start,
            row_group_end,
            remaining_rows,
            file_metadata,
            schema,
            predicate,
            row_index,
            parallel,
            projection,
            use_statistics,
            hive_partition_columns,
        )
    } else {
        rg_to_dfs_par_over_rg(
            store,
            row_group_start,
            row_group_end,
            previous_row_count,
            remaining_rows,
            file_metadata,
            schema,
            predicate,
            row_index,
            projection,
            use_statistics,
            hive_partition_columns,
        )
    }
}
#[allow(clippy::too_many_arguments)]
fn rg_to_dfs_optionally_par_over_columns(
    store: &mmap::ColumnStore,
    previous_row_count: &mut IdxSize,
    row_group_start: usize,
    row_group_end: usize,
    remaining_rows: &mut usize,
    file_metadata: &FileMetaData,
    schema: &ArrowSchemaRef,
    predicate: Option<&dyn PhysicalIoExpr>,
    row_index: Option<RowIndex>,
    parallel: ParallelStrategy,
    projection: &[usize],
    use_statistics: bool,
    hive_partition_columns: Option<&[Series]>,
) -> PolarsResult<Vec<DataFrame>> {
    let mut dfs = Vec::with_capacity(row_group_end - row_group_start);
    for rg_idx in row_group_start..row_group_end {
        let md = &file_metadata.row_groups[rg_idx];
        let current_row_count = md.num_rows() as IdxSize;
        if use_statistics
            && !read_this_row_group(predicate, &file_metadata.row_groups[rg_idx], schema)?
        {
            *previous_row_count += current_row_count;
            continue;
        }
        #[cfg(debug_assertions)]
        {
            assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err())
        }
        let projection_height = (*remaining_rows).min(md.num_rows());
        let chunk_size = md.num_rows();
        let columns = if let ParallelStrategy::Columns = parallel {
            POOL.install(|| {
                projection
                    .par_iter()
                    .map(|column_i| {
                        column_idx_to_series(
                            *column_i,
                            md,
                            projection_height,
                            schema,
                            store,
                            chunk_size,
                        )
                    })
                    .collect::<PolarsResult<Vec<_>>>()
            })?
        } else {
            projection
                .iter()
                .map(|column_i| {
                    column_idx_to_series(
                        *column_i,
                        md,
                        projection_height,
                        schema,
                        store,
                        chunk_size,
                    )
                })
                .collect::<PolarsResult<Vec<_>>>()?
        };
        *remaining_rows -= projection_height;
        let mut df = unsafe { DataFrame::new_no_checks(columns) };
        if let Some(rc) = &row_index {
            df.with_row_index_mut(&rc.name, Some(*previous_row_count + rc.offset));
        }
        materialize_hive_partitions(&mut df, hive_partition_columns, projection_height);
        apply_predicate(&mut df, predicate, true)?;
        *previous_row_count += current_row_count;
        dfs.push(df);
        if *remaining_rows == 0 {
            break;
        }
    }
    Ok(dfs)
}
#[allow(clippy::too_many_arguments)]
fn rg_to_dfs_par_over_rg(
    store: &mmap::ColumnStore,
    row_group_start: usize,
    row_group_end: usize,
    previous_row_count: &mut IdxSize,
    remaining_rows: &mut usize,
    file_metadata: &FileMetaData,
    schema: &ArrowSchemaRef,
    predicate: Option<&dyn PhysicalIoExpr>,
    row_index: Option<RowIndex>,
    projection: &[usize],
    use_statistics: bool,
    hive_partition_columns: Option<&[Series]>,
) -> PolarsResult<Vec<DataFrame>> {
    let row_groups = file_metadata
        .row_groups
        .iter()
        .enumerate()
        .skip(row_group_start)
        .take(row_group_end - row_group_start)
        .map(|(rg_idx, rg_md)| {
            let row_count_start = *previous_row_count;
            let num_rows = rg_md.num_rows();
            *previous_row_count += num_rows as IdxSize;
            let projection_height = (*remaining_rows).min(num_rows);
            *remaining_rows -= projection_height;
            (rg_idx, rg_md, projection_height, row_count_start)
        })
        .collect::<Vec<_>>();
    let dfs = POOL.install(|| {
        row_groups
            .into_par_iter()
            .map(|(rg_idx, md, projection_height, row_count_start)| {
                if projection_height == 0
                    || use_statistics
                        && !read_this_row_group(
                            predicate,
                            &file_metadata.row_groups[rg_idx],
                            schema,
                        )?
                {
                    return Ok(None);
                }
                #[cfg(debug_assertions)]
                {
                    assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err())
                }
                let chunk_size = md.num_rows();
                let columns = projection
                    .iter()
                    .map(|column_i| {
                        column_idx_to_series(
                            *column_i,
                            md,
                            projection_height,
                            schema,
                            store,
                            chunk_size,
                        )
                    })
                    .collect::<PolarsResult<Vec<_>>>()?;
                let mut df = unsafe { DataFrame::new_no_checks(columns) };
                if let Some(rc) = &row_index {
                    df.with_row_index_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset));
                }
                materialize_hive_partitions(&mut df, hive_partition_columns, projection_height);
                apply_predicate(&mut df, predicate, false)?;
                Ok(Some(df))
            })
            .collect::<PolarsResult<Vec<_>>>()
    })?;
    Ok(dfs.into_iter().flatten().collect())
}
#[allow(clippy::too_many_arguments)]
pub fn read_parquet<R: MmapBytesReader>(
    mut reader: R,
    mut limit: usize,
    projection: Option<&[usize]>,
    reader_schema: &ArrowSchemaRef,
    metadata: Option<FileMetaDataRef>,
    predicate: Option<&dyn PhysicalIoExpr>,
    mut parallel: ParallelStrategy,
    row_index: Option<RowIndex>,
    use_statistics: bool,
    hive_partition_columns: Option<&[Series]>,
) -> PolarsResult<DataFrame> {
    if limit == 0 {
        return Ok(materialize_empty_df(
            projection,
            reader_schema,
            hive_partition_columns,
            row_index.as_ref(),
        ));
    }
    let file_metadata = metadata
        .map(Ok)
        .unwrap_or_else(|| read::read_metadata(&mut reader).map(Arc::new))?;
    let n_row_groups = file_metadata.row_groups.len();
    let _sc = if n_row_groups > 1 {
        #[cfg(feature = "dtype-categorical")]
        {
            Some(polars_core::StringCacheHolder::hold())
        }
        #[cfg(not(feature = "dtype-categorical"))]
        {
            Some(0u8)
        }
    } else {
        None
    };
    let materialized_projection = projection
        .map(Cow::Borrowed)
        .unwrap_or_else(|| Cow::Owned((0usize..reader_schema.len()).collect::<Vec<_>>()));
    if let ParallelStrategy::Auto = parallel {
        if n_row_groups > materialized_projection.len() || n_row_groups > POOL.current_num_threads()
        {
            parallel = ParallelStrategy::RowGroups;
        } else {
            parallel = ParallelStrategy::Columns;
        }
    }
    if let (ParallelStrategy::Columns, true) = (parallel, materialized_projection.len() == 1) {
        parallel = ParallelStrategy::None;
    }
    let reader = ReaderBytes::from(&reader);
    let bytes = reader.deref();
    let store = mmap::ColumnStore::Local(bytes);
    let dfs = rg_to_dfs(
        &store,
        &mut 0,
        0,
        n_row_groups,
        &mut limit,
        &file_metadata,
        reader_schema,
        predicate,
        row_index.clone(),
        parallel,
        &materialized_projection,
        use_statistics,
        hive_partition_columns,
    )?;
    if dfs.is_empty() {
        Ok(materialize_empty_df(
            projection,
            reader_schema,
            hive_partition_columns,
            row_index.as_ref(),
        ))
    } else {
        accumulate_dataframes_vertical(dfs)
    }
}
pub struct FetchRowGroupsFromMmapReader(ReaderBytes<'static>);
impl FetchRowGroupsFromMmapReader {
    pub fn new(mut reader: Box<dyn MmapBytesReader>) -> PolarsResult<Self> {
        assert!(reader.to_file().is_some());
        let reader_ptr = unsafe {
            std::mem::transmute::<&mut dyn MmapBytesReader, &'static mut dyn MmapBytesReader>(
                reader.as_mut(),
            )
        };
        let reader_bytes = get_reader_bytes(reader_ptr)?;
        Ok(FetchRowGroupsFromMmapReader(reader_bytes))
    }
    fn fetch_row_groups(&mut self, _row_groups: Range<usize>) -> PolarsResult<ColumnStore> {
        Ok(mmap::ColumnStore::Local(self.0.deref()))
    }
}
pub enum RowGroupFetcher {
    #[cfg(feature = "cloud")]
    ObjectStore(FetchRowGroupsFromObjectStore),
    Local(FetchRowGroupsFromMmapReader),
}
#[cfg(feature = "cloud")]
impl From<FetchRowGroupsFromObjectStore> for RowGroupFetcher {
    fn from(value: FetchRowGroupsFromObjectStore) -> Self {
        RowGroupFetcher::ObjectStore(value)
    }
}
impl From<FetchRowGroupsFromMmapReader> for RowGroupFetcher {
    fn from(value: FetchRowGroupsFromMmapReader) -> Self {
        RowGroupFetcher::Local(value)
    }
}
impl RowGroupFetcher {
    async fn fetch_row_groups(&mut self, _row_groups: Range<usize>) -> PolarsResult<ColumnStore> {
        match self {
            RowGroupFetcher::Local(f) => f.fetch_row_groups(_row_groups),
            #[cfg(feature = "cloud")]
            RowGroupFetcher::ObjectStore(f) => f.fetch_row_groups(_row_groups).await,
        }
    }
}
pub(super) fn compute_row_group_range(
    row_group_start: usize,
    row_group_end: usize,
    limit: usize,
    row_groups: &[RowGroupMetaData],
) -> usize {
    let mut row_group_end_truncated = row_group_start;
    let mut acc_row_count = 0;
    #[allow(clippy::needless_range_loop)]
    for rg_i in row_group_start..(std::cmp::min(row_group_end, row_groups.len())) {
        if acc_row_count >= limit {
            break;
        }
        row_group_end_truncated = rg_i + 1;
        acc_row_count += row_groups[rg_i].num_rows();
    }
    row_group_end_truncated
}
pub struct BatchedParquetReader {
    #[allow(dead_code)]
    row_group_fetcher: RowGroupFetcher,
    limit: usize,
    projection: Arc<[usize]>,
    schema: ArrowSchemaRef,
    metadata: FileMetaDataRef,
    predicate: Option<Arc<dyn PhysicalIoExpr>>,
    row_index: Option<RowIndex>,
    rows_read: IdxSize,
    row_group_offset: usize,
    n_row_groups: usize,
    chunks_fifo: VecDeque<DataFrame>,
    parallel: ParallelStrategy,
    chunk_size: usize,
    use_statistics: bool,
    hive_partition_columns: Option<Arc<[Series]>>,
    has_returned: bool,
}
impl BatchedParquetReader {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        row_group_fetcher: RowGroupFetcher,
        metadata: FileMetaDataRef,
        schema: ArrowSchemaRef,
        limit: usize,
        projection: Option<Vec<usize>>,
        predicate: Option<Arc<dyn PhysicalIoExpr>>,
        row_index: Option<RowIndex>,
        chunk_size: usize,
        use_statistics: bool,
        hive_partition_columns: Option<Vec<Series>>,
        mut parallel: ParallelStrategy,
    ) -> PolarsResult<Self> {
        let n_row_groups = metadata.row_groups.len();
        let projection = projection
            .map(Arc::from)
            .unwrap_or_else(|| (0usize..schema.len()).collect::<Arc<[_]>>());
        parallel = match parallel {
            ParallelStrategy::Auto => {
                if n_row_groups > projection.len() || n_row_groups > POOL.current_num_threads() {
                    ParallelStrategy::RowGroups
                } else {
                    ParallelStrategy::Columns
                }
            },
            _ => parallel,
        };
        if let (ParallelStrategy::Columns, true) = (parallel, projection.len() == 1) {
            parallel = ParallelStrategy::None;
        }
        Ok(BatchedParquetReader {
            row_group_fetcher,
            limit,
            projection,
            schema,
            metadata,
            row_index,
            rows_read: 0,
            predicate,
            row_group_offset: 0,
            n_row_groups,
            chunks_fifo: VecDeque::with_capacity(POOL.current_num_threads()),
            parallel,
            chunk_size,
            use_statistics,
            hive_partition_columns: hive_partition_columns.map(Arc::from),
            has_returned: false,
        })
    }
    pub fn limit_reached(&self) -> bool {
        self.limit == 0
    }
    pub fn schema(&self) -> &ArrowSchemaRef {
        &self.schema
    }
    pub fn is_finished(&self) -> bool {
        self.row_group_offset >= self.n_row_groups
    }
    pub fn finishes_this_batch(&self, n: usize) -> bool {
        self.row_group_offset + n > self.n_row_groups
    }
    pub async fn next_batches(&mut self, n: usize) -> PolarsResult<Option<Vec<DataFrame>>> {
        if self.limit == 0 && self.has_returned {
            return if self.chunks_fifo.is_empty() {
                Ok(None)
            } else {
                let n_drainable = std::cmp::min(n, self.chunks_fifo.len());
                Ok(Some(self.chunks_fifo.drain(..n_drainable).collect()))
            };
        }
        let mut skipped_all_rgs = false;
        if self.row_group_offset < self.n_row_groups && self.chunks_fifo.len() < n {
            let row_group_start = self.row_group_offset;
            let row_group_end = compute_row_group_range(
                row_group_start,
                row_group_start + n,
                self.limit,
                &self.metadata.row_groups,
            );
            let store = self
                .row_group_fetcher
                .fetch_row_groups(row_group_start..row_group_end)
                .await?;
            let dfs = match store {
                ColumnStore::Local(_) => rg_to_dfs(
                    &store,
                    &mut self.rows_read,
                    row_group_start,
                    row_group_end,
                    &mut self.limit,
                    &self.metadata,
                    &self.schema,
                    self.predicate.as_deref(),
                    self.row_index.clone(),
                    self.parallel,
                    &self.projection,
                    self.use_statistics,
                    self.hive_partition_columns.as_deref(),
                ),
                #[cfg(feature = "async")]
                ColumnStore::Fetched(b) => {
                    let store = ColumnStore::Fetched(b);
                    let (tx, rx) = tokio::sync::oneshot::channel();
                    let mut rows_read = self.rows_read;
                    let mut limit = self.limit;
                    let row_index = self.row_index.clone();
                    let predicate = self.predicate.clone();
                    let schema = self.schema.clone();
                    let metadata = self.metadata.clone();
                    let parallel = self.parallel;
                    let projection = self.projection.clone();
                    let use_statistics = self.use_statistics;
                    let hive_partition_columns = self.hive_partition_columns.clone();
                    let f = move || {
                        let dfs = rg_to_dfs(
                            &store,
                            &mut rows_read,
                            row_group_start,
                            row_group_end,
                            &mut limit,
                            &metadata,
                            &schema,
                            predicate.as_deref(),
                            row_index,
                            parallel,
                            &projection,
                            use_statistics,
                            hive_partition_columns.as_deref(),
                        );
                        tx.send((dfs, rows_read, limit)).unwrap();
                    };
                    if POOL.current_thread_index().is_some() {
                        tokio::task::block_in_place(f);
                    } else {
                        POOL.spawn(f);
                    };
                    let (dfs, rows_read, limit) = rx.await.unwrap();
                    self.rows_read = rows_read;
                    self.limit = limit;
                    dfs
                },
            }?;
            self.row_group_offset += n;
            if self.rows_read == 0 && dfs.is_empty() {
                return Ok(Some(vec![materialize_empty_df(
                    Some(&self.projection),
                    self.schema.as_ref(),
                    self.hive_partition_columns.as_deref(),
                    self.row_index.as_ref(),
                )]));
            }
            skipped_all_rgs |= dfs.is_empty();
            for mut df in dfs {
                let n = df.height() / self.chunk_size;
                if n > 1 {
                    for df in split_df(&mut df, n, false) {
                        self.chunks_fifo.push_back(df)
                    }
                } else {
                    self.chunks_fifo.push_back(df)
                }
            }
        } else {
            skipped_all_rgs = !self.has_returned;
        };
        if self.chunks_fifo.is_empty() {
            if skipped_all_rgs {
                self.has_returned = true;
                Ok(Some(vec![materialize_empty_df(
                    Some(self.projection.as_ref()),
                    &self.schema,
                    self.hive_partition_columns.as_deref(),
                    self.row_index.as_ref(),
                )]))
            } else {
                Ok(None)
            }
        } else {
            let mut chunks = Vec::with_capacity(n);
            let mut i = 0;
            while let Some(df) = self.chunks_fifo.pop_front() {
                chunks.push(df);
                i += 1;
                if i == n {
                    break;
                }
            }
            self.has_returned = true;
            Ok(Some(chunks))
        }
    }
    #[cfg(feature = "async")]
    pub fn iter(self, batches_per_iter: usize) -> BatchedParquetIter {
        BatchedParquetIter {
            batches_per_iter,
            inner: self,
            current_batch: vec![].into_iter(),
        }
    }
}
#[cfg(feature = "async")]
pub struct BatchedParquetIter {
    batches_per_iter: usize,
    inner: BatchedParquetReader,
    current_batch: std::vec::IntoIter<DataFrame>,
}
#[cfg(feature = "async")]
impl BatchedParquetIter {
    pub(crate) async fn next_(&mut self) -> Option<PolarsResult<DataFrame>> {
        match self.current_batch.next() {
            Some(df) => Some(Ok(df)),
            None => match self.inner.next_batches(self.batches_per_iter).await {
                Err(e) => Some(Err(e)),
                Ok(opt_batch) => {
                    let batch = opt_batch?;
                    self.current_batch = batch.into_iter();
                    self.current_batch.next().map(Ok)
                },
            },
        }
    }
}