"use client";

import { useEffect, useRef, useState, type CSSProperties } from "react";
import * as THREE from "three";

type NeuralBackgroundProps = {
  colors: [string, string, string];
};

type NodePoint = {
  connections: Array<{ node: NodePoint; strength: number }>;
  distanceFromRoot: number;
  level: number;
  position: THREE.Vector3;
  size: number;
};

export function NeuralBackground({ colors }: NeuralBackgroundProps) {
  const canvasRef = useRef<HTMLCanvasElement | null>(null);
  const [useFallback, setUseFallback] = useState(false);

  useEffect(() => {
    const canvas = canvasRef.current;
    if (!canvas) {
      return;
    }
    if (useFallback) {
      return;
    }

    let cleanup: (() => void) | undefined;

    try {
    const scene = new THREE.Scene();
    scene.fog = new THREE.FogExp2(0x000000, 0.00135);

    const camera = new THREE.PerspectiveCamera(65, window.innerWidth / window.innerHeight, 0.1, 1000);
    camera.position.set(0, 7, 30);

    const renderer = new THREE.WebGLRenderer({
      alpha: true,
      antialias: true,
      canvas,
      powerPreference: "high-performance"
    });
    if (!renderer.getContext()) {
      renderer.dispose();
      setUseFallback(true);
      return;
    }
    renderer.setClearColor(0x000000, 0);
    renderer.setPixelRatio(Math.min(window.devicePixelRatio || 1, 1.7));
    renderer.setSize(window.innerWidth, window.innerHeight);

    const palette = colors.map((color) => new THREE.Color(color));
    const nodes = createNetwork(palette);
    const nodeGeometry = new THREE.BufferGeometry();
    const positions: number[] = [];
    const nodeColors: number[] = [];
    const sizes: number[] = [];

    nodes.forEach((node) => {
      positions.push(node.position.x, node.position.y, node.position.z);
      const color = palette[node.level % palette.length];
      nodeColors.push(color.r, color.g, color.b);
      sizes.push(node.size);
    });

    nodeGeometry.setAttribute("position", new THREE.Float32BufferAttribute(positions, 3));
    nodeGeometry.setAttribute("color", new THREE.Float32BufferAttribute(nodeColors, 3));
    nodeGeometry.setAttribute("size", new THREE.Float32BufferAttribute(sizes, 1));

    const nodeMaterial = new THREE.ShaderMaterial({
      blending: THREE.AdditiveBlending,
      depthWrite: false,
      transparent: true,
      uniforms: { uTime: { value: 0 } },
      vertexShader: `
        attribute float size;
        attribute vec3 color;
        varying vec3 vColor;
        uniform float uTime;
        void main() {
          vColor = color;
          vec3 p = position;
          p.x += sin(uTime * 0.35 + position.y * 0.08) * 0.25;
          p.y += cos(uTime * 0.28 + position.x * 0.07) * 0.2;
          vec4 mvPosition = modelViewMatrix * vec4(p, 1.0);
          float pulse = 0.8 + 0.25 * sin(uTime * 1.4 + position.z);
          gl_PointSize = size * pulse * (360.0 / -mvPosition.z);
          gl_Position = projectionMatrix * mvPosition;
        }
      `,
      fragmentShader: `
        varying vec3 vColor;
        void main() {
          vec2 center = gl_PointCoord - 0.5;
          float dist = length(center);
          if (dist > 0.5) discard;
          float alpha = 1.0 - smoothstep(0.0, 0.5, dist);
          vec3 glow = vColor * (1.0 + alpha * 0.75);
          gl_FragColor = vec4(glow, alpha * 0.9);
        }
      `
    });

    const nodeMesh = new THREE.Points(nodeGeometry, nodeMaterial);
    scene.add(nodeMesh);

    const linePositions: number[] = [];
    const lineColors: number[] = [];
    const used = new Set<string>();
    nodes.forEach((node, index) => {
      node.connections.forEach((connection) => {
        const otherIndex = nodes.indexOf(connection.node);
        if (otherIndex < 0) return;
        const key = [Math.min(index, otherIndex), Math.max(index, otherIndex)].join("-");
        if (used.has(key)) return;
        used.add(key);
        linePositions.push(
          node.position.x,
          node.position.y,
          node.position.z,
          connection.node.position.x,
          connection.node.position.y,
          connection.node.position.z
        );
        const color = palette[Math.min(node.level, palette.length - 1)];
        lineColors.push(color.r, color.g, color.b, color.r, color.g, color.b);
      });
    });

    const lineGeometry = new THREE.BufferGeometry();
    lineGeometry.setAttribute("position", new THREE.Float32BufferAttribute(linePositions, 3));
    lineGeometry.setAttribute("color", new THREE.Float32BufferAttribute(lineColors, 3));
    const lineMaterial = new THREE.LineBasicMaterial({
      blending: THREE.AdditiveBlending,
      opacity: 0.28,
      transparent: true,
      vertexColors: true
    });
    const lineMesh = new THREE.LineSegments(lineGeometry, lineMaterial);
    scene.add(lineMesh);

    const starGeometry = new THREE.BufferGeometry();
    const starPositions: number[] = [];
    for (let i = 0; i < 1200; i++) {
      const radius = THREE.MathUtils.randFloat(40, 130);
      const theta = Math.random() * Math.PI * 2;
      const phi = Math.acos(THREE.MathUtils.randFloatSpread(2));
      starPositions.push(
        radius * Math.sin(phi) * Math.cos(theta),
        radius * Math.sin(phi) * Math.sin(theta),
        radius * Math.cos(phi)
      );
    }
    starGeometry.setAttribute("position", new THREE.Float32BufferAttribute(starPositions, 3));
    const starMaterial = new THREE.PointsMaterial({
      blending: THREE.AdditiveBlending,
      color: palette[0],
      opacity: 0.35,
      size: 0.16,
      transparent: true
    });
    const stars = new THREE.Points(starGeometry, starMaterial);
    scene.add(stars);

    const clock = new THREE.Clock();
    let frame = 0;

    function animate() {
      frame = window.requestAnimationFrame(animate);
      const time = clock.getElapsedTime();
      nodeMaterial.uniforms.uTime.value = time;
      nodeMesh.rotation.y = time * 0.045;
      nodeMesh.rotation.x = Math.sin(time * 0.12) * 0.08;
      lineMesh.rotation.copy(nodeMesh.rotation);
      stars.rotation.y = time * 0.012;
      renderer.render(scene, camera);
    }

    function handleResize() {
      camera.aspect = window.innerWidth / window.innerHeight;
      camera.updateProjectionMatrix();
      renderer.setSize(window.innerWidth, window.innerHeight);
    }

    window.addEventListener("resize", handleResize);
    animate();

    cleanup = () => {
      window.removeEventListener("resize", handleResize);
      window.cancelAnimationFrame(frame);
      nodeGeometry.dispose();
      nodeMaterial.dispose();
      lineGeometry.dispose();
      lineMaterial.dispose();
      starGeometry.dispose();
      starMaterial.dispose();
      renderer.dispose();
    };
    } catch {
      cleanup?.();
      setUseFallback(true);
    }

    return () => {
      cleanup?.();
    };
  }, [colors, useFallback]);

  return (
    <>
      <canvas
        aria-hidden="true"
        className="neural-bg-canvas"
        ref={canvasRef}
        style={useFallback ? { display: "none" } : undefined}
      />
      {useFallback ? (
        <div
          aria-hidden="true"
          className="neural-bg-fallback"
          style={
            {
              "--nf-a": colors[0],
              "--nf-b": colors[1],
              "--nf-c": colors[2]
            } as CSSProperties
          }
        />
      ) : null}
    </>
  );
}

function createNetwork(palette: THREE.Color[]) {
  const nodes: NodePoint[] = [];
  const root = createNode(new THREE.Vector3(0, 0, 0), 0, 2.4);
  nodes.push(root);

  for (let layer = 1; layer <= 5; layer++) {
    const radius = layer * 4.2;
    const count = layer * 16;
    for (let i = 0; i < count; i++) {
      const phi = Math.acos(1 - (2 * (i + 0.5)) / count);
      const theta = Math.PI * (1 + Math.sqrt(5)) * i;
      const position = new THREE.Vector3(
        radius * Math.sin(phi) * Math.cos(theta),
        radius * Math.sin(phi) * Math.sin(theta),
        radius * Math.cos(phi)
      );
      const node = createNode(position, layer, 0.8 + Math.random() * 0.9);
      node.distanceFromRoot = radius;
      nodes.push(node);

      const previousLayer = nodes.filter((item) => item.level === layer - 1);
      previousLayer
        .sort((a, b) => position.distanceTo(a.position) - position.distanceTo(b.position))
        .slice(0, 3)
        .forEach((target) => connect(node, target, 0.7));
    }
  }

  for (const node of nodes) {
    const near = nodes
      .filter((item) => item !== node && item.level >= node.level - 1 && item.level <= node.level + 1)
      .sort((a, b) => node.position.distanceTo(a.position) - node.position.distanceTo(b.position))
      .slice(0, 3);
    near.forEach((target) => {
      if (Math.random() > 0.35) connect(node, target, 0.45);
    });
  }

  // Touch palette so color changes always retrigger React effect through deps.
  palette.forEach((color) => color.offsetHSL(0, 0, 0));
  return nodes;
}

function createNode(position: THREE.Vector3, level: number, size: number): NodePoint {
  return {
    connections: [],
    distanceFromRoot: position.length(),
    level,
    position,
    size
  };
}

function connect(a: NodePoint, b: NodePoint, strength: number) {
  if (a.connections.some((connection) => connection.node === b)) {
    return;
  }
  a.connections.push({ node: b, strength });
  b.connections.push({ node: a, strength });
}
