import React, { useMemo } from 'react';
import cx from 'classnames';
import { Box, makeStyles } from '@material-ui/core';
import { Typography } from '@clef/client-library';
import {
  AggregatedConfusionMatrix,
  RegisteredModel,
  SplitConfusionMatrices,
} from '@clef/shared/types';
import { useGetConfusionMatrixQuery } from '@/serverStore/modelAnalysis';
import useGetDefectNameById from '@/hooks/defect/useGetDefectNameById';
import { Theme } from '@material-ui/core';
import LoadingProgress from '../LoadingProgress';

const useTableCellStyle = makeStyles<Theme, { transparency?: number }>(theme => ({
  emptyCell: {
    width: theme.spacing(20),
    height: theme.spacing(4),
  },
  zeroCell: {
    color: theme.palette.grey[400],
  },
  tableCell: {
    display: 'flex',
    justifyContent: 'center',
    alignItems: 'center',
    width: theme.spacing(20),
    height: theme.spacing(12),
    background: '#FAFCFE',
  },
  tableCellBackground: {
    background: props =>
      props.transparency ? `rgba(1, 103, 220, ${props.transparency})` : '#0167DC',
  },
  cursorPointer: {
    cursor: 'pointer',
  },
}));

const useStyles = makeStyles<Theme>(theme => ({
  root: {
    padding: theme.spacing(4, 4, 25, 4),
    maxWidth: 'calc(100% - 110px)',
    display: 'flex',
    alignItems: 'center',
    overflowX: 'auto',
  },
  fixedHeight: {
    height: theme.spacing(7),
  },
  tableContainer: {
    marginTop: theme.spacing(2),
    fontSize: theme.spacing(3),
    fontFamily: 'Commissioner',
    display: 'flex',
    alignItems: 'flex-start',
  },
  groundTruthTitle: {
    fontWeight: 700,
    color: theme.palette.grey[400],
    paddingRight: theme.spacing(3),
    lineHeight: '16px',
  },
  firstColumn: {
    display: 'flex',
    flexDirection: 'column',
    textAlign: 'right',
    padding: theme.spacing(1, 0),
    color: theme.palette.grey[900],
  },
  firstColumnCaption: {
    display: 'flex',
    alignItems: 'center',
    justifyContent: 'flex-end',
    height: theme.spacing(12),
    paddingRight: theme.spacing(3),
  },
  lastColumn: {
    display: 'flex',
    flexDirection: 'column',
    justifyContent: 'center',
    fontWeight: 500,
    color: theme.palette.grey[900],
    background: theme.palette.grey[100],
    padding: theme.spacing(1, 1, 0, 1),
    marginLeft: theme.spacing(3),
    borderRadius: '2px',
  },
  lastColumnCaption: {
    display: 'flex',
    alignItems: 'center',
    height: theme.spacing(12),
    fontWeight: 500,
  },
  recallTitle: {
    color: theme.palette.grey[500],
    fontWeight: 600,
    lineHeight: '16px',
  },
  precisionTitle: {
    background: theme.palette.grey[100],
    color: theme.palette.grey[500],
    marginTop: theme.spacing(2),
    padding: theme.spacing(1, 3, 1, 0),
    fontWeight: 600,
    lineHeight: '16px',
  },
  precisionRow: {
    display: 'flex',
    alignItems: 'center',
  },
  precisionCell: {
    display: 'flex',
    justifyContent: 'center',
    marginTop: theme.spacing(2),
    width: theme.spacing(20),
    height: theme.spacing(6),
    background: theme.palette.grey[100],
    fontWeight: 500,
    color: theme.palette.grey[900],
  },
  lastLabelRow: {
    display: 'flex',
    alignItems: 'center',
  },
  lastLabelRowCaption: {
    display: 'flex',
    justifyContent: 'flex-end',
    alignItems: 'center',
    width: theme.spacing(20),
    transform: 'rotate(-45deg)',
    paddingRight: theme.spacing(1),
    position: 'relative',
    right: theme.spacing(5),
    transformOrigin: 'right',
    color: theme.palette.grey[900],
  },
  predictionTitle: {
    color: theme.palette.grey[400],
    fontWeight: 700,
    lineHeight: '16px',
  },
}));

const TableCell: React.FC<{
  count?: number;
  transparency?: number;
  isNoLabelCell?: boolean;
  onClick?: () => void;
}> = props => {
  const classes = useTableCellStyle({ transparency: props.transparency });
  if (props.isNoLabelCell) {
    return (
      <Box className={cx(classes.tableCell, classes.zeroCell)}>
        <Typography>--</Typography>
      </Box>
    );
  }
  if (typeof props.count !== 'undefined') {
    return (
      <Box
        className={cx(classes.tableCell, {
          [classes.zeroCell]: props.count === 0,
          [classes.tableCellBackground]: props.count > 0,
          [classes.cursorPointer]: props.count > 0 && !!props.onClick,
        })}
        onClick={() => (!!props.count && props.onClick ? props.onClick() : null)}
      >
        <Typography>{props.count}</Typography>
      </Box>
    );
  }
  return null;
};

interface IProps {
  model?: RegisteredModel;
  evaluationSetId?: number;
  threshold?: number;
  confusionMatrices?: SplitConfusionMatrices;
  onClick?: (gtClassId: number, predClassId: number) => void;
}

export const ConfusionMatrixTable: React.FC<IProps> = ({
  model,
  evaluationSetId,
  threshold,
  onClick,
  confusionMatrices,
}) => {
  const classes = useStyles();
  const { data: confusionMatrixData, isLoading: isConfusionMatrixDataLoading } =
    useGetConfusionMatrixQuery(model?.id, evaluationSetId, threshold);

  const getDefectNameById = useGetDefectNameById();
  const { countSet, gtDefectMap, predictionDefectMap, confusionMatrixMap } = useMemo(() => {
    const gtDefectMap = new Map<number, string>([[0, 'No label']]);
    const predictionDefectMap = new Map<number, string>([[0, 'No prediction']]);
    const countSet = new Set<number>();
    const confusionMatrixMap = new Map<string, number>();
    const setConfusionMatrixMap = (confusionMatrix: AggregatedConfusionMatrix[]) => {
      confusionMatrix
        .filter(m => m.count > 0)
        .forEach(item => {
          const gtCaption = item.gtClassId ? getDefectNameById(item.gtClassId) : 'No label';
          const predictionCaption = item.predClassId
            ? getDefectNameById(item.predClassId)
            : 'No prediction';
          gtDefectMap.set(item.gtClassId, gtCaption);
          predictionDefectMap.set(item.predClassId, predictionCaption);
          confusionMatrixMap.set(`${item.gtClassId}-${item.predClassId}`, item.count);
          countSet.add(item.count);
        });
    };

    const { correct, misClassification, falseNegative, falsePositive } =
      confusionMatrices ?? confusionMatrixData?.splitConfusionMatrices ?? {};
    setConfusionMatrixMap(correct?.data ?? []);
    setConfusionMatrixMap(misClassification?.data ?? []);
    setConfusionMatrixMap(falseNegative?.data ?? []);
    setConfusionMatrixMap(falsePositive?.data ?? []);
    return {
      gtDefectMap,
      predictionDefectMap,
      countSet,
      confusionMatrixMap,
    };
  }, [confusionMatrixData, confusionMatrices, getDefectNameById]);

  const gtDefectIds = Array.from(gtDefectMap.keys()).sort((a, b) => b - a);
  const predDefectIds = Array.from(predictionDefectMap.keys()).sort((a, b) => b - a);
  const orderedCountSet = Array.from(countSet).sort((a, b) => a - b);

  if (evaluationSetId && isConfusionMatrixDataLoading) {
    return (
      <Box className={classes.root}>
        <LoadingProgress />
      </Box>
    );
  }

  return (
    <Box className={classes.root}>
      <Box className={classes.tableContainer}>
        <Box className={classes.firstColumn}>
          <Typography variant="body2" maxWidth={100} className={cx(classes.groundTruthTitle)}>
            {t('Ground truth')}
          </Typography>
          {gtDefectIds.map(defectId => (
            <Box key={`ground-truth-${defectId}`} className={classes.firstColumnCaption}>
              <Typography variant="body2" maxWidth={100}>
                {gtDefectMap.get(defectId)}
              </Typography>
            </Box>
          ))}
          <Typography variant="body2" className={classes.precisionTitle}>
            {t('Precision')}
          </Typography>
        </Box>
        <Box>
          <Box height={20} />
          {gtDefectIds.map(defectIdA => {
            return (
              <Box key={`row-${defectIdA}`} display="flex">
                {predDefectIds.map(defectIdB => {
                  if (defectIdA === defectIdB && defectIdA === 0) {
                    return <TableCell key={`cell-${defectIdA}-${defectIdB}`} isNoLabelCell />;
                  }
                  const count = confusionMatrixMap.get(`${defectIdA}-${defectIdB}`);
                  return (
                    <TableCell
                      key={`cell-${defectIdA}-${defectIdB}`}
                      count={confusionMatrixMap.get(`${defectIdA}-${defectIdB}`) ?? 0}
                      transparency={
                        count
                          ? (orderedCountSet.indexOf(count) + 1) / orderedCountSet.length
                          : undefined
                      }
                      onClick={onClick ? () => onClick(defectIdA, defectIdB) : undefined}
                    />
                  );
                })}
              </Box>
            );
          })}
          <Box className={classes.precisionRow}>
            {predDefectIds.map(defectId => {
              if (defectId === 0) {
                return null;
              }
              const correctCount = confusionMatrixMap.get(`${defectId}-${defectId}`) ?? 0;
              const totalCount = gtDefectIds.reduce(
                (total, _defectId) =>
                  total + (confusionMatrixMap.get(`${_defectId}-${defectId}`) ?? 0),
                0,
              );
              if (totalCount === 0) {
                return (
                  <Typography className={classes.precisionCell} key={`precision-${defectId}`}>
                    --
                  </Typography>
                );
              }
              return (
                <Typography className={classes.precisionCell} key={`precision-${defectId}`}>
                  {((correctCount / totalCount) * 100).toFixed(1)}%
                </Typography>
              );
            })}
          </Box>
          <Box className={classes.lastLabelRow}>
            {predDefectIds.map(defectId => (
              <Box key={`last-label-${defectId}`} className={classes.lastLabelRowCaption}>
                <Typography maxWidth={100} variant="body2">
                  {predictionDefectMap.get(defectId)}
                </Typography>
              </Box>
            ))}
          </Box>
        </Box>
        <Box display="flex" flexDirection="column">
          <Box className={classes.lastColumn}>
            <Typography variant="body2" className={classes.recallTitle}>
              {t('Recall')}
            </Typography>
            {gtDefectIds.map(defectId => {
              if (defectId === 0) {
                return null;
              }
              const correctCount = confusionMatrixMap.get(`${defectId}-${defectId}`) ?? 0;
              const totalCount = predDefectIds.reduce(
                (total, _defectId) =>
                  total + (confusionMatrixMap.get(`${defectId}-${_defectId}`) ?? 0),
                0,
              );
              if (totalCount === 0) {
                return (
                  <Typography
                    className={cx(classes.lastColumnCaption)}
                    key={`precision-${defectId}`}
                  >
                    --
                  </Typography>
                );
              }
              return (
                <Typography className={cx(classes.lastColumnCaption)} key={`precision-${defectId}`}>
                  {((correctCount / totalCount) * 100).toFixed(1)}%
                </Typography>
              );
            })}
          </Box>
          <Box display="flex" flexDirection="column">
            <Box height={80}></Box>
            <Box key={`last-label-prediction`} className={classes.lastLabelRowCaption}>
              <Typography className={classes.predictionTitle} variant="body2">
                {t('Prediction')}
              </Typography>
            </Box>
          </Box>
        </Box>
      </Box>
    </Box>
  );
};
