#version 330 core
in vec4 vertex_color;
out vec4 fragColor;

uniform sampler2D tex;

uniform float light_mode;

uniform vec4 color_grading;

uniform vec4 scale;
uniform vec4 low;
uniform vec4 high;
uniform vec4 visible;

uniform vec3 ch1_col;
uniform vec3 ch2_col;
uniform vec3 ch3_col;
uniform vec3 ch4_col;

float VT_QUERIES = 256.;
float N = 4.*VT_QUERIES+1.;

vec2 dy = vec2(0., 1./N);

vec4 q_thresh(float i, vec2 pos) {
    vec4 high_byte = texture(tex, pos + (1.+2.*VT_QUERIES + i)*dy);
    vec4 low_byte = texture(tex, pos + (1.+3.*VT_QUERIES + i)*dy);
    return 255./80. * (high_byte + low_byte/255.) - 1.55;
}

vec4 q_val(float i, vec2 pos) {
    vec4 high_byte = texture(tex, pos + (1. + i)*dy);
    vec4 low_byte = texture(tex, pos + (1.+ VT_QUERIES + i)*dy);
    return high_byte + low_byte/255.;
}

vec4 get_cdf_helper(float x, float i, float y) {
    vec2 pos = vec2(x, 0.5/N);

    vec4 val_curr = q_val(i, pos);
    vec4 thresh_curr = q_thresh(i, pos);
    vec4 val_next = q_val(i+1., pos);
    vec4 thresh_next = q_thresh(i+1., pos);

    vec4 denom = abs(thresh_next - thresh_curr) + 1e-6;
    vec4 fac = (y - thresh_curr)/denom;
    return mix(val_curr, val_next, fac);
}

vec4 get_cdf(float x, vec4 y) {
    vec2 pos = vec2(x, 0.5/N);

    vec4 n = texture(tex, pos) * 255.;

    vec4 low0 = vec4(0., 0., 0., 0.);
    vec4 high0 = max(n - vec4(1.,1.,1.,1.), vec4(0.,0.,0.,0.));

    // Binary search
    for (int i = 0; i < 9; i++) { // Binary search
        vec4 i0 = floor((low0 + high0)/2.);

        vec4 thresh0 = vec4(
            q_thresh(i0.x, pos).x,
            q_thresh(i0.y, pos).y,
            q_thresh(i0.z, pos).z,
            q_thresh(i0.w, pos).w
        );

        low0 = mix(low0, i0, step(thresh0, y));
        high0 = mix(high0, i0, step(y, thresh0));
    }
    
    return vec4(
        get_cdf_helper(x, low0.x, y.x).x,
        get_cdf_helper(x, low0.y, y.y).y,
        get_cdf_helper(x, low0.z, y.z).z,
        get_cdf_helper(x, low0.w, y.w).w
    );
}

vec4 get_color(vec3 nom_color, float pdf, float mode) {
    float lm = 1.-0.25*light_mode;
    float lm2 = 1.-0.5*light_mode;

    vec4 cgraded = vec4(lm, lm, lm, 1.);
    float log_pdf = log2(max(16.*pdf, 1e-4))/3.;
    float t = fract(log_pdf);
    if (log_pdf > 4.) {
    } else if (log_pdf > 3.) {
        cgraded = vec4(lm + (1.-t)*(1.-lm), lm, lm*t, 1.);
    } else if (log_pdf > 2.) {
        cgraded = vec4(1., lm*t, 0., 1.);
    } else if (log_pdf > 1.) {
        cgraded = vec4(t, 0.2*(1.-t), 1.-t, 1.);
    } else if (log_pdf > 0.) {
        cgraded = vec4(0., 0.2*t + (1.-t), t, 1.);
    } else {
        t = 16.*pdf;
        cgraded = vec4(
           light_mode*(1.-t),
           light_mode*(1.-t) + t,
           light_mode*(1.-t),
           clamp((t - 0.05)/0.95, 0., 1.)
        );
    }
    return vec4(nom_color, clamp(pdf, 0., 1.)) * (1.-mode) + cgraded * mode;
}

void main() {
//    fragColor = vec4(texture(tex, vertex_color.xy).xyz, 1.); return; // Debug

    vec4 dy = clamp((high-low) * 5e-3, 3e-4, 10e-3);
    vec4 y0 = low + (high-low) * (vertex_color.y);

    vec4 pdf = (
       get_cdf(vertex_color.x, y0+dy)
       - get_cdf(vertex_color.x, y0-dy)
    )/dy - 0.4;
    pdf = mix(pdf, vec4(0., 0., 0., 0.), step(vec4(1., 1., 1., 1.), y0)); // Zero out above top range

    pdf *= scale;

    // Implement alpha composite passes
    vec4 pix1 = get_color(ch1_col, pdf.x, color_grading.x) * visible.x;
    vec4 pix2 = get_color(ch2_col, pdf.y, color_grading.y) * visible.y;
    vec4 pix3 = get_color(ch3_col, pdf.z, color_grading.z) * visible.z;
    vec4 pix4 = get_color(ch4_col, pdf.w, color_grading.w) * visible.w;

    vec4 pix = vec4(pix1.xyz, 1.) * pix1.w;
    pix = vec4(pix2.xyz, 1.) * pix2.w + (1.-pix2.w) * pix;
    pix = vec4(pix3.xyz, 1.) * pix3.w + (1.-pix3.w) * pix;
    pix = vec4(pix4.xyz, 1.) * pix4.w + (1.-pix4.w) * pix;

    fragColor = pix;
}
