import { WebGLHelper, glsl } from '../WebglHelper'

export default class JointBilateralFilter {
    constructor(
        gl,
        vertexShader,
        positionBuffer,
        texCoordBuffer,
        inputTexture,
        outputTexture,
        canvas
    ) {

        this.gl = gl;
        this.inputTexture = inputTexture;

        const fragmentShaderSource = glsl`#version 300 es

            precision highp float;
        
            uniform sampler2D u_inputFrame;
            uniform sampler2D u_segmentationMask;
            uniform vec2 u_texelSize;
            uniform float u_step;
            uniform float u_radius;
            uniform float u_offset;
            uniform float u_sigmaTexel;
            uniform float u_sigmaColor;
        
            in vec2 v_texCoord;
        
            out vec4 outColor;
        
            float gaussian(float x, float sigma) {
                float coeff = -0.5 / (sigma * sigma * 4.0 + 1.0e-6);
                return exp((x * x) * coeff);
            }
        
            void main() {
                vec2 centerCoord = v_texCoord;
                vec3 centerColor = texture(u_inputFrame, centerCoord).rgb;
                float newVal = 0.0;
            
                float spaceWeight = 0.0;
                float colorWeight = 0.0;
                float totalWeight = 0.0;
            
                // Subsample kernel space.
                for (float i = -u_radius + u_offset; i <= u_radius; i += u_step) {
                    for (float j = -u_radius + u_offset; j <= u_radius; j += u_step) {
                        vec2 shift = vec2(j, i) * u_texelSize;
                        vec2 coord = vec2(centerCoord + shift);
                        vec3 frameColor = texture(u_inputFrame, coord).rgb;
                        float outVal = texture(u_segmentationMask, coord).a;
                
                        spaceWeight = gaussian(distance(centerCoord, coord), u_sigmaTexel);
                        colorWeight = gaussian(distance(centerColor, frameColor), u_sigmaColor);
                        totalWeight += spaceWeight * colorWeight;
                
                        newVal += spaceWeight * colorWeight * outVal;
                    }
                }
                newVal /= totalWeight;
            
                outColor = vec4(vec3(0.0), newVal);
            }
        `

        this.outputWidth = canvas.width;
        this.outputHeight = canvas.height;
        this.texelWidth = 1 / this.outputWidth;
        this.texelHeight = 1 / this.outputHeight;
        
        this.fragmentShader = WebGLHelper.compileShader(
            this.gl,
            this.gl.FRAGMENT_SHADER,
            fragmentShaderSource
        );
        this.program = WebGLHelper.createPiplelineStageProgram(
            this.gl,
            vertexShader,
            this.fragmentShader,
            positionBuffer,
            texCoordBuffer
        );
        this.inputFrameLocation = this.gl.getUniformLocation(this.program, 'u_inputFrame');
        this.segmentationMaskLocation = this.gl.getUniformLocation(this.program, 'u_segmentationMask');
        this.texelSizeLocation = this.gl.getUniformLocation(this.program, 'u_texelSize');
        this.stepLocation = this.gl.getUniformLocation(this.program, 'u_step');
        this.radiusLocation = this.gl.getUniformLocation(this.program, 'u_radius');
        this.offsetLocation = this.gl.getUniformLocation(this.program, 'u_offset');
        this.sigmaTexelLocation = this.gl.getUniformLocation(this.program, 'u_sigmaTexel');
        this.sigmaColorLocation = this.gl.getUniformLocation(this.program, 'u_sigmaColor');
        
        this.frameBuffer = this.gl.createFramebuffer();
        this.gl.bindFramebuffer(this.gl.FRAMEBUFFER, this.frameBuffer);
        this.gl.framebufferTexture2D(
            this.gl.FRAMEBUFFER,
            this.gl.COLOR_ATTACHMENT0,
            this.gl.TEXTURE_2D,
            outputTexture,
            0
        );
        
        this.gl.useProgram(this.program);
        this.gl.uniform1i(this.inputFrameLocation, 0);
        this.gl.uniform1i(this.segmentationMaskLocation, 1);
        this.gl.uniform2f(this.texelSizeLocation, this.texelWidth, this.texelHeight);
        
        // Ensures default values are configured to prevent infinite
        // loop in fragment shader
        this.updateSigmaSpace(0)
        this.updateSigmaColor(0)

    }

    render() {
        this.gl.viewport(0, 0, this.outputWidth, this.outputHeight);
        this.gl.useProgram(this.program);
        this.gl.activeTexture(this.gl.TEXTURE1);
        this.gl.bindTexture(this.gl.TEXTURE_2D, this.inputTexture);
        this.gl.bindFramebuffer(this.gl.FRAMEBUFFER, this.frameBuffer);
        this.gl.drawArrays(this.gl.TRIANGLE_STRIP, 0, 4);
    }

    updateSigmaSpace(sigmaSpace) {
        sigmaSpace *= Math.max(            
            this.outputWidth / 256,
            this.outputHeight / 144
        );

        const kSparsityFactor = 0.66; // Higher is more sparse.
        const sparsity = Math.max(1, Math.sqrt(sigmaSpace) * kSparsityFactor);
        const step = sparsity;
        const radius = sigmaSpace;
        const offset = step > 1 ? step * 0.5 : 0;
        const sigmaTexel = Math.max(this.texelWidth, this.texelHeight) * sigmaSpace;

        this.gl.useProgram(this.program);
        this.gl.uniform1f(this.stepLocation, step);
        this.gl.uniform1f(this.radiusLocation, radius);
        this.gl.uniform1f(this.offsetLocation, offset);
        this.gl.uniform1f(this.sigmaTexelLocation, sigmaTexel);
    }

    updateSigmaColor(sigmaColor) {
        this.gl.useProgram(this.program);
        this.gl.uniform1f(this.sigmaColorLocation, sigmaColor);
    }

    cleanUp() {
        this.gl.deleteFramebuffer(this.frameBuffer);
        this.gl.deleteProgram(this.program);
        this.gl.deleteShader(this.fragmentShader);
    }

}
