import React, { useEffect, useRef, useState } from 'react';
import * as d3 from 'd3';

interface HeatmapProps {
  data: Record<string, any>;
  onHeightChange: (height: number) => void;
  isLoading: boolean;
}

const Heatmap: React.FC<HeatmapProps> = ({ data, onHeightChange, isLoading  }) => {
  const fontSize = 16;

  function wrap(text: d3.Selection<SVGTextElement>, width: number): void {
    text.each(function(this: SVGTextElement) {
      var text = d3.select(this),
          words = text.text().split(/\s+/).reverse(),
          word: string | undefined,
          line: string[] = [],
          lineNumber = 0,
          lineHeight = 1.1, // ems
          y = text.attr("y"),
          dy = parseFloat(text.attr("dy")),
          tspan = text.text(null).append("tspan")
          .attr("font-size", String(fontSize)+"px")
          .attr("x", 0).attr("y", y).attr("dy", dy + "em");

      while (word = words.pop()) {
        line.push(word);
        tspan.text(line.join(" "));
        if (tspan.node().getComputedTextLength() > width) {
          line.pop();
          tspan.text(line.join(" "));
          line = [word];
          tspan = text.append("tspan")
          .attr("font-size", String(fontSize)+"px")
          .attr("x", 0).attr("y", y).attr("dy", ++lineNumber * lineHeight + dy + "em").text(word);
        }
      }
    });
  }

  const ref = useRef<SVGSVGElement>(null);
  const [dimensions, setDimensions] = useState({ width: 0, height: 0 });

  const xAxisRef = useRef<SVGSVGElement>(null);

  // Start with an initial guess for the bottom margin
  const [bottomMargin, _] = useState(0);
  const [xAxisHeight, setXAxisHeight] = useState(0);

  useEffect(() => {
    if (!ref.current || dimensions.width === 0 || dimensions.height === 0) return;
    if (!data.length || isLoading) return; // Check if data is loaded

    // Clear previous SVG
    d3.select(ref.current).selectAll("*").remove();
    d3.select(xAxisRef.current).selectAll("*").remove();
    const svg = d3.select(ref.current);
    const xAxis = d3.select(xAxisRef.current);

    const textDummy = svg.append("text").attr("font-size", String(fontSize)+"px").style("opacity", 0);
    const maxTextWidth = Math.max(...data.map(d => {
      textDummy.text(d.model_name);
      return textDummy.node()!.getBBox().width;
    }));
    textDummy.remove();

    const margin = { top: 30, right: 30, bottom: bottomMargin, left: maxTextWidth+25};
    const width = dimensions.width - margin.left - margin.right;
    const height = dimensions.height - margin.top - margin.bottom;

    // Append SVG group with margins
    const g = svg.append("g").attr("transform", `translate(${margin.left},${margin.top})`);

    // Scales
    const x = d3.scaleBand()
      .domain(data.map(d => d.prompt_text))
      .range([0, width])
      .padding(0.005);

    const y = d3.scaleBand()
      .domain(data.map(d => d.model_name))
      .range([0, height])
      .padding(0.05);

    // Color scale
    const myColor = d3.scaleSequential()
      .interpolator(d3.interpolateGreens)
      .domain([0, 1]);

    // Draw cells
    g.selectAll()
      .data(data)
      .enter()
      .append("rect")
      .attr("x", d => x(d.prompt_text))
      .attr("y", d => y(d.model_name))
      .attr("rx", 4)
      .attr("ry", 4)
      .attr("width", x.bandwidth())
      .attr("height", y.bandwidth())
      .style("fill", d => myColor(d.score))

    // Adding text labels inside each cell
    g.selectAll()
      .data(data)
      .enter()
      .append("text")
      .text(d => d.score ? d.score.toFixed(2) : "-")
      .attr("x", d => x(d.prompt_text) + x.bandwidth() / 2)
      .attr("y", d => y(d.model_name) + y.bandwidth() / 2)
      .attr("dy", ".35em") // vertically center text
      .attr("text-anchor", "middle") // center text horizontally
      .style("fill", "black") // choose a fill color that stands out on your colors
      .style("font-size", fontSize); // adjust font size based on your preferences


    // X axis
    // Append x-axis to its own SVG
    xAxis.append("g")
      .attr("transform", `translate(${margin.left}, 0)`)
      .call(d3.axisBottom(x))
      .selectAll(".tick text")
      .call(wrap, x.bandwidth())
      .attr("dy", "-0.8em")
      .attr("x", -9);

    // Y axis
    g.append("g")
      .call(d3.axisLeft(y))

    svg.selectAll("text").style("font-size", String(fontSize)+"px"); // Overall text size adjustment

    const bbox = svg.node().getBBox();
    onHeightChange(bbox.height);

    setXAxisHeight(xAxis.node()!.getBBox().height);

  }, [data, dimensions, bottomMargin, xAxisHeight]);

  useEffect(() => {
    const observer = new ResizeObserver(entries => {
      for (const entry of entries) {
        setDimensions({
          width: entry.contentRect.width,
          height: entry.contentRect.height
        });
      }
    });
    if (ref.current) {
      observer.observe(ref.current);
    }
    return () => observer.disconnect();
  }, []);

  return <>
      <svg ref={ref} style={{ width: '100%'}}></svg>
      <div style={{height: "50px", resize: "vertical",  overflowY: 'auto', width: "100%"}}>
          <svg ref={xAxisRef} style={{height: xAxisHeight, width: "100%" }} />
      </div>
  </>
};

export default Heatmap;
