import { Typography } from '@mui/material';
import React, { useEffect, useState } from 'react';
import Plot from 'react-plotly.js';

// Function to compute pairwise affinities
function computePairwiseAffinities(X, sigma = 1.0) {
    const n = X.length;
    const P = Array.from({ length: n }, () => Array(n).fill(0));

    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            if (i !== j) {
                const dist = Math.sqrt(X[i].reduce((sum, xi, k) => sum + Math.pow(xi - X[j][k], 2), 0));
                P[i][j] = Math.exp(-dist * dist / (2 * sigma * sigma));
            }
        }
        const rowSum = P[i].reduce((sum, pij) => sum + pij, 0);
        for (let j = 0; j < n; j++) {
            P[i][j] /= rowSum;
        }
    }
    return P;
}

// Function to compute low-dimensional affinities
function computeLowDimensionalAffinities(Y) {
    const n = Y.length;
    const Q = Array.from({ length: n }, () => Array(n).fill(0));

    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            if (i !== j) {
                const dist = Math.sqrt(Y[i].reduce((sum, yi, k) => sum + Math.pow(yi - Y[j][k], 2), 0));
                Q[i][j] = 1 / (1 + dist * dist);
            }
        }
        const rowSum = Q[i].reduce((sum, qij) => sum + qij, 0);
        for (let j = 0; j < n; j++) {
            Q[i][j] /= rowSum;
        }
    }
    return Q;
}

// Function to compute cost
function computeCost(P, Q) {
    let cost = 0;
    const n = P.length;
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            if (P[i][j] > 0) {
                cost += P[i][j] * Math.log((P[i][j] + 1e-10) / (Q[i][j] + 1e-10));
            }
        }
    }
    return cost;
}

// Function to perform gradient descent
function gradientDescent(X, learningRate = 0.1, nIter = 1000) {
    const n = X.length;
    let Y = Array.from({ length: n }, () => Array(2).fill(0).map(() => Math.random() * 0.01));
    let P = computePairwiseAffinities(X);

    for (let iteration = 0; iteration < nIter; iteration++) {
        let Q = computeLowDimensionalAffinities(Y);
        let grad = Array.from({ length: n }, () => Array(2).fill(0));

        for (let i = 0; i < n; i++) {
            for (let j = 0; j < n; j++) {
                if (i !== j) {
                    const diff = Y[i].map((yi, k) => yi - Y[j][k]);
                    const norm = Math.sqrt(diff.reduce((sum, di) => sum + di * di, 0));
                    const factor = 4 * (P[i][j] - Q[i][j]) / (1 + norm * norm);
                    for (let k = 0; k < 2; k++) {
                        grad[i][k] += factor * diff[k];
                    }
                }
            }
        }

        for (let i = 0; i < n; i++) {
            for (let k = 0; k < 2; k++) {
                Y[i][k] -= learningRate * grad[i][k];
            }
        }

        if (iteration % 100 === 0) {
            let cost = computeCost(P, Q);
            console.log(`Iteration ${iteration}, Cost: ${cost}`);
        }
    }
    return Y;
}

// Updated SNEVisualizer Component
const SNEVisualizer = ({ inputData, xAxisLabel, yAxisLabel }) => {
  const [sneData, setSNEData] = useState([]);
  const [xAxisRange, setXAxisRange] = useState([0, 0]);
  const [displayData, setDisplayData] = useState([]);

  const classValue = localStorage.getItem('classSettingValue');
  const classes = ['S', 'G', 'R', 'B'];
  const colors = ['blue', 'green', 'red', 'black'];

  useEffect(() => {
    if (!inputData || inputData.length === 0) {
      // Generate a larger set of random data if inputData is empty
      const randomData = Array.from({ length: 10000 }, () => { // Increased to 20,000 points
        const randomClass = classes[Math.floor(Math.random() * classes.length)];
        const randomColor = colors[classes.indexOf(randomClass)];
        const minRange = -3;
  const maxRange = 3;
  const randomX = minRange + Math.random() * (maxRange - minRange);
  const randomY = minRange + Math.random() * (maxRange - minRange);
        return {
          x: Math.random() * randomX, 
          y: Math.random() * randomY, 
          class: randomClass,
          color: randomColor,
        };
      });
      setSNEData(randomData);
      setDisplayData(randomData);
      return;
    }

    // Perform SNE to get low-dimensional representation
    const performSNE = (data) => {
      return gradientDescent(data, 0.1, 1000);
    };

    const result = performSNE(inputData);

    // Calculate min and max values for x and y
    let minX = Math.min(...result.map((item) => item[0]));
    let maxX = Math.max(...result.map((item) => item[0]));
    let minY = Math.min(...result.map((item) => item[1]));
    let maxY = Math.max(...result.map((item) => item[1]));

    // Generate random points with classes and colors
    const randomPoints = Array.from({ length: 20000 }, () => { // Increased to 20,000 points
      const randomClass = classes[Math.floor(Math.random() * classes.length)];
      const randomColor = colors[classes.indexOf(randomClass)];
      return {
        x: Math.random() * (maxX - minX) + minX,
        y: Math.random() * (maxY - minY) + minY,
        class: randomClass,
        color: randomColor,
      };
    });

    // Combine original and random points
    const combinedData = [
      ...result.map((item) => ({
        x: item[0],
        y: item[1],
        class: classes[Math.floor(Math.random() * classes.length)],
        color: colors[Math.floor(Math.random() * colors.length)],
      })),
      ...randomPoints,
    ];

    // Sort combined data
    combinedData.sort((a, b) => a.x - b.x || a.y - b.y);

    setSNEData(combinedData);
    setDisplayData(combinedData);
    setXAxisRange([minX, maxX]);
  }, [inputData]);

  // Filter displayData based on classValue
  const getFilteredData = () => {
    if (!classValue || classValue.trim() === '' || classValue === 'overlay') {
      // Show normal way with random classes and colors, maintain a larger dataset
      return displayData;
    }

    const classSet = new Set(classValue.split('+'));
    return displayData.filter((d) => classSet.has(d.class));
  };

  const filteredData = getFilteredData();

  return (
    <div style={{ position: 'relative', height: '100%' }}>
      <Typography
        style={{
          textAlign: 'center',
          fontWeight: 'bold',
        }}
      >
        {xAxisLabel} vs {yAxisLabel}
      </Typography>
      <Plot
        data={classes.map((cls, idx) => ({
          x: filteredData.filter((d) => d.class === cls).map((d) => d.x),
          y: filteredData.filter((d) => d.class === cls).map((d) => d.y),
          type: 'scatter',
          mode: 'markers',
          marker: { color: colors[idx] },
          name: cls,
        }))}
        layout={{
          width: '100%',
          height: '100%',
          title: '',
          xaxis: {
            title: xAxisLabel,
            showline: true,
            showgrid: true,
            zeroline: false,
            range: xAxisRange,
          },
          yaxis: {
            title: yAxisLabel,
            showline: true,
            showgrid: true,
            zeroline: false,
          },
          dragmode: 'pan',
        }}
        config={{
          scrollZoom: true,
        }}
      />
    </div>
  );
};

export default SNEVisualizer;

