import * as tf from "@tensorflow-models/coco-ssd";
import { objectClasses } from "./object-classes";

const DEBUG = false;

function log(message: any) {
  if (DEBUG) {
    console.log(message);
  }
}

let currentDetection: IDetectedObject = {
  name: "",
  score: 0,
  boundingBox: { x: 0, y: 0, width: 0, height: 0 },
};
let model: tf.ObjectDetection;

export const loadModel = async (
  modelType: tf.ObjectDetectionBaseModel = "lite_mobilenet_v2"
) => {
  log("loading ML model");
  try {
    const result = await tf.load({ base: modelType });
    model = result;
  } catch (error) {
    console.error(error);
  }
};

export const getDetection = (
  frame: ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement,
  numberOfDetections: number = 1
): IDetectedObject => {
  log("getting detection");
  model
    .detect(frame, numberOfDetections)
    .then((results) => {
      let result = results[0];
      if (result && filterDetections(result.class)) {
        currentDetection = convertDetectionData(
          frame.width,
          frame.height,
          result
        );
      }
    })
    .catch((error) => handleModelError);
  return currentDetection;
};

const convertDetectionData = (
  frameWidth: number,
  frameHeight: number,
  result: tf.DetectedObject
) => {
  // Converts detection data (which is stored as offsets from center of frame)
  // to image space rectangles drawn from their center

  const { bbox } = result;

  // BBox Index Values
  const X_POS = 0;
  const Y_POS = 1;
  const WIDTH = 2;
  const HEIGHT = 3;

  const width = bbox[WIDTH];
  const halfWidth = width * 0.5;

  const height = bbox[HEIGHT];
  const halfHeight = height * 0.5;

  const xCenter = frameWidth * 0.5;
  const xPos = xCenter + bbox[X_POS] + halfWidth;

  const yCenter = frameHeight * 0.5;
  const yPos = yCenter + bbox[Y_POS] + halfHeight;

  return {
    name: result.class,
    score: result.score,
    boundingBox: {
      x: xPos,
      y: yPos,
      width: width,
      height: height,
    },
  } as IDetectedObject;
};

const filterDetections = (name: string) => {
  return objectClasses.keys.includes(name);
};

const handleModelError = (error: any) => {
  console.error(error);
  if (model === undefined) {
    // should load default model
    loadModel();
  }
};

export interface IDetectedObject {
  name: string;
  score: number;
  boundingBox: IBoundingBox;
}

export interface IBoundingBox {
  x: number;
  y: number;
  width: number;
  height: number;
}
