import { Component, Output, EventEmitter, ViewChild } from '@angular/core';
import * as tf from '@tensorflow/tfjs';
import { Word2Vec } from './model';
import {
  preprocessText,
  generateFeaturesSkipGram,
  generateLabelsSkipGram,
  generateFeaturesCBOW,
  generateLabelsCBOW,
} from './preprocessing';
import { ChartConfiguration, ChartType } from 'chart.js';
import { BaseChartDirective } from 'ng2-charts';

@Component({
  selector: 'app-train-embedding-models',
  templateUrl: './train-embedding-models.component.html',
  styleUrls: ['./train-embedding-models.component.css'],
})
export class TrainEmbeddingModelsComponent {
  // Event emitter to emit words and trained embeddings to WordEmbeddingsComponent
  @Output()
  embeddingsTrained = new EventEmitter<object>();

  // Chart to display the steps while training the neural networks
  @ViewChild(BaseChartDirective) chart?: BaseChartDirective;

  lineChartData: ChartConfiguration['data'] = {
    datasets: [
      {
        data: [],
        label: 'Genauigkeit',
      },
    ],
    labels: [],
  };

  lineChartOptions: ChartConfiguration['options'] = {
    responsive: true,
    animation: false,
    elements: {
      line: {
        tension: 0,
      },
    },
  };

  lineChartType: ChartType = 'line';

  // Variable to keep track of the current step within the TrainEmbeddingModelsComponent
  steps: boolean[] = new Array(3).fill(false);

  currentlyLoading: boolean = false;

  // Variables related to loading the text data
  useCustomText: boolean = false;
  textLoaded: boolean = false;
  customText: string;
  text: string;
  textPreview: string;
  hideWarn: boolean = false;

  // Variables related to the selection of the model
  modelSelection: string;
  model: Word2Vec;
  playCBOWAnimation: boolean = false;
  playSkipGramAnimation: boolean = false;

  // Initial hyperparameter selections
  selectedBatchSize: number = 32;
  selectedNumEpochs: number = 1000;
  selectedEmbeddingDim: number = 32;

  // Available hyperparameter selection options
  availableEpochs: number[] = [100, 1000, 2500, 5000, 10000];
  availableBatchSizes: number[] = [8, 16, 32, 64, 128];
  availableEmbeddingDims: number[] = [32, 64, 128, 256];

  // Variables related to generate the training data
  sentences: string[];
  dictionary: any;
  vocabSize: number;
  windowSize: number;

  // Variables to store features and labels required for training
  features: tf.Tensor;
  labels: tf.Tensor;

  // Variables used to keep track of training progress
  trainingRunning: boolean = false;
  trainingDone: boolean = false;
  currentEpoch: number = 0;

  // Variable to store the trained embeddings
  embeddings: any;

  // Generates a shorter text preview if text is too long
  generateTextPreview(numWords: number): void {
    this.textPreview =
      this.text.split(' ').length > numWords
        ? this.text.split(' ').slice(0, numWords).join(' ') + '...'
        : this.text;
  }

  // Loads text and preprocesses it
  async loadText(): Promise<void> {
    this.currentlyLoading = true;

    if (this.useCustomText) {
      this.text = this.customText;
    } else {
      const response: any = await fetch('assets/word-embeddings/example.txt');
      this.text = await response.text();
    }

    this.generateTextPreview(50);

    // Preprocessing of the text by defining a context window size
    [this.sentences, this.dictionary, this.vocabSize, this.windowSize] =
      await preprocessText(this.text, 2);

    this.textLoaded = true;
    this.steps[0] = true;
    this.currentlyLoading = false;
  }

  // Generates training data for selected model
  async loadModelData(model: string): Promise<void> {
    this.currentlyLoading = true;
    this.modelSelection = model;
    // Clear features from memory if there are already training features generated
    if (this.features !== null && this.features !== undefined) {
      this.features.dispose();
      this.features = null;
    }

    // Clear features from memory if there are already training labels generated
    if (this.labels !== null && this.labels !== undefined) {
      this.labels.dispose();
      this.features = null;
    }

    if (this.modelSelection === 'cbow') {
      const features = await generateFeaturesCBOW(
        this.sentences,
        this.dictionary,
        this.vocabSize,
        this.windowSize
      );
      const labels = await generateLabelsCBOW(this.sentences, this.dictionary);

      this.features = tf.tidy(() => {
        return tf.tensor(features);
      });

      this.labels = tf.tidy(() => {
        return tf.oneHot(labels, this.vocabSize);
      });
    }

    if (this.modelSelection === 'skip-gram') {
      const features: number[] = await generateFeaturesSkipGram(
        this.sentences,
        this.dictionary,
        this.windowSize
      );

      const labels: number[] = await generateLabelsSkipGram(
        this.sentences,
        this.dictionary,
        this.windowSize
      );

      this.features = tf.tidy(() => {
        return tf.oneHot(features, this.vocabSize);
      });

      this.labels = tf.tidy(() => {
        return tf.oneHot(labels, this.vocabSize);
      });
    }

    this.steps[1] = true;
    this.currentlyLoading = false;
  }

  // Generate an instance of the model class
  loadModel(): void {
    this.model = new Word2Vec(this.vocabSize, this.selectedEmbeddingDim);
  }

  stopTraining(): void {
    if (this.model.model) {
      this.model.model.stopTraining = true;
    }
  }

  // Train the model
  async trainModel(): Promise<void> {
    if (!this.trainingRunning && this.features !== null && this.labels !== null) {
      this.loadModel();
      this.trainingRunning = true;
      this.steps[2] = true;
  
      this.lineChartData.labels = [];
      this.lineChartData.datasets[0].data = [];
  
      await this.model.model.fit(this.features, this.labels, {
        epochs: this.selectedNumEpochs,
        batchSize: this.selectedBatchSize,
        shuffle: false,
        callbacks: {
          onEpochEnd: async (ep, status) => {
            // Only update chart every 10 epochs
            if (ep % 10 === 0) {
              this.lineChartData.labels.push(ep.toString());
              this.lineChartData.datasets[0].data.push(status['acc']);
              this.chart?.update();
              this.currentEpoch = ep + 10;
              await tf.nextFrame();
            }
          },
        },
      });
  
      this.trainingRunning = false;
      this.embeddings = await this.model.getEmbeddings();
      this.trainingDone = true;
  
      this.submitEmbeddings();
    }
  }

  // Emit the trianed embeddings and corresponding words
  submitEmbeddings(): void {
    this.embeddingsTrained.emit({
      words: Object.keys(this.dictionary),
      embeddings: this.embeddings,
    });
  }

  constructor() {}
}
