import { useQuery } from '@apollo/client'
import { LinearProgress, styled } from '@mui/material'
import {
  DataGridPro,
  GridPaginationModel,
  GridSlots,
  useGridApiRef
} from '@mui/x-data-grid-pro'
import { useAtom } from 'jotai'
import { startCase } from 'lodash-es'
import { useState } from 'react'

import { Typography } from '@/components'
import { NoResults } from '@/components/NoResults'
import {
  GET_COLUMNS_FOR_INSIGHTS_DASHBOARD,
  GET_DATASETS_LIST_FOR_REPORT
} from '@/datasets/queries/get_datasets_for_insights_dashboard'
import { AGGREGATION_TYPE_MEAN_PLUS_MINUS_STD } from '@/insights/charts'
import { InsightsTableFilters } from '@/insights/controls/InsightsTableFilters'
import {
  DATASET_TABLE_ROW_PROPERTY_DENY_LIST,
  FIXED_COLUMN_PROPERTIES,
  PROPERTY_TYPE_TO_GRID_COL_TYPE
} from '@/insights/home/insights_slice'

import { visualizationAtomFamily } from '../../../store/report.molecule'
import { TableConfig } from '../../../types'

const { Title } = Typography

const COLUMN_HEADER_HEIGHT = 80
const PAGE_SIZE_OPTIONS = [25, 50, 100]

type ReportTableProps = {
  datasetIds: string[]
  workspaceId?: string
  organizationId?: string
  config: TableConfig
  title?: string
  tableId?: string
  showEditControls?: boolean
}
export const ReportTable = (props: ReportTableProps) => {
  const {
    datasetIds,
    workspaceId,
    organizationId,
    config,
    title: tableTitle,
    tableId = '',
    showEditControls = true
  } = props

  const [visualization, setVisualization] = useAtom(
    visualizationAtomFamily({ id: tableId })
  )

  const apiRef = useGridApiRef()

  const { columns, groupByProperty } = (visualization.config ||
    config) as TableConfig

  const [paginationModel, setPaginationModel] = useState<GridPaginationModel>({
    pageSize: PAGE_SIZE_OPTIONS[0],
    page: 0
  })

  const { data: datasetData, loading } = useQuery(
    GET_DATASETS_LIST_FOR_REPORT,
    {
      variables: {
        organization_id: organizationId!,
        workspace_ids: [workspaceId!],
        dataset_ids: datasetIds,
        page_from: paginationModel.page,
        page_size: paginationModel.pageSize,
        properties_sort: [],
        group_by: groupByProperty?.key ?? '',
        aggregate_by: AGGREGATION_TYPE_MEAN_PLUS_MINUS_STD.key // temporarily hardcoded
        // TODO: Filter columns based on config instead of fetching all and hiding them on FE
        // columns: columns?? []
      },
      skip: !organizationId || !workspaceId || !datasetIds
    }
  )

  const { data: columnData } = useQuery(GET_COLUMNS_FOR_INSIGHTS_DASHBOARD, {
    variables: {
      organization_id: organizationId!,
      workspace_ids: [workspaceId!]
    },
    skip: !organizationId || !workspaceId
  })

  const fixedColumns = [...FIXED_COLUMN_PROPERTIES]
  const datasetColumns = columnData?.get_datasets_list.columns ?? []
  const has_cycle_count = datasetColumns.some(
    property => property.key === 'cycle_count'
  )
  const cycle_count_column_index = fixedColumns.findIndex(
    column => column.field === 'cycle_count'
  )
  if (cycle_count_column_index !== -1 && !has_cycle_count) {
    fixedColumns.splice(cycle_count_column_index, 1)
  }

  const nonGroupedColumnDefs = datasetColumns.length
    ? fixedColumns.concat(
        datasetColumns
          .filter(
            ({ key }) =>
              !FIXED_COLUMN_PROPERTIES.map(({ field }) => field).includes(
                key
              ) && columns?.includes(key)
          )
          .sort((a, b) => {
            const aIndex = columns?.indexOf(a.key) ?? -1
            const bIndex = columns?.indexOf(b.key) ?? -1
            if (aIndex === -1 && bIndex === -1) return 0
            if (aIndex === -1) return -1
            if (bIndex === -1) return 1
            return aIndex - bIndex
          })
          .map(({ key, property_type }) => {
            const type = property_type
              ? PROPERTY_TYPE_TO_GRID_COL_TYPE[property_type]
              : 'string'
            return {
              minWidth: 150,
              width: 150,
              field: key,
              headerName: startCase(key),
              valueGetter:
                type === PROPERTY_TYPE_TO_GRID_COL_TYPE.date
                  ? value => (value != null ? new Date(value) : null)
                  : undefined,
              type,
              flex: 1
            }
          })
      )
    : []

  const groupedColumnDefs = datasetColumns
    .filter(
      ({ key }) =>
        !FIXED_COLUMN_PROPERTIES.map(({ field }) => field).includes(key) &&
        columns?.includes(key)
    )
    .sort((a, b) => {
      // Pin group by property to the left
      if (groupByProperty && a.key === groupByProperty.key) return -1
      if (groupByProperty && b.key === groupByProperty.key) return 1

      // Sort by pinned status
      return a.pinned === b.pinned ? 0 : a.pinned ? -1 : 1
    })
    .map(({ key, property_type }) => {
      const type = property_type
        ? PROPERTY_TYPE_TO_GRID_COL_TYPE[property_type]
        : 'string'
      return {
        width: 150,
        field: key,
        headerName: startCase(key),
        valueGetter:
          type === PROPERTY_TYPE_TO_GRID_COL_TYPE.date
            ? (value: Nullable<string>) =>
                value !== null ? new Date(value) : null
            : undefined,
        type,
        flex: 1
      }
    })

  // Select column definitions depending on whether grouping was applied
  const columnDefs = groupByProperty ? groupedColumnDefs : nonGroupedColumnDefs

  const validDatasets = (
    datasetData?.get_datasets_list_for_report.rows ?? []
  ).map(({ id, properties }) => {
    const filtered_properties = properties?.filter(
      ({ key }) => !DATASET_TABLE_ROW_PROPERTY_DENY_LIST.includes(key)
    )

    return {
      id,
      ...filtered_properties?.reduce(
        (dataset_properties: Record<string, any>, { key, value }) => {
          if (key.endsWith('_std')) return dataset_properties

          // If adding more aggregation types, need to update this
          // Find corresponding standard deviation value
          const stdKey = `${key}_std`
          const stdValue = filtered_properties.find(
            p => p.key === stdKey
          )?.value

          // Round numeric values to 3 decimal places
          // TODO: add more sophisticated rounding
          const roundedValue =
            typeof value === 'number' ? Number(value.toFixed(3)) : value
          const roundedStdValue =
            typeof stdValue === 'number'
              ? Number(stdValue.toFixed(3))
              : stdValue

          // Combine value with std if it exists, otherwise use plain value
          dataset_properties[key] =
            roundedStdValue != null
              ? `${roundedValue} ± ${roundedStdValue}`
              : roundedValue

          return dataset_properties
        },
        {}
      )
    }
  })

  return (
    <>
      <div className='flex items-center justify-between mb-4'>
        <Title
          level={5}
          className='[&>button]:text-sm flex flex-row items-baseline gap-x-1'
          editable={
            showEditControls
              ? {
                  onChange: (newTitle: string) => {
                    if (newTitle === '') return
                    setVisualization(prev => ({ ...prev, title: newTitle }))
                  }
                }
              : false
          }
        >
          {visualization.title || tableTitle}
        </Title>
        {showEditControls && (
          <InsightsTableFilters tableId={tableId} tableRef={apiRef} />
        )}
      </div>
      <Styled_Dataset_Datagrid
        apiRef={apiRef}
        columnHeaderHeight={COLUMN_HEADER_HEIGHT}
        columns={columnDefs}
        density='compact'
        disableColumnMenu
        disableDensitySelector
        disableRowSelectionOnClick
        filterDebounceMs={500}
        hideFooterSelectedRowCount
        loading={loading}
        onPaginationModelChange={setPaginationModel}
        pageSizeOptions={PAGE_SIZE_OPTIONS}
        pagination
        paginationMode='server'
        paginationModel={paginationModel}
        rowCount={validDatasets.length}
        rows={validDatasets || []}
        slots={{
          loadingOverlay: LinearProgress as GridSlots['loadingOverlay'],
          noResultsOverlay: NoResults,
          noRowsOverlay: NoResults
        }}
        slotProps={{
          pagination: {
            rowsPerPageOptions: PAGE_SIZE_OPTIONS.map(size => ({
              value: size,
              label: `${size} rows per page`
            })),
            labelRowsPerPage: null
          }
        }}
      />
    </>
  )
}

const Styled_Dataset_Datagrid = styled(DataGridPro)(({ theme, rowCount }) => ({
  // border: 'none',
  maxHeight: '400px',
  maxWidth: '100%',
  minHeight: rowCount === 0 ? '180px' : 'auto',
  borderRadius: theme.shape.borderRadius,

  '.MuiDataGrid-columnHeaders': {
    background: theme.palette.gray[100],
    borderTopLeftRadius: theme.shape.borderRadius,
    borderTopRightRadius: theme.shape.borderRadius,
    borderRadius: 0,
    border: 'none',
    '[role=row]': {
      background: 'transparent'
    }
  },
  '.MuiDataGrid-columnHeader': {
    background: 'transparent'
  },
  '.MuiDataGrid-main': {
    overflowY: 'auto'
  },
  '.MuiDataGrid-row:hover': {
    background: theme.palette.info[50]
  },
  '.MuiDataGrid-footerContainer': {
    border: 0,
    borderTop: `1px solid ${theme.palette.divider}`
  },
  '.MuiTablePagination-root': {
    width: '100%',
    '.MuiTablePagination-toolbar': {
      border: 0,
      '.MuiInputBase-root': {
        border: 0,
        background: 'transparent',
        boxShadow: 'none',
        margin: 0
      },
      '.MuiSelect-select': {
        border: 0
      },
      '.MuiTablePagination-spacer': {
        display: 'none'
      },
      '.MuiTablePagination-displayedRows': {
        flex: 1,
        textAlign: 'right'
      }
    }
  }
}))
