import {  tf } from './my-tf'

export class ModelManip {
    predeccesors : Map<string,string[]> = new Map<string,string[]>();
    successors : Map<string,string[]> = new Map<string,string[]>();
    inputs : string[]  = []
    outputs : string[] = []
    model : tf.LayersModel
    replacedBy : Map<string,string> = new Map<string,string>();
    newLayers = new Map<string,tf.layers.Layer>()

    constructor(model: tf.LayersModel) {
        this.model = model


        for (let layer of model.layers){
            const predeccesors: string[]= []
            const successors: string[]= []


            const layerInputs : tf.SymbolicTensor[] = []
            for (let node of layer.inboundNodes){
                for(let inlayer of node.inboundLayers){
                    predeccesors.push(inlayer.name)
                }
            }
            for (let node of layer.outboundNodes){
                successors.push(node.outboundLayer.name)
            }

            if (predeccesors.length===0){
                // input layer
                this.inputs.push(layer.name)
            }
            if (successors.length===0){
                // input layer
                this.outputs.push(layer.name)
            }

            this.predeccesors.set(layer.name,predeccesors)
            this.successors.set(layer.name,successors)

        }
        //console.log(this.inputs)
        //console.log(this.outputs)
        //console.log(this.predeccesors)
        //console.log(this.successors)
    }



    copy():tf.LayersModel{

        this.newLayers = new Map<string,tf.layers.Layer>()

        for (let layer of this.model.layers){
            const lconfig = layer.getConfig()
            const newLayer : tf.layers.Layer = tf.layers.deserialize({className:layer.getClassName(), config:lconfig}) as tf.layers.Layer
            this.newLayers.set(layer.name,newLayer)

            if ( this.predeccesors.get(layer.name)!.length>0) {
                const layerInputs : tf.SymbolicTensor[] = []
                for (let inputLayerName of this.predeccesors.get(layer.name)!){
                    const l = this.newLayers.get(inputLayerName)!
                    layerInputs.push( l.output as tf.SymbolicTensor )

                }
                this.newLayers.get(layer.name)!.apply(layerInputs)
            }
        }

        const inputs : tf.SymbolicTensor[]  = this.inputs.map( l => this.newLayers.get(l)!.output as tf.SymbolicTensor)
        const outputs: tf.SymbolicTensor[]  = this.outputs.map( l => this.newLayers.get(l)!.output as tf.SymbolicTensor)

        const model = tf.model({inputs: inputs, outputs: outputs})

        for (let layer of this.model.layers){
            this.newLayers.get(layer.name)!.setWeights(layer.getWeights())
        }

        //model.summary()
        return model
    }


    isPredeccesor(a:string,b:string):boolean {
        return this.predeccesors.get(a)!.includes(b)
    }

    fusableAdd(layerName:string) {
        const layer = this.model.getLayer(layerName)
        if ("Add" !== layer.getClassName()){
            return false
        }

        const predeccesors :string[] = this.predeccesors.get(layerName)!
        if (predeccesors.length !== 2){
            return false
        }
        const [pred1,pred2] = predeccesors
        return this.isPredeccesor(pred1,pred2) || this.isPredeccesor(pred2,pred1)
    }

    getInputs(layerName:string) {
        const layerInputs : tf.SymbolicTensor[] = []
        for (let inputLayerName of this.predeccesors.get(layerName)!){
            if (this.replacedBy.get(inputLayerName)){
                inputLayerName = this.replacedBy.get(inputLayerName)!
            }
            layerInputs.push(this.newLayers.get(inputLayerName)!.output as tf.SymbolicTensor)
        }
        return layerInputs
    }

    fuseResiduals():tf.LayersModel {
        this.newLayers = new Map<string,tf.layers.Layer>()

        for (let layer of this.model.layers){
            //console.log("---------------",layer.getClassName(), layer.name);
            if (this.replacedBy.get(layer.name)){
                //console.log("=========Skiping =========", layer.name, this.fusableAdd(layer.name))
                continue
            }
            //if ("Add"===layer.getClassName()){
            //}

            const lconfig = layer.getConfig()

            if (layer.getClassName() === "SeparableConv2D") {
                const successors = this.successors.get(layer.name)!
                if (successors.length==1) {
                    const succLayerName = successors[0]
                    if (this.fusableAdd(succLayerName)){
                        //console.log("FUSE ",layer.getClassName(), layer.name, succLayerName)
                        this.replacedBy.set(succLayerName, layer.name)
                        lconfig.residual = true;
                    }
                }
            }

            const newLayer = tf.layers.deserialize({className:layer.getClassName(), config:lconfig}) as tf.layers.Layer

            this.newLayers.set(layer.name,newLayer)
            const layerInputs : tf.SymbolicTensor[] = this.getInputs(layer.name)
            if (layerInputs.length>0) {
                this.newLayers.get(layer.name)!.apply(layerInputs)
            }
        }

        const inputs : tf.SymbolicTensor[]  = this.inputs.map( l => this.newLayers.get(l)!.output as tf.SymbolicTensor)
        const outputs: tf.SymbolicTensor[]  = this.outputs.map( l => this.newLayers.get(l)!.output as tf.SymbolicTensor)

        const model = tf.model({inputs: inputs, outputs: outputs})

        for (let layer of this.model.layers){
            if (this.replacedBy.get(layer.name)){
                continue
            }

            this.newLayers.get(layer.name)!.setWeights(layer.getWeights())
        }

        //model.summary()
        return model
    }
}
