import { BenchmarkResult, LensPerformanceCluster } from "./estimateLensPerformanceCluster";
import { createFramebuffer, createProgram, createTexture, promiseSync, setUniform1i, setUniform4f } from "./webglUtils";

const vertexSource = `#version 300 es

precision mediump float;
precision mediump int;

in vec2 pos;

void main() {
    gl_Position = vec4(pos, 0.0, 1.0);
}
`;

const fragmentSource = `#version 300 es

precision mediump float;
precision mediump int;

uniform int LOOP_COUNT;
uniform vec4 v0;
uniform vec4 v1;
uniform vec4 v2;

out vec4 fragColor;

#define REPEAT_2(x) x; x
#define REPEAT_4(x) REPEAT_2(x); REPEAT_2(x)
#define REPEAT_8(x) REPEAT_4(x); REPEAT_4(x)
#define REPEAT_16(x) REPEAT_8(x); REPEAT_8(x)
#define REPEAT_32(x) REPEAT_16(x); REPEAT_16(x)

void main() {
    vec4 r = v2;
    for (int i = 0; i < LOOP_COUNT; i++) {
        REPEAT_32(r = r * v1 + v0);
    }
    fragColor = r;
}
`;

const width = 1024;
const height = 1024;
const budgetMs = 300;
const maxLoopCount = 1000;

function prepareBenchmark(gl: WebGL2RenderingContext): { program: WebGLProgram; cleanupBenchmark: () => void } {
    const texture = createTexture(gl, width, height);
    const framebuffer = createFramebuffer(gl, texture);
    gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);

    const buffer = gl.createBuffer();
    if (!buffer) {
        throw new Error("Failed to create WebGLBuffer.");
    }
    gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
    gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([-1, 1, -1, -1, 1, -1, 1, 1]), gl.STATIC_DRAW);

    gl.viewport(0, 0, width, height);
    gl.disable(gl.CULL_FACE);
    gl.disable(gl.DEPTH_TEST);

    const program = createProgram(gl, vertexSource, fragmentSource);
    gl.useProgram(program);

    const posLocation = gl.getAttribLocation(program, "pos");
    gl.enableVertexAttribArray(posLocation);
    gl.vertexAttribPointer(posLocation, 2, gl.FLOAT, false, 0, 0);

    setUniform4f(gl, program, "v0", [1.15, 1.23, 1.47, 1.84]);
    setUniform4f(gl, program, "v1", [1.65, 1.22, 1.69, 1.04]);
    setUniform4f(gl, program, "v2", [1.05, 1.3, 1.55, 1.23]);

    return {
        program,
        cleanupBenchmark: () => {
            gl.deleteProgram(program);
            gl.deleteBuffer(buffer);
            gl.deleteFramebuffer(framebuffer);
            gl.deleteTexture(texture);
        },
    };
}

async function runBenchmark(gl: WebGL2RenderingContext, program: WebGLProgram): Promise<number> {
    await promiseSync(gl);

    const start = performance.now();
    const flops: number[] = [];

    let loopCount = 20;
    while (true) {
        setUniform1i(gl, program, "LOOP_COUNT", loopCount);

        const iterationStart = performance.now();

        gl.drawArrays(gl.TRIANGLE_FAN, 0, 4);
        await promiseSync(gl);

        const iterationEnd = performance.now();
        const duration = iterationEnd - iterationStart;
        flops.push(loopCount / duration);

        const remainingBudgetMs = budgetMs - (iterationEnd - start);
        if (remainingBudgetMs < 0) break;

        if (loopCount < maxLoopCount) {
            loopCount += remainingBudgetMs < duration ? 10 : (0.6 * loopCount * remainingBudgetMs) / duration;
        }
    }

    const maxFlops = Math.max(...flops) * (8 * 32) * width * height;
    return maxFlops / 1e6;
}

/**
 * These are obtained from historical performance data gathered from end-user devices. Keys are gflops, values are the
 * corresponding performance rating.
 *
 * To compute the performance rating given a gflops value, find the nearest key in this map and look up the rating.
 */
export const gflopsClusterCenters = new Map<number, LensPerformanceCluster>([
    [34, 1],
    [134, 2],
    [385, 3],
    [783, 4],
    [1484, 5],
    [2313, 6],
]);

/**
 * This benchmark is the same that's run on non-web (e.g. native mobile) platforms, and produces results that can be
 * compared to those gathered on those platforms – this allows us to cluster results and determine a performance rating
 * based on historical data gathered elsewhere.
 *
 * @internal
 */
export async function benchmarkGflops(gl: WebGL2RenderingContext): Promise<BenchmarkResult> {
    const { program, cleanupBenchmark } = prepareBenchmark(gl);
    const gflops = await runBenchmark(gl, program);
    cleanupBenchmark();
    return { name: "gflops", value: gflops };
}
