import { FC } from 'react';
import { BasicLineChart } from '@/components/charts';
import { LearningCurveWrapper } from './styles';

interface LearningCurveProps {
  trainLoss?: number[];
  trainAccuracy?: number[];
  validLoss?: number[];
  validAccuracy?: number[];
}

const BASE = 10000; // 保留四位小数

const LearningCurve: FC<LearningCurveProps> = ({
  trainLoss,
  trainAccuracy,
  validLoss,
  validAccuracy
}) => {
  const getOption = (
    title: string,
    trainData?: number[],
    validData?: number[]
  ) => ({
    tooltip: {
      trigger: 'axis',
      valueFormatter(value: number) {
        return value / BASE;
      }
    },
    title: {
      text: title,
      textStyle: {
        color: '#38445E',
        fontSize: '14px'
      }
    },
    legend: {
      data: ['训练集', '验证集'],
      bottom: 0
    },
    yAxis: {
      type: 'value',
      axisLabel: {
        formatter(value: number) {
          return value / BASE;
        }
      }
    },
    xAxis: {
      type: 'category',
      boundaryGap: false,
      data: trainData?.map((_, index) => index)
    },
    series: [
      {
        name: '训练集',
        data: trainData?.map((i) => (i * BASE).toFixed()),
        itemStyle: {
          color: '#4D64D3'
        }
      },
      {
        name: '验证集',
        data: validData?.map((i) => (i * BASE).toFixed()),
        itemStyle: {
          color: '#FAAD14'
        }
      }
    ]
  });
  return (
    <LearningCurveWrapper>
      <BasicLineChart
        style={{ height: '400px', width: '50%' }}
        option={getOption('loss', trainLoss, validLoss)}
      />
      <BasicLineChart
        style={{ height: '400px', width: '50%' }}
        option={getOption('accuracy', trainAccuracy, validAccuracy)}
      />
    </LearningCurveWrapper>
  );
};

export default LearningCurve;
