import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
import {
  CSS2DRenderer,
  CSS2DObject,
} from 'three/examples/jsm/renderers/CSS2DRenderer';
import { ElementRef, Injectable, NgZone } from '@angular/core';

@Injectable({ providedIn: 'root' })
export class EmbeddingProjectorService {
  // Variables that store the elements to render and control the scene
  private canvas: HTMLCanvasElement;
  private renderer: THREE.WebGLRenderer;
  private orthographicCamera: any;
  private perspectiveCamera: any;
  private scene: THREE.Scene;
  private controls: any;
  private raycaster: THREE.Raycaster;

  // Variable to keep track of the mouse position
  private pointer: THREE.Vector2 = new THREE.Vector2();

  // Variables to display a label above the embeddings
  private labelElement: any;
  private labelRenderer: any;
  private label: any;

  // Variables to help with intersection handling
  private previousInstanceId: number = -1;
  private matrix: THREE.Matrix4 = new THREE.Matrix4();
  private previousMatrix: THREE.Matrix4 = new THREE.Matrix4();
  private previousColor: THREE.Color = new THREE.Color();

  // Variable to keep track of the current dimension of the embeddings
  private mode: string = '2D';

  // Variables to store embeddings and words to be rendered in the embedding projector
  private embeddings: any;
  private words: string[];

  // Store to keep track of event listeners initialized by anonymous functions
  private events: Map<string, any> = new Map();

  private mesh: THREE.InstancedMesh;
  public color: THREE.Color = new THREE.Color();

  private frameId: number = null;

  // Destroys element in canvas and removes them from memory if necessary
  public destroy(): void {
    if (this.scene) {
      while (this.scene.children.length > 0) {
        const child: any = this.scene.children[0];
        this.scene.remove(child);
        // Dispose of the child if necessary (e.g., geometry, materials, textures)
        if (child instanceof THREE.Object3D) {
          child.traverse((obj) => {
            if (obj instanceof THREE.Mesh || obj instanceof THREE.Line) {
              obj.geometry.dispose();
              obj.material.dispose();
            }
          });
        }
      }
    }

    if (this.mesh) {
      this.mesh.geometry.dispose();
    }
  }

  public constructor(private ngZone: NgZone) {}

  // Generates and adds a 2D coordinate system
  generateCoordinateSystem(axisLength: number): void {
    const xAxisMaterial: THREE.LineBasicMaterial = new THREE.LineBasicMaterial({
      color: 0xff0000,
    });
    const yAxisMaterial: THREE.LineBasicMaterial = new THREE.LineBasicMaterial({
      color: 0x00ff00,
    });

    const xAxisPoints: any[] = [];
    xAxisPoints.push(new THREE.Vector3(0, 0, 0));
    xAxisPoints.push(new THREE.Vector3(axisLength, 0, 0));

    const yAxisPoints: any[] = [];
    yAxisPoints.push(new THREE.Vector3(0, 0, 0));
    yAxisPoints.push(new THREE.Vector3(0, axisLength, 0));

    const xAxisGeometry: THREE.BufferGeometry =
      new THREE.BufferGeometry().setFromPoints(xAxisPoints);
    const yAxisGeometry: THREE.BufferGeometry =
      new THREE.BufferGeometry().setFromPoints(yAxisPoints);

    const xAxis = new THREE.Line(xAxisGeometry, xAxisMaterial);
    const yAxis = new THREE.Line(yAxisGeometry, yAxisMaterial);

    this.scene.add(xAxis);
    this.scene.add(yAxis);

    xAxisMaterial.dispose();
    yAxisMaterial.dispose();
    xAxisGeometry.dispose();
    yAxisGeometry.dispose();
  }

  // Generates and adds a 3D coordinate system
  generateCoordinateSystem3D(axisLength: number): void {
    const xAxisMaterial: THREE.LineBasicMaterial = new THREE.LineBasicMaterial({
      color: 0xff0000,
    });
    const yAxisMaterial: THREE.LineBasicMaterial = new THREE.LineBasicMaterial({
      color: 0x00ff00,
    });
    const zAxisMaterial: THREE.LineBasicMaterial = new THREE.LineBasicMaterial({
      color: 0x0000ff,
    });

    const xAxisPoints: any[] = [];
    xAxisPoints.push(new THREE.Vector3(0, 0, 0));
    xAxisPoints.push(new THREE.Vector3(0, 0, axisLength));

    const yAxisPoints: any[] = [];
    yAxisPoints.push(new THREE.Vector3(0, 0, 0));
    yAxisPoints.push(new THREE.Vector3(axisLength, 0, 0));

    const zAxisPoints: any[] = [];
    zAxisPoints.push(new THREE.Vector3(0, 0, 0));
    zAxisPoints.push(new THREE.Vector3(0, axisLength, 0));

    const xAxisGeometry: THREE.BufferGeometry =
      new THREE.BufferGeometry().setFromPoints(xAxisPoints);
    const yAxisGeometry: THREE.BufferGeometry =
      new THREE.BufferGeometry().setFromPoints(yAxisPoints);
    const zAxisGeometry: THREE.BufferGeometry =
      new THREE.BufferGeometry().setFromPoints(zAxisPoints);

    const xAxis = new THREE.Line(xAxisGeometry, xAxisMaterial);
    const yAxis = new THREE.Line(yAxisGeometry, yAxisMaterial);
    const zAxis = new THREE.Line(zAxisGeometry, zAxisMaterial);

    this.scene.add(xAxis);
    this.scene.add(yAxis);
    this.scene.add(zAxis);

    xAxisMaterial.dispose();
    yAxisMaterial.dispose();
    zAxisMaterial.dispose();
    xAxisGeometry.dispose();
    yAxisGeometry.dispose();
    zAxisGeometry.dispose();
  }

  // Generates and adds 2D embeddings to the scene
  generateEmbeddings(multiplier: number): void {
    const geometry: THREE.CircleGeometry = new THREE.CircleGeometry(2, 32);
    const material: THREE.MeshBasicMaterial = new THREE.MeshBasicMaterial();
    this.mesh = new THREE.InstancedMesh(
      geometry,
      material,
      this.embeddings.length
    );

    this.scene.add(this.mesh);

    const dummy: THREE.Object3D = new THREE.Object3D();

    for (let i = 0; i < this.embeddings.length; i++) {
      const x: number = this.embeddings[i][0] * multiplier;
      const y: number = this.embeddings[i][1] * multiplier;
      const z: number = 1;

      dummy.position.x = x;
      dummy.position.y = y;
      dummy.position.z = z;

      const vx: number = x / multiplier + 0.5;
      const vy: number = y / multiplier + 0.5;
      const vz: number = z / multiplier + 0.5;

      this.color.setRGB(vx, vy, vz);

      dummy.updateMatrix();
      this.mesh.setMatrixAt(i, dummy.matrix);
      this.mesh.setColorAt(i, this.color);
    }

    geometry.dispose();
    material.dispose();
  }

  // Generates and adds 3D embeddings to the scene
  generateEmbeddings3D(multiplier: number) {
    const geometry: THREE.SphereGeometry = new THREE.SphereGeometry(1, 16, 8);
    const material: THREE.MeshBasicMaterial = new THREE.MeshBasicMaterial();
    this.mesh = new THREE.InstancedMesh(
      geometry,
      material,
      this.embeddings.length
    );

    this.scene.add(this.mesh);

    const dummy: THREE.Object3D = new THREE.Object3D();

    for (let i = 0; i < this.embeddings.length; i++) {
      const x: number = this.embeddings[i][0] * multiplier;
      const y: number = this.embeddings[i][1] * multiplier;
      const z: number = this.embeddings[i][2] * multiplier;

      dummy.position.x = x;
      dummy.position.y = y;
      dummy.position.z = z;

      dummy.scale.x = 2;
      dummy.scale.y = 2;
      dummy.scale.z = 2;

      const vx: number = this.embeddings[i][0] * 0.5 + 0.5;
      const vy: number = this.embeddings[i][1] * 0.5 + 0.5;
      const vz: number = this.embeddings[i][2] * 0.5 + 0.5;

      this.color.setRGB(vx, vy, vz);

      dummy.updateMatrix();
      this.mesh.setMatrixAt(i, dummy.matrix);
      this.mesh.setColorAt(i, this.color);
    }

    geometry.dispose();
    material.dispose();
  }

  // Creates a label that can be added to the scene
  initLabelRenderer(): void {
    // Create the label renderer
    this.labelRenderer = new CSS2DRenderer();
    this.labelRenderer.setSize(
      this.canvas.clientWidth,
      this.canvas.clientHeight
    );
    this.labelRenderer.domElement.style.position = 'absolute';
    this.labelRenderer.domElement.style.top = '0px';
    this.labelRenderer.domElement.style.pointerEvents = 'none';
    this.canvas.parentElement.appendChild(this.labelRenderer.domElement);

    // Create the label element
    this.labelElement = document.createElement('span');
    this.labelElement.textContent = '';
    this.labelElement.style.color = 'white';
    this.labelElement.style.backgroundColor = 'black';
    this.labelElement.style.zIndex = '1';
    this.labelElement.style.border = '2px solid white';
    this.labelElement.style.borderRadius = '4px';
    this.labelElement.style.padding = '8px';
    this.labelElement.style.cursor = 'pointer';

    // Create a CSS2DObject and add it to the scene
    this.label = new CSS2DObject(this.labelElement);
  }

  // Update the embeddings to be able to update the visualization after new training runs
  public updateWordEmbeddings(embeddings: number[][], words: string[]) {
    this.embeddings = embeddings;
    this.words = words;
  }

  // Initialize the embedding projector
  public init(
    canvas: ElementRef<HTMLCanvasElement>,
    embeddings: number[][],
    words: string[]
  ): void {
    this.canvas = canvas.nativeElement;
    this.words = words;
    this.embeddings = embeddings;

    this.renderer = new THREE.WebGLRenderer({
      canvas: this.canvas,
    });
    this.renderer.setSize(this.canvas.clientWidth, this.canvas.clientHeight);
    this.renderer.info.autoReset = false;

    this.scene = new THREE.Scene();

    // Create two cameras for 2D and 3D use
    this.perspectiveCamera = new THREE.PerspectiveCamera(
      45,
      this.canvas.clientWidth / this.canvas.clientHeight,
      1,
      10000
    );

    this.orthographicCamera = new THREE.OrthographicCamera(
      this.canvas.clientWidth / -2,
      this.canvas.clientWidth / 2,
      this.canvas.clientHeight / 2,
      this.canvas.clientHeight / -2,
      1,
      100
    );

    // Initialize with 2D orthographic camera
    this.orthographicCamera.position.set(0, 0, 100);
    this.orthographicCamera.lookAt(this.scene.position);
    this.orthographicCamera.updateMatrix();

    this.scene.add(this.orthographicCamera);

    this.controls = new OrbitControls(
      this.orthographicCamera,
      this.renderer.domElement
    );

    this.initLabelRenderer();

    this.load2D();

    this.raycaster = new THREE.Raycaster();

    this.animate();
  }

  // Load 2D self trained embeddings
  public load2D() {
    if (this.mode === '3D') {
      this.controls.reset();
      this.scene.remove(this.perspectiveCamera);
      this.scene.add(this.orthographicCamera);

      this.controls.object = this.orthographicCamera;

      this.orthographicCamera.position.set(0, 0, 100);
      this.orthographicCamera.lookAt(this.scene.position);
      this.orthographicCamera.updateMatrix();
    }

    this.controls.enablePan = true;
    this.controls.minZoom = 1;
    this.controls.maxZoom = 10;
    this.controls.enableRotate = false;

    this.generateCoordinateSystem(200);
    this.generateEmbeddings(200);
    this.scene.add(this.label);

    this.mode = '2D';
  }

  // Load 3D self trained embeddings
  public load3D() {
    if (this.mode === '2D') {
      this.controls.reset();
      this.scene.remove(this.orthographicCamera);
      this.scene.add(this.perspectiveCamera);

      this.controls.object = this.perspectiveCamera;
    }

    this.perspectiveCamera.position.set(1000, 1000, 1000);
    this.perspectiveCamera.lookAt(this.scene.position);
    this.perspectiveCamera.updateMatrix();

    this.controls.enablePan = true;
    this.controls.minDistance = 7.5;
    this.controls.maxDistance = 750;
    this.controls.enableRotate = true;

    this.generateCoordinateSystem3D(200);
    this.generateEmbeddings3D(200);
    this.scene.add(this.label);

    this.mode = '3D';
  }

  // Initialize the animation of the scene
  public animate(): void {
    this.ngZone.runOutsideAngular(() => {
      if (document.readyState !== 'loading') {
        this.render();
      } else {
        this.addListener(window, 'DOMContentLoaded', () => {
          this.render();
        });
      }

      this.addListener(window, 'pointermove', (e) => {
        this.onPointerMove(e);
      });

      this.addListener(window, 'resize', () => {
        this.onWindowResize();
      });
    });
  }

  // Helper function to add anonymous callback functions used for event listeners to the store
  addListener(element: any, event: string, callback: any): void {
    this.events.set(event, callback);
    element.addEventListener(event, callback);
  }

  // Helper function to remove anonymous callback functions used for event listeners from the store
  removeListener(element: any, event: string) {
    element.removeEventListener(event, this.events.get(event));
    this.events.delete(event);
  }

  // Function used to resize canvas on window resive
  public onWindowResize(): void {
    // Width and height are determined by a sibling element of the canvas
    const width = this.canvas.parentElement.children[0].clientWidth;
    const height = this.canvas.parentElement.children[0].clientHeight;

    if (this.mode === '2D') {
      this.orthographicCamera.left = width / -2;
      this.orthographicCamera.right = width / 2;
      this.orthographicCamera.top = height / 2;
      this.orthographicCamera.bottom = height / -2;
      this.orthographicCamera.updateProjectionMatrix();
    } else {
      this.perspectiveCamera.aspect = width / height;
      this.perspectiveCamera.updateProjectionMatrix();
    }

    // Update the size of the renderer and the canvas
    this.renderer.setSize(width, height);
    this.labelRenderer.setSize(width, height);
  }

  // Function used to keep track of the mouse pointer on the canvas
  public onPointerMove(e: PointerEvent): void {
    const rect: DOMRect = this.canvas.getBoundingClientRect();
    this.pointer.x =
      ((e.clientX - rect.left) / (rect.right - rect.left)) * 2 - 1;
    this.pointer.y =
      -((e.clientY - rect.top) / (rect.bottom - rect.top)) * 2 + 1;
  }

  // Render function that updates the scene and call intersection handling function
  public render(): void {
    this.frameId = window.requestAnimationFrame(() => {
      this.render();
    });

    if (this.mode === '2D') {
      this.handleIntersection();
    } else {
      this.handleIntersection3D();
    }

    this.controls.update();

    if (this.mode === '2D') {
      this.labelRenderer.render(this.scene, this.orthographicCamera);
      this.renderer.render(this.scene, this.orthographicCamera);
    } else {
      this.labelRenderer.render(this.scene, this.perspectiveCamera);
      this.renderer.render(this.scene, this.perspectiveCamera);
    }
  }

  // Handle the intersection of the embeddings with the mouse pointer in 2D
  handleIntersection(): void {
    this.raycaster.setFromCamera(this.pointer, this.orthographicCamera);

    const intersection: any = this.raycaster.intersectObject(this.mesh);
    let instanceId: number;

    if (intersection.length === 0) {
      this.mesh.setColorAt(this.previousInstanceId, this.previousColor);
      this.mesh.setMatrixAt(this.previousInstanceId, this.previousMatrix);
      this.mesh.instanceColor.needsUpdate = true;
      this.mesh.instanceMatrix.needsUpdate = true;
      this.labelElement.style.opacity = '0';
      this.previousInstanceId = -1;
    }

    if (intersection.length > 0) {
      instanceId = intersection[0].instanceId;

      if (instanceId !== this.previousInstanceId) {
        this.mesh.setColorAt(this.previousInstanceId, this.previousColor);
        this.mesh.setMatrixAt(this.previousInstanceId, this.previousMatrix);

        this.mesh.getMatrixAt(instanceId, this.matrix);
        this.label.position.set(
          this.matrix.elements[12],
          this.matrix.elements[13] +
            Math.log(10 / this.orthographicCamera.zoom) * 10 +
            5,
          this.matrix.elements[14]
        );

        this.labelElement.textContent = this.words[instanceId];

        this.labelElement.style.opacity = '1';
        this.mesh.setColorAt(this.previousInstanceId, this.previousColor);
        this.mesh.getColorAt(instanceId, this.previousColor);
        this.mesh.setColorAt(instanceId, this.color.setRGB(1, 1, 1));

        this.mesh.setMatrixAt(this.previousInstanceId, this.previousMatrix);

        const dummy: THREE.Object3D = new THREE.Object3D();
        dummy.position.x = this.matrix.elements[12];
        dummy.position.y = this.matrix.elements[13];
        dummy.position.z = 1;
        dummy.updateMatrix();
        this.previousMatrix = dummy.matrix.clone();
        dummy.position.z = 1.1;
        dummy.scale.x *= 1.2;
        dummy.scale.y *= 1.2;
        dummy.scale.z *= 1.2;
        dummy.updateMatrix();
        this.mesh.setMatrixAt(instanceId, dummy.matrix);

        this.previousInstanceId = instanceId;

        this.mesh.instanceColor.needsUpdate = true;
        this.mesh.instanceMatrix.needsUpdate = true;
      }
    }
  }

  // Handle the intersection of the embeddings with the mouse pointer in 3D
  handleIntersection3D(): void {
    const distance: number = this.perspectiveCamera.position.distanceTo(
      this.controls.target
    );

    this.raycaster.setFromCamera(this.pointer, this.perspectiveCamera);

    const intersection: any = this.raycaster.intersectObject(this.mesh);

    if (intersection.length === 0) {
      this.mesh.setColorAt(this.previousInstanceId, this.previousColor);
      this.previousInstanceId = -1;
      this.mesh.instanceColor.needsUpdate = true;
      this.labelElement.style.opacity = '0';
    }

    if (intersection.length > 0) {
      const instanceId: number = intersection[0].instanceId;

      if (instanceId !== this.previousInstanceId) {
        this.mesh.getMatrixAt(instanceId, this.matrix);
        this.label.position.set(
          this.matrix.elements[12],
          this.matrix.elements[13] + 5 + distance / 50,
          this.matrix.elements[14]
        );

        this.labelElement.textContent = this.words[instanceId];
        this.labelElement.style.opacity = '1';

        this.mesh.setColorAt(this.previousInstanceId, this.previousColor);
        this.mesh.getColorAt(instanceId, this.previousColor);
        this.mesh.setColorAt(instanceId, this.color.setRGB(1, 1, 1));
        this.previousInstanceId = instanceId;

        this.mesh.instanceColor.needsUpdate = true;
      }
    }
  }
}
