import * as tf from '@tensorflow/tfjs';

export class Word2Vec {
  model: tf.LayersModel;

  constructor(vocabSize: number, embeddingDim: number) {
    this.model = this.createModel(vocabSize, embeddingDim);
  }

  // Creates the simplified Word2Vec model architectures as sequential model
  private createModel(vocabSize: number, embeddingDim: number): tf.LayersModel {
    const model = tf.sequential();

    const hiddenLayer = tf.layers.dense({
      units: embeddingDim,
      inputShape: [vocabSize],
      useBias: false,
    });

    model.add(hiddenLayer);

    const outputLayer = tf.layers.dense({
      units: vocabSize,
      useBias: false,
      activation: 'softmax',
    });

    model.add(outputLayer);

    model.compile({
      optimizer: 'sgd',
      loss: 'categoricalCrossentropy',
      metrics: 'accuracy'
    });

    return model;
  }

  // Returns the weights between input and hidden layer as embeddings
  public async getEmbeddings(): Promise<any> {
    const weights = await this.model.layers[0].getWeights()[0].array();

    return weights;
  }
}
