import { useRef, useEffect } from 'react';
import { Box, useColorModeValue } from '@chakra-ui/react';
import { Chart, 
    ScatterController, 
    LinearScale, 
    PointElement, 
    Title, Tooltip, Legend, Point, LegendItem } from 'chart.js';
import { useQuery } from 'react-query';
import { ModelsService } from '../../client/services/ModelsService';


Chart.register(ScatterController, LinearScale, PointElement, Title, Tooltip, Legend);


const ScatterPlot = ({ data }) => {
    const chartRef = useRef<Chart<"scatter", (number | Point | null)[], unknown> | null>(null);
    const canvasRef = useRef<HTMLCanvasElement>(null);
    const modelsQuery = useQuery('internal_models', () => ModelsService.readInternalModels({}))
    const internalModels = modelsQuery.data?.data || [];
    const possibleShapes = ['circle', 'rect', 'triangle', 'rectRot', 'star'];

    var colorArray = ['#00B3E6', '#FF6633', '#CC80CC', "#99FF99", 
        '#E64D66', '#4DB380', '#FF4D4D', '#E6B333', '#6666FF',
        '#99E6E6', '#FFB399', '#B34D4D',
        '#3366E6', '#999966', '#FF33FF', '#B34D4D', '#FFFF99',
        '#80B300', '#809900', '#E6B3B3', '#6680B3', '#66991A', 
        '#FF99E6', '#CCFF1A', '#FF1A66', '#E6331A', '#33FFCC',
        '#66994D', '#B366CC', '#4D8000', '#B33300', 
        '#66664D', '#991AFF', '#E666FF', '#4DB3FF', '#1AB399',
        '#E666B3', '#33991A', '#CC9999', '#B3B31A', '#00E680', 
        '#4D8066', '#809980', '#E6FF80', '#1AFF33', '#999933',
        '#FF3380', '#CCCC00', '#66E64D', '#4D80CC', '#9900B3', 
        ]

    const getColors = (model_names) => {
        const colors: { [key: string]: string } = {};
        var i = 0;
        for (i = 0; i < internalModels.length; i++) {
            colors[internalModels[i].name] = colorArray[i % colorArray.length];
        }
        model_names.forEach(model_name => {
            if (!colors[model_name]) {
                colors[model_name] = colorArray[i % colorArray.length];
                i++;
            }
        });
        return colors;
    }

    const getShapes = (prompt_texts) => {
        const zipped: { [key: string]: string } = prompt_texts.reduce(
            (acc: { [key: string]: string }, key: string, index: number) => {
                acc[key] = possibleShapes[index % possibleShapes.length];
                return acc;
            }, 
        {});
        return zipped;
    };

     // Define colors based on the current color mode
     const textColor = useColorModeValue('#333333', '#ffffff');
     const gridColor = useColorModeValue('#cccccc', '#888888');
    

    // Generate colors and shapes for each unique prompt_text and model_name
    const shapes = getShapes(data.map(({ prompt_text }) => prompt_text));
    const colors = getColors(data.map(({ model_name }) => model_name));

    const chartData = {
        datasets: data.map(({ prompt_text, model_name, cost, score }) => ({
            label: `${model_name} - ${prompt_text}`,
            data: [{ x: cost, y: score }],
            backgroundColor: colors[model_name],
            pointStyle: shapes[prompt_text],
            pointRadius: 10,
            pointHoverRadius: 12,
            pointHoverBorderColor: 'white',
            pointHoverBorderWidth: 1
        }))
    };

    const fontSize = 16;
    const options = {
        responsive: true,
    scales: {
            x: {
                title: {
                    display: true,
                    text: 'Cost (per 1M tokens)',
                    color: textColor, // X-axis title color
                    font: {
                        size: fontSize // Increase font size for Y-axis ticks
                    }
                },
                ticks: {
                    color: textColor, // X-axis ticks color
                    font: {
                        size: fontSize // Increase font size for Y-axis ticks
                    }
                },
                grid: {
                    color: gridColor // X-axis grid line color
                },
                min: 0,
                max: Math.max(...data.map(({ cost }) => cost*1.1))
            },
            y: {
                title: {
                    display: true,
                    text: 'Score',
                    color: textColor, // Y-axis title color
                    font: {
                        size: fontSize // Increase font size for Y-axis ticks
                    }
                },
                ticks: {
                    color: textColor, // Y-axis ticks color
                    font: {
                        size: fontSize // Increase font size for Y-axis ticks
                    },
                    margin: 100
                },
                grid: {
                    color: gridColor // Y-axis grid line color
                },
                min: 0,
                max: 1
            }
        },
        plugins: {
            tooltip: {
                position: 'nearest' as const, // Ensure tooltip follows the nearest point
                yAlign: 'bottom' as const, // Align tooltip in the center
                callbacks: {
                    label: (context) => {
                        const dataPoint = context.dataset.label;
                        return `${dataPoint}`;
                    }
                }
            },
            legend: {
                display: true,
                position: 'right' as const,
                labels: {
                    usePointStyle: true,
                },
                // TODO: Does not seem to be working?
                sort: (a: LegendItem, b: LegendItem, _: any): number => {
                    // Compare labels alphanumerically
                    return a.text.localeCompare(b.text);
                }
            }
        }
    };

    useEffect(() => {
        if (chartRef.current) {
            (chartRef.current as Chart).destroy();
        }
    
        const ctx = (canvasRef.current as HTMLCanvasElement | null)?.getContext('2d');
        if (ctx) {
            chartRef.current = new Chart(ctx, {
                type: 'scatter',
                data: chartData,
                options: options
            });
        }
    
        return () => {
            if (chartRef.current) {
                (chartRef.current as Chart).destroy();
            }
        };
    }, [data, textColor, gridColor]);

    return <>
        <Box m={5}>
            <canvas ref={canvasRef}></canvas>
        </Box>
    </> 
};

export default ScatterPlot;
