import {  tf } from './my-tf'
import {sortBy} from "lodash";

import SkeletonModel from "./skeleton-model"
import { Smoother } from "./utils"
import modelConfig from "./model-config"
import { BBox, KeypointsArray } from "@kemtai/keypoints";

import { sleep } from "@kemtai/utils";
import logger from '@kemtai/logger'
//import { isMobile } from '@kemtai/utils';

//let lastPredictTime = Date.now()

interface ModelInfoStats {
    [id: string]:{count:number,fps:number}
}

export interface ModelInfo {
    name : string
    time : number
    rawtime : number
    fps  : number
    loading : boolean,
    stats : ModelInfoStats
}

abstract class ModelSelectonStrategy {
    mm : MultiModel

    constructor(mm:MultiModel) {
        this.mm = mm
    }
    abstract updateModel():void;

    get modelsNum() {
        return this.mm.LoadedModelIds.length
    }
    initModel() {
        this.updateModel()
    }

}

class MSSExact extends ModelSelectonStrategy {
    model:string
    constructor (mm:MultiModel, model:string) {
        super(mm)
        this.model = model
    }
    updateModel() {
        this.mm.currentModel = this.model
    }
}


class MSSDefault extends ModelSelectonStrategy {
    updateModel() {
        this.mm.currentModel = this.mm.LoadedModelIds[0]
    }
}

/*
class MSSExplore extends ModelSelectonStrategy {
    current : number = 0
    updateModel() {
        this.current = (this.current + 1) % this.modelsNum
        this.mm.currentModel = this.mm.LoadedModelIds[this.current]
    }
}

*/

class SelectonParams {
    constructor(
            public bandMin = 11,
            public bandMax = 15,
            public minThr = -1000,
            public maxThr = 100,
            ) {  }
}


class MSSAutomatic extends ModelSelectonStrategy {
    params : SelectonParams
    current : number = 0
    score : number = 0
    pause : number = 10

    constructor (mm:MultiModel, currentModelID? : string, params:SelectonParams = new SelectonParams() ) {
        super(mm)
        this.params = params
        if (currentModelID) {
            let current = this.mm.LoadedModelIds.indexOf(currentModelID)
            if (current>=0) {
                this.current = current
            }
        }
        //print("MSSAutomatic started:", this.mid, this.params)
    }

    get mid() {
        return this.mm.LoadedModelIds[this.current]
    }

    set(idx:number) {
        if (idx < 0 || idx >= this.modelsNum) {
          return
        }
        this.current = idx
        this.score  = 0
        this.pause = 10
    }

    updateModel() {
        if (this.pause>0) {
            this.pause --;
            return
        }
        const fps = this.mm.FPS

        if (fps > this.params.bandMax && this.current < this.modelsNum-1) {
            this.score += fps-this.params.bandMax;
            if (this.score > this.params.maxThr) {
              this.set(this.current + 1)
              //console.log("AutomaticModelSelector: Upgraiding model to",this.mm.LoadedModelIds[this.current])
            }
        } else if (fps < this.params.bandMin && this.current>0) {
            this.score -= this.params.bandMin - fps;
            if (this.score < this.params.minThr) {
              this.set(this.current - 1)
              //console.log("AutomaticModelSelector: Downgraiding model to",this.mm.LoadedModelIds[this.current])
            }
        } else {
              this.score *= 0.995
              //debug(`AutomaticModelSelector:  FPS: ${Math.round(fps)} ${this.mm.LoadedModelIds[this.current]} ${Math.round(this.score)}`)
        }

        this.mm.currentModel = this.mm.LoadedModelIds[this.current]
    }

    initModel() {
        let res = 0
        for (let idx = 0; idx < this.modelsNum; idx++) {
            const mid = this.mm.LoadedModelIds[idx]
            const fps = this.mm.timers0[mid]?.FPS
            if (fps && fps >= this.params.bandMin){
                res = idx
            }
        }
        this.set(res)
        this.mm.currentModel = this.mm.LoadedModelIds[this.current]
    }

}

class BoxStats
{
    trueCount = 0
    falseCount = 0

    cnt(b:boolean) {
        if (b) {
            this.trueCount++
        } else {
            this.falseCount++
        }

        if ((this.trueCount + this.falseCount)>0 && (this.trueCount + this.falseCount) % 500 == 0) {
            const N = this.trueCount + this.falseCount
            const pct = this.trueCount * 100. / N
            //console.log(`BoxStats: ${pct}% (N=${N})`)
            logger.event("BoxStats", {"BoxPct":pct, N:N})
        }
    }
}

class MultiModel {
    LoadedModelIds : string[] = []
    private models: { [id: string] : SkeletonModel; } = {};
    timers0 : { [id: string] : Smoother; }      = {};
    timers1 : { [id: string] : Smoother; }      = {};

    currentModel : string = ""
    strategy : ModelSelectonStrategy
    policy : string = "default"

    readonly centralSquareBBox = new BBox(80,0,560,480)
    useBox : boolean = false
    boxStats = new BoxStats()

    ready : boolean = false
    error : boolean = false

    private lastTime = 0

    constructor() {
        this.strategy = new MSSDefault(this)
    }


    box(imShape:[number,number,number]) : BBox|undefined{
        const [h,w,d] = imShape
        if (w!==640 || h!==480){
            this.useBox = false
        }
        return this.useBox ? this.centralSquareBBox : undefined
    }


    updateBox(kp:KeypointsArray|null) {
        this.useBox = false
        if (kp) {
            const kp_box = kp.bbox()
            if (kp_box) {
                if (kp_box.x1 > 120 && kp_box.x2 < 480) {
                    this.useBox = true
                }
            }
        }
        this.boxStats.cnt(this.useBox)
    }


    setPolicy(policy : string) {
        //console.log("setPolicy:", policy)
        this.policy = policy

        switch (policy) {
            /*
            case "explore": {
                this.strategy = new MSSExplore(this)
                break;
            }
            */
            case "automatic": {
                this.strategy = new MSSAutomatic(this,this.currentModel, new SelectonParams(6,10))
                break;
            }
            case "default": {
                this.strategy = new MSSDefault(this)
                break;
            }
            case "slow": {
                this.strategy = new MSSAutomatic(this, this.currentModel, new SelectonParams(6,10))
                break;
            }
            case "fast": {
                this.strategy = new MSSAutomatic(this, this.currentModel, new SelectonParams(11,15))
                break;
            }

            default: {
                const idx = this.LoadedModelIds.indexOf(policy)
                if (idx>=0) {
                    this.strategy = new MSSExact(this,policy)
                    break
                } else {
                    console.error(`MultiModel.setPolicy: unknown policy:${policy}; ignoring...`)
                    return
                }
            }
        }

        this.strategy.initModel()
    }

    async loadAllModels(policy:string) {
        const models = await modelConfig.allModels()
        return await this.loadModels(models, policy)
    }

    async loadModel(mid:string) {
        if (this.LoadedModelIds.indexOf(mid) > -1) {
            logger.log(`loadModel ${mid} already loaded`)
            return true
        }

        logger.log(`multi-model: Loading ${mid}`)
        this.models[mid] = new SkeletonModel(modelConfig.mid2path(mid))
        const status = await this.models[mid].loadModel()
        logger.log(`multi-model: ${mid} loaded success: ${status}`)
        if (!status) {
            return false
        }
        await sleep(0)
        this.timers0[mid] = new Smoother()
        this.timers1[mid] = new Smoother()
        this.LoadedModelIds.push(mid)
        return true
    }

    async load(policy : string = "automatic") {
        const allModels = await modelConfig.allModels()
        if (allModels.includes(policy)) {
            return await this.loadModels([policy], policy)
        }
        const models  = await modelConfig.selectModels();
        console.log("multi-model.load",models,policy)
        return await this.loadModels(models, policy)
    }

    async isReady() {
        while(!this.ready && !this.error) {
            await sleep(50)
        }
    }


    async warmUp(policy : string = "automatic") {
        const models  = await modelConfig.selectModels();
        for (let mid of models) {
            //console.log("Warmup ",mid)
            const sm = new SkeletonModel(modelConfig.mid2path(mid))
            await sm.copyToLocal()
        }
    }

    private async loadModels(models:string[], policy : string = "automatic") {
        //console.log("** loadModels",this.currentModel,this.policy)

        await modelConfig.ready()

        /////////////////////// DEBUG - simulating loading error. 
        //this.error = true
        //return false
        //////////////////////////////////


        //this.ready = false
        for (let mid of models) {
            const success = await this.loadModel(mid)
            if (success) {
                this.strategy.initModel()
                this.ready = true
                //console.log(`multi-model.loadModels ready @${logger.sessionTime()}s ${this.currentModel}`)
                await sleep(0)
            } else {
                continue
            }
        }

        if (!this.ready) {
            this.error = true
            //console.log(`multi-model.loadModels error state @${logger.sessionTime()}s`)
            return false
        }        

        /// resort loaded models according MODELS order
        const MODELS = await modelConfig.allModels()

        let k : {[m:string]:number} = {}
        for(let i=0; i<MODELS.length; i++){
            k[MODELS[i]] = i
        }
        this.LoadedModelIds = sortBy(this.LoadedModelIds, [(m)=>{k[m]}])
        //consoleLog("Loaded models:",this.LoadedModelIds.join(" "))

        this.setPolicy(policy)
        //this.ready = true
        //console.log("multi-model ready @ at " , Date.now() - (window as any).__started__)
        //console.log(`multi-model.loadModels ALL models loaded! @${logger.sessionTime()}s running ${this.currentModel}`)

        return true
    }


    logFPS(){
        for (let mid of this.LoadedModelIds) {
            console.log(`multi-model: ${mid} FPS=${Math.round(this.timers0[mid].FPS)}/${Math.round(this.timers1[mid].FPS)} `)
        }
    }

    get FPS() {
        return this.timers1[this.currentModel]?.FPS || 1
    }

    async _predict(t:tf.Tensor3D, box?:BBox) : Promise<KeypointsArray | null>{
        return this.models[this.currentModel].predictNew(t, box)
        /*
        if (isMobile()){
            return await this.models[this.currentModel].predictNew2(t, box)
        } else {
            return this.models[this.currentModel].predictNew(t, box)
        } 
        */
    }

    async predict(t:tf.Tensor3D) {
        //console.log("** predict",this.currentModel,this.policy)
        if (!this.ready) {
            console.log("multi-model:predict - model is not ready; return null")
            return null
        }
        if (!this.lastTime) {
            this.lastTime =  Date.now()
        }
        const started = Date.now()
        //const kp = await this.models[this.currentModel].predict(t)
        //const kp = this.models[this.currentModel].predictNew(t, this.box(t.shape))
        const kp = await this._predict(t, this.box(t.shape))

        /////////////////////// DEBUG - simulating model error. 
        //if (this.models[this.currentModel].count > 300 ) {
        //    kp = new KeypointsArray()
        //}
        //////////////////////////////////

        this.updateBox(kp)
        const time0 = Date.now() - started
        this.timers0[this.currentModel].add(time0)
        const time1 =  Date.now() - this.lastTime
        this.timers1[this.currentModel].add(time1)
        this.lastTime =  Date.now()
        //console.log(`==== ${this.policy} ${this.currentModel}: time:${time1} | FPS ${this.FPS}`)
        this.strategy.updateModel()
        return kp
    }

    predictImage(im:ImageData|ImageBitmap|HTMLVideoElement) {
        const t  = tf.browser.fromPixels(im)
        const kp = this.predict(t)
        t.dispose()
        //console.log("** predictImage ", Math.round(Date.now()-lastPredictTime),this.currentModel,this.policy)
        //lastPredictTime = Date.now()
        return kp
    }


    get stats() : ModelInfoStats {
        let res:ModelInfoStats = {}
        for (let mid of this.LoadedModelIds) {
            if (this.timers1[mid] && this.timers1[mid].count>10) {
                res[mid] = {
                    count:this.timers1[mid].count,
                    fps:this.timers1[mid].FPS
                }
            }
        }
        return res;
    }


    get info() : ModelInfo {
        return {
            name : this.currentModel,
            fps   : this.FPS,
            time  : this.timers1[this.currentModel ?? Object.keys(this.timers1)[0]]?.val,
            rawtime : this.timers1[this.currentModel ?? Object.keys(this.timers1)[0]]?.rawval,
            loading : !this.ready,
            stats : this.stats
        }
    }
}


export default  MultiModel;
