
import { tf } from './my-tf'
import logger from '@kemtai/logger'
import { a7_192 } from "./utils"
import { GPU } from '@kemtai/utils';

//import Config from "../config"

import { sleep } from '@kemtai/utils';
import SkeletonModel from './skeleton-model';

const modelBaseUrl = 'https://models.api.kemtai.com'
//const modelBaseUrl = 'http://localhost:4000/p/models'
const modelVer = 'v6'

type ModelConfigData = {
    modelVersions: { [key: string]: string[] }
    models: string[],
    smallModels: string[],
    gpu2group: [string, string][],
    fps2group: [number, string][],
    group2model: { [key: string]: string[] }
}

export function gpu16bit() {
    return !tf.ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE')
}

class ModelConfig {
    config: ModelConfigData | undefined = undefined
    //version : string  | undefined = undefined
    modelVersions: { [key: string]: string[] } | undefined = undefined
    //readonly gpu16bit : boolean = gpu16bit()

    modelUrl(url: string) {
        return `${modelBaseUrl}/32/${modelVer}/${url}`
        /*
        if(this.gpu16bit) {
            return `${modelBaseUrl}/16/${modelVer}/${url}`
        } else {
            return `${modelBaseUrl}/32/${modelVer}/${url}`
        }
        */
    }

    configUrl() {
        return this.modelUrl("config.json")
    }

    mid2path(mid: string): string {
        if (this.modelVersions && this.modelVersions[mid]) {
            return this.modelUrl(`${this.modelVersions[mid]}/${mid}/model.json`)
        } {
            return ""
        }
    }

    async init() {
        while (true) { // never give up
            try {
                const resp = await fetch(this.configUrl())
                this.config = await resp.json()
                //this.version = this.config!.version
                this.modelVersions = this.config!.modelVersions
                return
            } catch {
                await sleep(200)
            }
        }
    }

    async ready() {
        while (!this.config) {
            await sleep(200)
        }
    }


    allModelsSync(): string[] {
        return this.config?.models ?? []
    }

    async allModels() {
        await this.ready()
        return this.config!.models
    }

    async checkFPSNew() {
        const model = a7_192()
        await sleep(1)
        const inputShape: tf.Shape = model.getLayer(undefined, 0).inputSpec[0].shape!
        await sleep(1)
        const input = tf.randomUniform([1, inputShape[1]!, inputShape[2]!, inputShape[3]!], -1, 1) as tf.Tensor4D;
        await sleep(1);
        (model.predict(input) as tf.Tensor<tf.Rank>).dataSync();
        await sleep(1);

        const started = Date.now();
        (model.predict(input) as tf.Tensor<tf.Rank>).dataSync();
        const time = Date.now() - started

        return 1000 / time
    }



    async selectBestModels(): Promise<string[]> {
        //console.log("selectModels:", GPU(), this.gpu16bit)
        await this.ready()

        // if (Config._model) {
        //     return [Config._model]
        // }
        const gpu = GPU()
        if (gpu) {
            for (let [kw, group] of this.config!.gpu2group) {
                if (gpu.includes(kw)) {
                    logger.event("gpuSelectModel", { gpu, mode: "known", group })
                    //console.log("gpuSelectModel:",group,gpu)
                    return this.config!.group2model[group]
                }
            }
        }

        const fps = await this.checkFPSNew()
        let group = "unknown"
        for (let [minFps, gid] of this.config!.fps2group) {
            if (fps > minFps) {
                group = gid
            } else {
                break
            }
        }
        console.log("gpuSelectModelFPS:", group, fps)

        logger.event("gpuSelectModel", { gpu, mode: "unknown-gpu", group, fps })
        return this.config!.group2model[group]
    }

    async selectModels(): Promise<string[]> {
        const allModels = await this.allModels()
        const cachedModels = await this.getCachedModels()
        const goodModels = await this.selectBestModels()

        const smallModels = this.config!.smallModels || [allModels[0], allModels[1]]
        //console.log(`*selectModels2*: goodModels: ${goodModels} cachedModels:${cachedModels}`)

        const goodCached = goodModels.filter(mid => cachedModels.includes(mid));
        const goodUncached = goodModels.filter(mid => !cachedModels.includes(mid));

        // if some of good models are already cached, just reorder the list
        if (goodCached.length > 0) {
            const result = [...goodCached, ...goodUncached]
            //console.log(`*selectModels2*: some good cached, returning: ${result}`)
            return result
        }

        // if some small models in the good list, do nothing
        const goodSmall = goodModels.filter(mid => smallModels.includes(mid));
        if (goodSmall.length > 0) {
            //console.log(`*selectModels2*: some small models in the good list, returning: ${goodModels}`)
            return goodModels
        }

        // add a small model to the to the list
        const smallCached = smallModels.filter(mid => cachedModels.includes(mid));
        const bestSmallModel = (smallCached.length > 0) ? smallCached[smallCached.length - 1] : smallModels[smallModels.length - 1]

        const result = [bestSmallModel, ...goodModels]
        //console.log(`*selectModels2*: added ${bestSmallModel} to list, returning: ${result}`)
        return result;
    }

    async getCachedModels(): Promise<string[]> {
        const loadedModels = await tf.io.listModels()
        const res: string[] = []
        for (const mid of await this.allModels()) {
            const remote = this.mid2path(mid)
            const local = SkeletonModel.remoteToLocal(remote)
            const isLoaded = local in loadedModels
            //console.log(`getLoaded: ${mid} ${isLoaded}`)
            if (isLoaded) {
                res.push(mid)
            }
        }
        return res
    }

}


const modelConfig = new ModelConfig()
modelConfig.init()

export default modelConfig

