import * as THREE from 'three';

export const DeformationShader = {
    uniforms: {
        uTexture: { value: new THREE.Vector4() },
        uTextureHover: { value: new THREE.Vector4() },
        uGrid: { value: new THREE.Vector4() },
        uGridSize: { value: new THREE.Vector2() },
        uContainerResolution: { value: new THREE.Vector2(window.innerWidth, window.innerHeight) },
        uMouse: { value: new THREE.Vector2() },
        uDisplacement: { value: 0 },
        uImageResolution: { value: new THREE.Vector2(0, 0) },
        uRGBshift: { value: new THREE.Vector2(0.02, 0.0) },
        uCellData: { value: 0 },
        uTime: { value: 0 },
        uMobile: { value: 0 },
    },
    vertexShader: `
        varying vec2 vUv;

        void main()
        {
            vec4 modelPosition = modelMatrix * vec4(position, 1.0);
            vec4 viewPosition = viewMatrix * modelPosition;
            vec4 projectedPosition = projectionMatrix * viewPosition;
            gl_Position = projectedPosition;    

            vUv=uv;
        }
    `,
    fragmentShader: `
        struct CellData {
            int id;
            float opacity;
        };

        struct Cell {
            vec2 id;
            float opacity;
        };

        const int NUM_OPACITY = 128;

        uniform sampler2D uTexture;
        uniform sampler2D uTextureHover;
        uniform sampler2D uGrid;
        varying vec2 vUv;

        uniform vec2 uContainerResolution;
        uniform vec2 uGridSize;
        uniform vec2 uMouse;
        uniform float uDisplacement;
        uniform float uTime;
        uniform float uMobile;
        uniform vec2 uImageResolution;
        uniform vec2 uRGBshift;
        uniform vec3 uCellData[NUM_OPACITY];
        
        float finalMask = 0.0;

        vec2 coverUvs(vec2 imageRes,vec2 containerRes)
        {
            float imageAspectX = imageRes.x/imageRes.y;
            float imageAspectY = imageRes.y/imageRes.x;
            
            float containerAspectX = containerRes.x/containerRes.y;
            float containerAspectY = containerRes.y/containerRes.x;

            vec2 ratio = vec2(
                min(containerAspectX / imageAspectX, 1.0),
                min(containerAspectY / imageAspectY, 1.0)
            );

            vec2 newUvs = vec2(
                vUv.x * ratio.x + (1.0 - ratio.x) * 0.5,
                vUv.y * ratio.y + (1.0 - ratio.y) * 0.5
            );

            return newUvs;
        }

        float square(in vec2 _st, in float _size, in float blurriness){
            float maskX = step(abs(_st.x), _size);
            float maskY = step(abs(_st.y), _size);
            float size = maskX * maskY;
            return smoothstep(size - (size * blurriness), size + (size * blurriness), dot(size, size) * 4.0);
        }

        float random (vec2 st) {
            return fract(sin(dot(st.xy,
                                vec2(12.9898,78.233)))*
                43758.5453123);
        }

        vec3 rgbShift (vec2 Uv, vec4 displacement) {
            vec2 redUvs = Uv;
            vec2 blueUvs = Uv;
            vec2 greenUvs = Uv;

            vec2 shift = displacement.rg*0.001;

            float displacementStrengh=length(displacement.rg);
            displacementStrengh = clamp(displacementStrengh,0.,2.);
            
            float redStrengh = 1.+displacementStrengh*0.25;
            redUvs += shift*redStrengh;    
            
            float blueStrengh = 1.+displacementStrengh*1.5;
            blueUvs += shift*blueStrengh; 
            
            float greenStrengh = 1.+displacementStrengh*2.;
            greenUvs += shift*greenStrengh;
            
            float red = texture2D(uTexture,redUvs).r;
            float blue = texture2D(uTexture,blueUvs).b;    
            float green = texture2D(uTexture,greenUvs).g;     

            return vec3(red, green, blue);
        }

        vec2 cellSize(vec2 grid) {
            return 1.0 / grid;
        }

        vec2 cellCenter(vec2 mousePos) {
            vec2 gridSize = vec2(uGridSize.x);
            vec2 cellSize = -cellSize(gridSize);
            vec2 cellCoords = floor(mousePos * gridSize);
            vec2 cellPos = cellCoords * cellSize;
            vec2 cellCenter = cellPos + cellSize * 0.5;
            return -cellCenter;
        }

        vec2 cellCenter(int x, int y) {
            vec2 gridSize = vec2(uGridSize.x);
            vec2 cellSize = 1.0 / gridSize;
            vec2 cellPos = vec2(x, y) * cellSize;
            vec2 cellCenter = cellPos + cellSize * 0.5;
            return -cellCenter;
        }

        vec2 cellGridCoords(vec2 mousePos) {
            vec2 gridSize = vec2(uGridSize.x, uGridSize.x);
            vec2 cellSize = 1.0 / gridSize;
            vec2 cellCoords = floor((mousePos + cellSize) * gridSize);
            if(-cellCoords.y >= uGridSize.y - 1.0) { 
                cellCoords.x += 1.0;
            }
            return cellCoords;
        }

        vec2 stPos() {
            // We manage the device ratio by passing PR constant
            vec2 conRes = uContainerResolution;
            float pixelRatio = float(PR);
            vec2 res = conRes * pixelRatio;
            vec2 st = gl_FragCoord.xy / res.xy;

            // Keep the good ratio of the coordinates
            st.y *= conRes.y / conRes.x; 

            return st;
        }

        void viewGridedImage(vec2 st) {
            float cellSizeX = cellSize(uGridSize).x;

            for (int x = 0; x < int(uGridSize.x); x++) {
                for (int y = 0; y < int(uGridSize.y); y++) {
                    int cellID = x + y * int(uGridSize.x); // Cell index

                    vec2 objectPos = st + cellCenter(x, y);
                    float c = square(objectPos, cellSizeX * 0.5, 4.);

                    // Accumulate the reveal mask
                    finalMask += smoothstep(0.4, 0.5, c);
                }
            }
        }

        float roundedSquareBorder(vec2 _st, float _size, float radius) {
            vec2 d = abs(_st) - (_size - radius);
            float dist = 2.0 * length(max(d, 0.0)) - radius;

            return 0.5 - smoothstep(0.0, 0.1, dist);
        }

        // float roundedBorder(vec2 _st, float _size) {
        //     float maskX = step(abs(_st.x), _size);
        //     float maskY = step(abs(_st.y), _size);
        //     float size = maskX * maskY;
        //     float radius = 0.00001;
        //     _size = 0.0001;
        //     vec2 d = abs(_st) - vec2(_size - radius);
        //     float dist = length(max(d, 0.0)) - radius;

        //     // Use smoothstep to create an anti-aliased border effect
        //     float bluriness = 0.5;
        //     return 0.5 - smoothstep(0.0, bluriness, dist);
        // }

        void main()
        {
            vec2 newUvs = coverUvs(vec2(1.),vec2(1.));
            
            vec2 squareUvs = coverUvs(vec2(1.),vec2(1.));

            // The horizontal(s) and vertical(t) texture coordinates
            vec2 st = stPos();
            
            // We readjust the mouse coordinates
            float consRatio = -(uContainerResolution.y / uContainerResolution.x) * 0.5;
            vec2 mouse = vec2(uMouse.x * -0.5 - 0.5, consRatio * (uMouse.y + 1.0));

            vec2 objectPos = st + cellCenter(mouse);
            vec2 cellSize = cellSize(uGridSize) * 1.0;
            float c = square(objectPos, cellSize.x * 0.5, 4.);
            // finalMask += smoothstep(0.4, 0.5, c);

            // Calculate grid cell coordinates
            vec2 cellIndex = floor(vUv * uGridSize);

            float cellOpacity = 1.0;
            // if(mod(cellIndex.y + cellIndex.x + 1.0, 2.0) == 0.0) {
            //     cellOpacity = 0.1;
            // }
            // for (int i = 0; i < NUM_OPACITY; i++) {
            //     vec2 storedId = uCellData[i].xy; // Extract stored cell ID

            //     if (cellIndex == storedId) {
            //         cellOpacity = 0.0; // Apply opacity from data
            //         break; // Stop checking once found
            //     }
            // }

            
            for (int i = 0; i < NUM_OPACITY; i++) {
                vec3 cellData = uCellData[i];
                ivec2 cellID = ivec2(int(cellData.x), int(cellData.y));
                vec2 cellPos = st + cellCenter(cellID.x, cellID.y);
                float c = roundedSquareBorder(cellPos, cellSize.x * 0.5, 0.01);
                float opacity = cellData.z;

                finalMask += smoothstep(0.4, 0.5, c) * opacity;
            }

            // To view the masked image in a grided format
            // viewGridedImage(st);

            if (uMobile == 1.0) {
                float outAR = uContainerResolution.x / uContainerResolution.y;
                if (outAR > 16.0 / 9.0) {
                    // Phone screen is wider than the video
                    float scale = (16.0 / 9.0) / outAR;
                    newUvs.y = newUvs.y * scale + (1.0 - scale) * 0.5;
                }
                else {
                    float scale = outAR / (9.0 / 16.0);
                    newUvs.x = newUvs.x * scale + (1.0 - scale) * 0.5;
                }
            }

            vec4 image = texture2D(uTexture,newUvs);
            vec4 displacement = texture2D(uGrid,newUvs);
            
            vec2 finalUvs = newUvs - displacement.rg*0.01;
            
            vec4 finalImage = texture2D(uTexture,finalUvs);
            
            // Cannot use finalUv because it distorts the image
            vec4 finalImageHover = texture2D(uTextureHover, vec2(newUvs.x, newUvs.y)); 

            // Calculate gradient position based on time to animate the gradient
            float gradientPos = mod(uTime * 0.1 + newUvs.x * 0.3, 1.0);

            // Blend colors based on gradient position
            vec3 gradientColor = vec3(
                0.5 + 0.5 * sin(2.0 * 3.14159 * (gradientPos + 0.0)),  // Red channel
                0.5 + 0.5 * sin(2.0 * 3.14159 * (gradientPos + 0.33)), // Green channel (offset by 1/3)
                0.5 + 0.5 * sin(2.0 * 3.14159 * (gradientPos + 0.66))  // Blue channel (offset by 2/3)
            );

            // Blend the gradient color with the texture color
            // finalImageHover.rgb = mix(finalImageHover.rgb, gradientColor, 0.5);  // Adjust 0.5 to control blend amount
            // vec4 combinedColor = finalImageHover + vec4(gradientColor, 0.0);
            vec4 combinedColor = finalImageHover.a > 0.5 ? finalImageHover : vec4(gradientColor, 1.0);
            combinedColor = clamp(combinedColor, 0.0, 1.0);
            
            //rgb shift
            // finalImage.rgb = rgbShift(finalUvs, displacement);

            vec4 maskedImage = mix(finalImage, combinedColor, finalMask);

            vec4 visualDisplacement = displacement;
            visualDisplacement*=0.5;
            visualDisplacement+=0.5;    
            
            vec4 final = step(0.5,uDisplacement)*visualDisplacement + (1.-step(0.5,uDisplacement))*maskedImage;

            if(cellOpacity != 1.0) {
                final *= cellOpacity;
            } else {
                final.a = 1.0; 
            }
            gl_FragColor = final; // White cells, transparency applied

            //#include <tonemapping_fragment>
            //#include <colorspace_fragment>
        }
    `
};
