import { CustomBlock } from '../custom-block';
import * as Blockly from 'blockly/core';
import { outdent } from 'outdent';

export class CompileModelBlock implements CustomBlock {
  type = 'tfjs_compile_model';
  
  defineBlock(block) {
    let blockJson = {
      style: 'tfjs_blocks',
      message0: '%{BKY_TFJS_MODEL_COMPILE_MODEL_MESSAGE_0}',
      args0: [
        {
          type: 'field_variable',
          name: 'MODEL',
          variable: 'model'
        },
      ],
      message1: '%{BKY_TFJS_MODEL_COMPILE_MODEL_MESSAGE_1}',
      args1: [
        {
          type: 'field_dropdown',
          name: 'OPTIMIZER',
          options: [
            ['sgd', 'sgd'],
            ['momentum', 'momentum'],
            ['adam', 'adam']
          ]
        },
        {
          type: 'field_number',
          name: 'RATE',
          value: 0.01,
        }
      ],
      message2: '%{BKY_TFJS_MODEL_COMPILE_MODEL_MESSAGE_2}',
      args2: [
        {
          type: 'field_dropdown',
          name: 'LOSSES',
          options: [
            ['%{BKY_TFJS_MODEL_COMPILE_MODEL_MEAN_SQUARED_ERROR}', 'meanSquaredError'],
            ['%{BKY_TFJS_MODEL_COMPILE_MODEL_MEAN_ABSOLUTE_ERROR}', 'meanAbsoluteError'],
            ['%{BKY_TFJS_MODEL_COMPILE_MODEL_CATEGORICAL_CROSSENTROPY}', 'categoricalCrossentropy'],
          ]
        }
      ],
      message3: '%{BKY_TFJS_MODEL_COMPILE_MODEL_MESSAGE_3}',
      args3: [
        {
          type: 'input_value',
          name: 'METRICS',
        }
      ],

      previousStatement: null,
      nextStatement: null,
    };

    block.jsonInit(blockJson);
  }

  toJavaScriptCode(block) {
    const model = block.getField('MODEL').getText();
    const optimizer = block.getFieldValue('OPTIMIZER');
    const rate = block.getFieldValue('RATE');
    const losses = block.getFieldValue('LOSSES');
    const metrics = Blockly.JavaScript.valueToCode(block, 'METRICS', Blockly.JavaScript.ORDER_NONE) || '[]';

    return outdent`
    ${model}.compile({
      optimizer: tf.train.${optimizer}(${rate}),
      loss: '${losses}',
      metrics: ${metrics},
    });
    `;

  }
}
