import * as posenet from '@tensorflow-models/posenet';
import { Vector2D } from '@tensorflow-models/posenet/dist/types';

const MIN_CONFIDENCE = 0.3;

export async function load() {
  return await posenet.load({
    architecture: 'ResNet50',
    outputStride: 32,
    inputResolution: { width: 257, height: 200 },
    quantBytes: 2
  });
}

interface PoseRule {
  conditions: PoseCondition[],
  pose: string
};

interface PoseCondition {
  startPart: string;
  endPart: string;
  angle: number;
}

const angleDelta = 15;
const poseRules: PoseRule[] = [
  {
    conditions: [
      {
        startPart: 'leftHip',
        endPart: 'leftKnee',
        angle: 180
      },
      {
        startPart: 'rightHip',
        endPart: 'rightKnee',
        angle: 0
      },
    ],
    pose: 'squat'
  },
  {
    conditions: [
      {
        startPart: 'leftHip',
        endPart: 'leftKnee',
        angle: 270
      },
      {
        startPart: 'rightHip',
        endPart: 'rightKnee',
        angle: 270
      },
    ],
    pose: 'stand'
  },
  {
    conditions: [
      {
        startPart: 'leftHip',
        endPart: 'leftKnee',
        angle: 0
      },
      {
        startPart: 'leftKnee',
        endPart: 'leftAngle',
        angle: 270
      },
    ],
    pose: 'sit'
  },
  {
    conditions: [
      {
        startPart: 'rightHip',
        endPart: 'rightKnee',
        angle: 180
      },
      {
        startPart: 'rightKnee',
        endPart: 'rightAngle',
        angle: 270
      },
    ],
    pose: 'sit'
  },
  {
    conditions: [
      {
        startPart: 'leftElbow',
        endPart: 'leftWrist',
        angle: 90
      }
    ],
    pose: 'liftLeft'
  },
  {
    conditions: [
      {
        startPart: 'rightElbow',
        endPart: 'rightWrist',
        angle: 90
      }
    ],
    pose: 'liftRight'
  },
  {
    conditions: [
      {
        startPart: 'leftElbow',
        endPart: 'leftWrist',
        angle: 135
      }
    ],
    pose: 'salute'
  },
  {
    conditions: [
      {
        startPart: 'rightElbow',
        endPart: 'rightWrist',
        angle: 45
      }
    ],
    pose: 'salute'
  },
  {
    conditions: [
      {
        startPart: 'leftShoulder',
        endPart: 'leftHip',
        angle: 315
      }
    ],
    pose: 'bow'
  },
  {
    conditions: [
      {
        startPart: 'leftShoulder',
        endPart: 'leftHip',
        angle: 225
      }
    ],
    pose: 'bow'
  },
  {
    conditions: [
      {
        startPart: 'rightShoulder',
        endPart: 'rightHip',
        angle: 315
      }
    ],
    pose: 'bow'
  },
  {
    conditions: [
      {
        startPart: 'rightShoulder',
        endPart: 'rightHip',
        angle: 225
      }
    ],
    pose: 'bow'
  }
]

export async function detect(model: posenet.PoseNet, input: HTMLImageElement | HTMLVideoElement):
  Promise<posenet.Pose[]> {
  const config = {
    flipHorizontal: false,
    maxDetections: 5,
    scoreThreshold: MIN_CONFIDENCE,
    nmsRadius: 20
  };

  const poses = (await model.estimateMultiplePoses(input, config)).filter(pose => pose.score >= MIN_CONFIDENCE);
  poses.forEach(pose => {
    pose.keypoints.forEach(keypoint => {
      if (keypoint.score < MIN_CONFIDENCE) {
        keypoint.position.x = NaN;
        keypoint.position.y = NaN;
      }
    });
  });
  return poses;
}

export function renderPredictions(predictions: posenet.Pose[], input: HTMLImageElement | HTMLVideoElement) {
  const canvas = document.getElementById('output') as HTMLCanvasElement;
  const context = canvas.getContext('2d');

  context.clearRect(0, 0, canvas.width, canvas.height);
  canvas.width = input.width;
  canvas.height = input.height;

  context.save();
  context.drawImage(input, 0, 0, input.width, input.height);
  context.restore();

  context.font = '24px Arial';
  predictions.forEach(prediction => {
    console.log(prediction);
    const keypoints = prediction.keypoints.filter(keypoint => keypoint.score >= MIN_CONFIDENCE);
    const minX = Math.min(...keypoints.map(keypoint => keypoint.position.x));
    const maxX = Math.max(...keypoints.map(keypoint => keypoint.position.x));
    const minY = Math.min(...keypoints.map(keypoint => keypoint.position.y));
    const maxY = Math.max(...keypoints.map(keypoint => keypoint.position.y));

    context.strokeStyle = '#00ffff';
    context.lineWidth = 6;
    const padding = 10;
    context.rect(minX - padding, minY - padding, maxX - minX + 2 * padding, maxY - minY + 2 * padding);
    context.font = '24px Arial';
    context.fillText(this.poseEstimation(prediction).join(', '), minX > 0 ? minX : 0, minY > 0 ? minY : 0);
    context.stroke();

    keypoints.filter(keypoint => keypoint.score >= MIN_CONFIDENCE)
      .forEach(keypoint => {
        context.strokeStyle = '#00ffff';
        context.fillStyle = '#00ffff';
        context.beginPath();
        context.arc(keypoint.position.x, keypoint.position.y, 12, 0, 2 * Math.PI);
        context.fill();
      });

    const skeletons = posenet.getAdjacentKeyPoints(prediction.keypoints, MIN_CONFIDENCE);
    skeletons.forEach(skeleton => {
      context.strokeStyle = '#00ffff';
      context.beginPath();
      context.lineWidth = 10;
      context.moveTo(skeleton[0].position.x, skeleton[0].position.y);
      context.lineTo(skeleton[1].position.x, skeleton[1].position.y);
      context.stroke();
    });
  });
}

export function poseEstimation(pose: posenet.Pose): string[] {
  const result = poseRules.filter(rule => isMatchPoseRule(pose, rule)).map(rule => rule.pose);
  return [... new Set(result)];
}

function isMatchPoseRule(pose: posenet.Pose, poseRule: PoseRule) {
  return poseRule.conditions.every(
    condition => {
      const angle = getHorizontalAngleOfConnectParts(pose, condition.startPart, condition.endPart);

      if (isNaN(angle)) {
        return false;
      }

      return angle <= condition.angle + angleDelta && angle >= condition.angle - angleDelta
    });
}

function getHorizontalAngleOfConnectParts(pose: posenet.Pose, startPart: string, endPart: string) {

  if (pose.keypoints[posenet.partIds[startPart]] === undefined) {
    return NaN;
  }

  if (pose.keypoints[posenet.partIds[endPart]] === undefined) {
    return NaN;
  }


  return horizontalAngle(pose.keypoints[posenet.partIds[startPart]].position,
    pose.keypoints[posenet.partIds[endPart]].position);
}

function horizontalAngle(start: Vector2D, end: Vector2D) {
  if (isNaN(start.x) || isNaN(start.y) || isNaN(end.x) || isNaN(end.x)) {
    return NaN;
  }

  const angle = - Math.atan2(end.y - start.y, end.x - start.x) * 180 / Math.PI;

  if (angle < 0) {
    return angle + 360;
  }
  else {
    return angle;
  }
}