/*
 
  KEMTAI
  
*/


import {RecoverCoords,  RecoverCoordsInputs, KernelConfig, KernelFunc, TensorInfo} from '../../tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {GPGPUProgram} from '../gpgpu_math';
import {reshape} from './Reshape';
import {argMax} from './ArgMax';

export class RecoverCoordsProgram implements GPGPUProgram {
  variableNames = ['Heatmap', 'Offset','ArgMax'];
  outputShape: number[] = [];
  userCode: string;
  packedOutput = false
  packedInputs = true

  constructor(   batch: number,width: number,height: number,
                 num_points: number) {

    this.outputShape  = [batch, num_points*2]

    
    this.userCode = `

    void main() {
        ivec2 coords = getOutputCoords();
        int batch = coords[0];
        int point = coords[1] / 2;
        //bool isY    = bool(imod(coords[1],2));
        bool isY    = mod(float(coords[1]),2.0) != 0.0;

        int idx1D = int(getChannel(getArgMax(batch,point),vec2(batch,point)));
        int idx_x   =  idx1D / ${width};
        int idx_y   =  idx1D - idx_x * ${width};
        float p = getChannel(getHeatmap(batch,idx_x,idx_y,point),vec2(idx_y,point));
        float d = (getChannel(getOffset(batch,idx_x,idx_y,point*2),vec2(idx_y,!isY))- 0.5)*6.0;
        setOutput(p>0.35 ? (d + (isY ? float(idx_x) : float(idx_y))) : -100. ) ;
      }
    `;
  }
}



export const recoverCoords = (args: {
  inputs: RecoverCoordsInputs,
  backend: MathBackendWebGL
}): TensorInfo => {
  const {inputs, backend} = args;
  const {heatmap, offset} = inputs;

   
  const [batch,width,height,num_points] = heatmap.shape 
  
  const heatmapReshaped = reshape({inputs: {x: heatmap}, backend, attrs: {shape: [-1, num_points]}});
  const heatmapArgMax   = argMax({inputs: {x: heatmapReshaped}, backend, attrs: {axis: 0}});
  const heatmapArgMax2  = reshape({inputs: {x: heatmapArgMax}, backend, attrs: {shape: [batch, num_points]}});

  const program = new RecoverCoordsProgram(batch,width,height,num_points);
  const res = backend.runWebGLProgram(program, [heatmap, offset, heatmapArgMax2], 'float32');

  backend.disposeIntermediateTensorInfo(heatmapReshaped);
  backend.disposeIntermediateTensorInfo(heatmapArgMax);
  backend.disposeIntermediateTensorInfo(heatmapArgMax2);

  return res
};

export const recoverCoordsConfig: KernelConfig = {
  kernelName: RecoverCoords,
  backendName: 'webgl',
  kernelFunc: recoverCoords as {} as KernelFunc
};
