
import {  tf } from './my-tf'
import { BBox, KeypointsArray } from "@kemtai/keypoints"
import { sleep } from '@kemtai/utils'
import { ModelManip } from './model-manip'
import modelTest from './modelTest'


tf.enableProdMode()


let offset : tf.Scalar | undefined = undefined
let factor : tf.Scalar | undefined = undefined


try {
    // Normalize the image from [0, 255] to [-1, 1].
    offset = tf.scalar(127.5);
    factor = tf.scalar(-118.7/127.5);
    //const factor = tf.scalar(1/127.5);
} catch {
    // no tf
}



function normaizeImage(img:tf.Tensor3D):tf.Tensor3D
{
    if (offset !== undefined && factor !== undefined) {
        return img.sub(offset).mul(factor)
    } else {
        throw "No TF!!!!"
    }
}


function median(arr:number[]) {
    arr = arr.sort();
    if (arr.length % 2 === 0) { // array with even number elements
      return (arr[arr.length/2] + arr[(arr.length / 2) - 1]) / 2;
    }
    else {
      return arr[(arr.length - 1) / 2]; // array with odd number elements
    }
};

function getPaddingParams(h:number,w:number) : [number,number,number]
{  if (w === h ) {
      return [w,0,0];
  } else if (h<w) {
      const d = (w - h)/2
      return [w,0,d]
  } else {
      const d = (h - w)/2
      return [h,d,0]
  }
}


function padToSquare(img : tf.Tensor3D) : tf.Tensor3D
{
  const [h,w,_] = img.shape
  if (w === h ) {
      return img
  } else if (h<w) {
      const d = (w - h)/2
      return tf.pad(img,[[d,d],[0,0],[0,0]])
  } else {
      const d = (h - w)/2
      return tf.pad(img,[[0,0],[d,d],[0,0]])
  }
}

/*
export class PredictHistory {
    kp    : Keypoints|null=null
    count : number=0

    get bbox():BBox|null{
        if (this.kp) {
            let bb = this.kp.bbox()
            if (bb) {
                return bb.square().enlarge(1.2)
            }
        }
        return null
    }

    check(kp:Keypoints|null): boolean {
        if (!kp) {
            return false
        }
        if (kp.points().length<7) {
            return false
        }
        if (this.count > 10) {
            return false;
        }
        if (this.kp) {
            let new_points:Set<string> = new Set(kp.points())
            let missing = this.kp.points().filter(p => !new_points.has(p));
            if (missing.length>0) {
                return false;
            }
        }
        return true;
    }


    newPrediction(kp:Keypoints|null) {
        if (this.check(kp)) {
            this.kp = kp
            this.count ++
        } else {
            this.kp = null
            this.count = 0
        }
    }
}
*/

export class SkeletonModel {
    remotePath : string;
    localPath  : string;
    model : tf.LayersModel | null = null;
    count : number = 0;
    protected InputImageSize :number = 256;
    protected SizeReduction :number = 16;
    protected NRows:number = 16;
    protected NCols:number = 16;
    protected NPoints:number = 111;
    result : KeypointsArray = new KeypointsArray()

    constructor (remotePath:string){
        this.remotePath = remotePath;
        this.localPath  = SkeletonModel.remoteToLocal(remotePath)
        //console.log("***",this.remotePath,this.localPath)
    }

    private initSize() {
        const inputShape : tf.Shape = this.model!.getLayer(undefined,0).inputSpec[0].shape!
        const inputSize:number = inputShape[1]!
        const outputShape = this.model!.getLayer("heatmap").outputShape as tf.Shape
        const outputSize = outputShape[1]!
        
        this.InputImageSize = inputSize
        this.SizeReduction = inputSize/outputSize
        this.NRows =  outputSize
        this.NCols =  outputSize
        this.NPoints = outputShape[3]!
    }

    dispose(){
        if (this.model){
            this.model.dispose()
            this.model = null
        }
    }

    static remoteToLocal(model:string) : string {
        const p1 = model.replace(/https?:..[^\/]+\//,"indexeddb://")
        const p2 = p1.replace(/\/model.json$/,"")
        return p2;
    }
    
    async copyToLocal(){
        const mdls = await tf.io.listModels()
        if (this.localPath in mdls){
            //console.log(`Found local ${this.localPath}`)
            return true
        }

        //console.log("Loading remote model from network",this.localPath)
        await tf.io.copyModel(this.remotePath, this.localPath);

        const mdls2 = await tf.io.listModels()
        if (this.localPath in mdls2){
            return true
        }
        return false
    }

    async reloadModel() {                
        console.log("Reload",this.localPath)
        this.dispose()
        await tf.io.removeModel(this.localPath)
        await this.copyToLocal()
        await sleep(10)        
    }

    async loadModel(fuse:boolean=true, retry:number=3) {
        const image = await modelTest.getTestImage()
        await sleep(0)
        
        const localOk = await this.copyToLocal()
        await sleep(0)

        for(let i=0;i<retry;i++) {
            //console.log(`loadModel: loadLocal...`)
            const ok = await this.loadLocal(fuse)
            await sleep(0)

            if (ok) {
                //console.log(`loadModel: predictImage 1...`)
                const kpa = await this.predictImageNice(image)
                //console.log(`loadModel: predictImage 2...`)
                await sleep(0)

                const n = kpa ? kpa.keypoints.points().length : 0
                //console.log(`loadModel: got ${n} points (${this.localPath})`)

                if (modelTest.check(n)) {
                    return true
                } 

                //console.log(`loadModel: got ${n} points instead ${modelTest.nPoints} (retry# ${i})`)
                await this.reloadModel()   
            }
        }
        await sleep(0)
        const ok = await this.loadLocal(fuse)
        return false
    }

    async loadLocal(fuse:boolean=true) {
        console.log(`skeleton-model : loading local model: ${this.localPath} to GPU`)
        let status = false

        for (let retry = 0 ; retry < 3; retry ++) {
            try {
              //console.log("skeleton-model : loadLocal 1")
               await sleep(0)
               const origModel = await tf.loadLayersModel(this.localPath);
               await sleep(0)
               //console.log("skeleton-model : loadLocal 2")

               //this.model = origModel

               if(fuse){
                    const fusedModel =new ModelManip(origModel).fuseResiduals()
                    this.model = fusedModel;
                    origModel.dispose()
               } else {
                    this.model = origModel;
               }
               //console.log("skeleton-model : loadLocal 3")

               await sleep(0)
               console.log(`skeleton-model : model successfully loaded`)
               this.initSize()
               return true
            } catch (e) {
                console.error(`skeleton-model :failed loading model retry#${retry}:`,e)
                await this.reloadModel()
            }
        }
        console.error(`skeleton-model :failed loading model ${this.remotePath}`)
        return false
    }


    async loadModelNode() {
        //console.log(`skeleton-model : loading model: ${this.remotePath} ...`)
        this.model = await tf.loadLayersModel(this.remotePath);
        //console.log(`skeleton-model : model successfully loaded`)
        this.initSize()
        return true
    }


    isModelReady() {
        return this.model !== null
    }

    /*
    hpredictImageData = async (im:ImageData, h:PredictHistory) => {
        const t = tf.browser.fromPixels(im)
        const kp = await this.hpredict(t,h)
        t.dispose()
        //debug(tf.memory())
        return kp
    }
    */

    predictImage =  (im:ImageData|ImageBitmap) => {
        //let start = Date.now()
        const t = tf.browser.fromPixels(im)
        const kp =  this.predictNew(t)
        t.dispose()
        return kp
    }

    async predictImageNice(im:ImageData|ImageBitmap) {
        //console.log("predictImage1 0")
        await sleep(0)
        //console.log("predictImage1 1")
        const t = await tf.browser.fromPixelsAsync(im)
        await sleep(0)
        const kp =  await this.predictNice(t)
        await sleep(0)
        t.dispose()
        return kp
    }


    /*
    newHistory(){
        return new PredictHistory()
    }


    hpredict =  (t:tf.Tensor3D, h:PredictHistory) => {
        let bb = h.bbox
        let kp:Keypoints|null = null
        if(bb){
            const [Y,X,__] = t.shape
            bb = bb.overlap(0,0,X,Y)
            kp =  this.predict(t,bb)

        } else {
            kp =  this.predict(t)
        }

        h.newPrediction(kp)
        return kp
    }
    */

    predict(im:tf.Tensor3D, slice_box:undefined|BBox=undefined) {
        if (this.model === null) {
            return null
        }

        this.count ++;

        let [padded_size, pad_x, pad_y] = [0,0,0]
        if (slice_box) {
            [padded_size, pad_x, pad_y] = getPaddingParams(slice_box.iHeight,slice_box.iWidth)
        } else {
            [padded_size, pad_x, pad_y] = getPaddingParams(im.shape[0],im.shape[1])
        }

        const output_tensors = tf.tidy(() => {
            if (slice_box) {
                im = im.slice([slice_box.iy1,slice_box.ix1],[slice_box.iHeight,slice_box.iWidth])
            }
            const padded = padToSquare(im);
            const normalized = normaizeImage(padded);
            const reshaped = normalized.resizeNearestNeighbor([this.InputImageSize,this.InputImageSize]).toFloat();
            // Reshape to a single-element batch so we can pass it to predict.
            const batched = reshaped.reshape([1, this.InputImageSize, this.InputImageSize, 3]);

            const res = this.model!.predict(batched);
            return res
        }) as tf.Tensor<tf.Rank.R4>[];

        const heatmap =  output_tensors[0].arraySync()
        const offset  =  output_tensors[1].arraySync()

        //const kp = this.recover_coords2(heatmap[0], offset[0], padded_size, pad_x, pad_y)
        this.recoverCoords(heatmap[0], offset[0], padded_size, pad_x, pad_y, slice_box)
        tf.dispose(output_tensors)
        return this.result
    }


    async predictNice(im:tf.Tensor3D){
        if (this.model === null) {
            return null
        }

        const [padded_size, pad_x, pad_y] = getPaddingParams(im.shape[0],im.shape[1])
        await sleep(0)
        const padded = padToSquare(im);     
        await sleep(0)
        const normalized = normaizeImage(padded);     
        padded.dispose(); 
        await sleep(0)
        const reshaped = normalized.resizeNearestNeighbor([this.InputImageSize,this.InputImageSize]).toFloat();
        normalized.dispose() ; 
        await sleep(0)
        // Reshape to a single-element batch so we can pass it to predict.
        const batched = reshaped.reshape([1, this.InputImageSize, this.InputImageSize, 3]);
        reshaped.dispose() ; 
        await sleep(0)
        //console.log("predictNewAsync 2")
        const [H,O] = this.model!.predictOnBatch(batched) as tf.Tensor4D[]
        await sleep(0)
        batched.dispose() ; 
        //console.log("predictNewAsync 3")

        const res = tf.recoverCoords(H, O) as tf.Tensor2D
        //im.dispose() 
        H.dispose() ; O.dispose() ; 
        await sleep(0)

        //console.log("predictNewAsync 4")

        const rawCoords = await res.array()
        //console.log("predictNewAsync 5")

        this.recoverCoords2(rawCoords[0], padded_size, pad_x, pad_y, undefined)
        //console.log("predictNewAsync 6")

        await sleep(0)

        tf.dispose(res)
        //console.log("predictNewAsync 5",tf.memory())
        return this.result
    }


    predictNew(im:tf.Tensor3D, slice_box:undefined|BBox=undefined) {

        if (this.model === null) {
            return null
        }

        this.count ++;

        let [padded_size, pad_x, pad_y] = [0,0,0]
        if (slice_box) {
            [padded_size, pad_x, pad_y] = getPaddingParams(slice_box.iHeight,slice_box.iWidth)
        } else {
            [padded_size, pad_x, pad_y] = getPaddingParams(im.shape[0],im.shape[1])
        }

        const res = tf.tidy(() => {
            if (slice_box) {
                im = im.slice([slice_box.iy1,slice_box.ix1],[slice_box.iHeight,slice_box.iWidth])
            }
            const padded = padToSquare(im);
            const normalized = normaizeImage(padded);
            const reshaped = normalized.resizeNearestNeighbor([this.InputImageSize,this.InputImageSize]).toFloat();
            // Reshape to a single-element batch so we can pass it to predict.
            const batched = reshaped.reshape([1, this.InputImageSize, this.InputImageSize, 3]);

            const [H,O] = this.model!.predictOnBatch(batched) as tf.Tensor4D[]
            const res = tf.recoverCoords(H, O) as tf.Tensor2D
            return res
        })

        const rawCoords = res.dataSync() as Float32Array
        this.recoverCoords3(rawCoords, padded_size, pad_x, pad_y, slice_box)
        
        /*
        const rawCoords = res.arraySync() 

        this.recoverCoords2(rawCoords[0], padded_size, pad_x, pad_y, slice_box)
        */

        tf.dispose(res)
        return this.result
    }


    async predictNew2(im:tf.Tensor3D, slice_box:undefined|BBox=undefined) {

        if (this.model === null) {
            return null
        }

        this.count ++;

        let [padded_size, pad_x, pad_y] = [0,0,0]
        if (slice_box) {
            [padded_size, pad_x, pad_y] = getPaddingParams(slice_box.iHeight,slice_box.iWidth)
        } else {
            [padded_size, pad_x, pad_y] = getPaddingParams(im.shape[0],im.shape[1])
        }

        const res = tf.tidy(() => {
            if (slice_box) {
                im = im.slice([slice_box.iy1,slice_box.ix1],[slice_box.iHeight,slice_box.iWidth])
            }
            const padded = padToSquare(im);
            const normalized = normaizeImage(padded);
            const reshaped = normalized.resizeNearestNeighbor([this.InputImageSize,this.InputImageSize]).toFloat();
            // Reshape to a single-element batch so we can pass it to predict.
            const batched = reshaped.reshape([1, this.InputImageSize, this.InputImageSize, 3]);

            const [H,O] = this.model!.predictOnBatch(batched) as tf.Tensor4D[]
            const res = tf.recoverCoords(H, O) as tf.Tensor2D
            return res
        })

        const rawCoords = await res.array()

        this.recoverCoords2(rawCoords[0], padded_size, pad_x, pad_y, slice_box)

        tf.dispose(res)
        return this.result
    }


    /*
    recover_coords2(heatmap:number[][][], offset:number[][][], size:number, pad_dx:number, pad_dy:number):Keypoints {
        const factor = size / this.InputImageSize;
        const result = new Keypoints();
        let point,x,y,p:number;
        let max_y, max_x, max_p:number;
        for (point = 0; point < this.NPoints; point++) {
            max_y=0;
            max_x=0;
            max_p = heatmap[0][0][point];

            for(y = 0; y < this.NRows; y++) {
                for(x = 0; x < this.NCols; x++) {
                p = heatmap[y][x][point]
                if (p > max_p) {
                    max_y=y; max_x=x; max_p = p
                }
                }
            }
            if  (max_p > 0.2) {
                const dx = offset[max_y][max_x][point*2+1]
                const dy = offset[max_y][max_x][point*2]

                const res_x = (max_x+dx) * this.SizeReduction * factor - pad_dx;
                const res_y = (max_y+dy) * this.SizeReduction * factor - pad_dy;
                const res_point = new Point(res_x,res_y);

                result.set(AllModelPoints[point], res_point);
            }
        }
        return result;
    }
    */


   recoverCoords(heatmap:number[][][], offset:number[][][], size:number, pad_dx:number, pad_dy:number, slice_box:BBox | undefined) {

        const factor = size / this.InputImageSize;
        let point,x,y,p:number;
        let max_y, max_x, max_p:number;

        this.result.reset()
        if (slice_box) {
            pad_dx -= slice_box.x1
            pad_dy -= slice_box.y1
        }

        for (point = 0; point < this.NPoints; point++) {
            max_y=0;
            max_x=0;
            max_p = heatmap[0][0][point];

            for(y = 0; y < this.NRows; y++) {
                for(x = 0; x < this.NCols; x++) {
                p = heatmap[y][x][point]
                if (p > max_p) {
                    max_y=y; max_x=x; max_p = p
                }
                }
            }
            if  (max_p > 0.35) {
                const dx = (offset[max_y][max_x][point*2+1] - 0.5)*6.0
                const dy = (offset[max_y][max_x][point*2]- 0.5)*6.0

                const res_x = (max_x+dx) * this.SizeReduction * factor - pad_dx;
                const res_y = (max_y+dy) * this.SizeReduction * factor - pad_dy;
                this.result.set(point,res_x,res_y)
            }
        }
    }

    recoverCoords2(coords:number[], size:number, pad_dx:number, pad_dy:number, slice_box:BBox | undefined) {
        const factor = size / this.InputImageSize;

        this.result.reset()
        if (slice_box) {
            pad_dx -= slice_box.x1
            pad_dy -= slice_box.y1
        }

        for (let point = 0; point < this.NPoints; point++) {
            const x = coords[point*2]
            const y = coords[point*2+1]
            if  (x > -50 && y > -50 ) {
                const res_x = x * this.SizeReduction * factor - pad_dx;
                const res_y = y * this.SizeReduction * factor - pad_dy;
                this.result.set(point,res_x,res_y)
            }
        }

        //console.log("(((((((((((((((",this.result);
    }
    recoverCoords3(coords:Float32Array, size:number, pad_dx:number, pad_dy:number, slice_box:BBox | undefined) {
        const factor = size / this.InputImageSize;

        this.result.reset()
        if (slice_box) {
            pad_dx -= slice_box.x1
            pad_dy -= slice_box.y1
        }

        for (let point = 0; point < this.NPoints; point++) {
            const x = coords[point*2]
            const y = coords[point*2+1]
            //console.log(x,y)
            if  (x > -50 && y > -50 ) {
                const res_x = x * this.SizeReduction * factor - pad_dx;
                const res_y = y * this.SizeReduction * factor - pad_dy;
                this.result.set(point,res_x,res_y)
            }
        }
    }

    dummyInput(n1=222,n2=222):tf.Tensor3D {
        return tf.randomUniform([n1, n2, 3],0,255,'int32')
    }

    async dummyPredict()  {
        const started = Date.now();
        const dummyInputTensor = this.dummyInput();
        await this.predict(dummyInputTensor);
        dummyInputTensor.dispose();
        const t = Date.now()-started;
        //debug(`dummy predict took ${t} ms`)
        return t;
    }

    async dummyPredict0(dummyInputTensor:tf.Tensor3D) {
        const started = Date.now();
        await this.predict(dummyInputTensor);
        const t = Date.now()-started;
        //debug(`dummy predict took ${t} ms`)
        return t;
    }
    /*
    async speedTest0(n:number = 5) {
        const dummyInputTensor = this.dummyInput();
        let ts:number[]=[];
        for(let i=0;i<n;i++) {
            ts.push(await this.dummyPredict0(dummyInputTensor));
        }
        const t = median(ts)
        let fps = Math.round(1000/t);
        console.log(`average of ${n} predicts took ${t} ms; ===>  ${fps}FPS`)

        dummyInputTensor.dispose();

        return fps;
    }

    async speedTest1(n:number = 5) {
        let t = 0;
        for(let i=0;i<n;i++) {
            t += await this.dummyPredict();
        }
        let fps = Math.round(1000/(t/n));
        console.log(`average of ${n} predicts took ${t/n} ms; ===>  ${fps}FPS`)
        return fps;
    }


    async speedTest2( n : number = 5) {
        const dummyInputTensor = this.dummyInput(480,640);
        let tot : number = 0;
        for(let i=0; i<n; i++) {
            const t = await this.dummyPredict0(dummyInputTensor)
            //console.log(`${i+1}/${n} : ${t} ms`)
            tot += t;
        }
        return tot / n;
    }
    */

}

export default SkeletonModel;
