import { CustomBlock } from '../custom-block';
import * as Blockly from 'blockly/core';
import { outdent } from 'outdent';

export class MnistDataBatchBlock implements CustomBlock {
  type = 'tfjs_mnist_data_batch';
  
  defineBlock(block) {
    let blockJson = {
      style: 'tfjs_blocks',
      message0: '%{BKY_TFJS_MNIST_DATA_BATCH_MESSAGE_0}',
      args0: [
        {
          type: 'field_variable',
          name: 'TENSOR',
          variable: 'tensor',
        },
        {
          type: 'field_variable',
          name: 'LABELS',
          variable: 'labels',
        },
        {
          type: 'field_variable',
          name: 'DATA',
          variable: 'data',
        },
      ],
      message1: '%{BKY_TFJS_MNIST_DATA_BATCH_MESSAGE_1}',
      args1: [
        {
          type: 'field_dropdown',
          name: 'TYPE',
          options: [
            ['%{BKY_TFJS_MNIST_DATA_BATCH_TEST}', 'test'],
            ['%{BKY_TFJS_MNIST_DATA_BATCH_TRAIN}', 'train'],
          ],
        },
      ],
      message2: '%{BKY_TFJS_MNIST_DATA_BATCH_MESSAGE_2}',
      args2: [
        {
          type: 'field_number',
          name: 'BATCH',
        },
      ],
      previousStatement: null,
      nextStatement: null,
    };

    block.jsonInit(blockJson);
  }

  toJavaScriptCode(block) {
    const tensor = block.getField('TENSOR').getText();
    const labels = block.getField('LABELS').getText();
    const data = block.getField('DATA').getText();
    const dataType = block.getFieldValue('TYPE') == 'test' ? 'Test' : 'Train';
    const batchNum = block.getFieldValue('BATCH');

    return outdent`
    [${tensor}, ${labels}] = tf.tidy(() => {
      const d = ${data}.next${dataType}Batch(${batchNum});
      return [
        d.xs.reshape([${batchNum}, 28, 28, 1]),
        d.labels
      ];
    });
    `;
  }
}
