Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Convolution with WebGPU

conv_gpu.livemd

Convolution with WebGPU

Mix.install(
  [
    {:nx, "~> 0.9.2"},
    {:scidata, "~> 0.1.11"},
    {:exla, "~> 0.9.2"},
    {:kino, "~> 0.14.2"},
    {:stb_image, "~> 0.6.9"},
    {:axon, "~> 0.7.0"},
    {:kino_vega_lite, "~> 0.1.11"},
    {:nx_signal, "~> 0.2.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)
defmodule D do
  use Kino.JS
  use Kino.JS.Live

  def html() do
    """
    
    
    """
  end

  def start(), do: Kino.JS.Live.new(__MODULE__, html())

  @impl true
  def handle_connect(ctx), do: {:ok, ctx.assigns.html, ctx}

  @impl true
  def init(html, ctx), do: {:ok, assign(ctx, html: html)}

  asset "main.js" do
    """
    export async function init(ctx,html) {
      ctx.root.innerHTML = html;

      const width = 256;
      const height = 256;

      const adapter = await navigator.gpu.requestAdapter();
      console.log(adapter.info)
      const device = await adapter.requestDevice();
      const gpuCanvas = document.getElementById("gpucanvas");
      const gpuCanvasCtx = gpuCanvas.getContext("webgpu");
      const offscreen = new OffscreenCanvas(width, height);
      const offscreenCtx = offscreen.getContext('2d', { willReadFrequently: true });

      gpuCanvasCtx.configure({
        device,
        format: "bgra8unorm",
        usage: GPUTextureUsage.RENDER_ATTACHMENT
      });


      const inputTexture = device.createTexture({
        size: [width, height],
        format: "rgba8unorm",
        usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_DST
      });


      const outputTexture = device.createTexture({
        size: [width, height],
        format: "rgba8unorm",
        usage: GPUTextureUsage.STORAGE_BINDING | GPUTextureUsage.TEXTURE_BINDING
      });

      const sampler = device.createSampler({
        magFilter: 'linear',
        minFilter: 'linear'
      });

      const fragmentShaderCode = `
        @group(0) @binding(0) var outputTexture: texture_2d;
        @group(0) @binding(2) var outputSampler: sampler;

        @fragment
        fn main(@builtin(position) pos: vec4f) -> @location(0) vec4f {
          let size = vec2f(256.0, 256.0);
          let uv = pos.xy / size;
          return textureSample(outputTexture, outputSampler, uv);
        }
      `;

      const computeShaderCode = `
        @group(0) @binding(0) var inputTexture: texture_2d;
        @group(0) @binding(1) var outputTexture: texture_storage_2d;
        @group(0) @binding(2) var ourSampler: sampler;  // Added sampler

        /*
        const kernelSize: i32 = 5;
        const kernel: array = array(
          0.1, 0.15, 0.5, 0.15, 0.1
        );
        */

        const kernelSize: i32 = 9;
        const kernel: array = array(
          0.000229, 0.005977, 0.060598, 0.241732, 0.382928, 0.241732, 0.060598, 0.005977, 0.000229
        );



        @compute @workgroup_size(8, 8)
        fn main(@builtin(global_invocation_id) global_id: vec3) {
          let coords = vec2(global_id.xy);
          let texSize = textureDimensions(inputTexture);

          if (coords.x >= i32(texSize.x) || coords.y >= i32(texSize.y)) {
            return;
          }

          var color = vec4(0.0, 0.0, 0.0, 0.0);

          // 2D Gaussian blur
          for (var y: i32 = -kernelSize / 2; y <= kernelSize / 2; y = y + 1) {
            for (var x: i32 = -kernelSize / 2; x <= kernelSize / 2; x = x + 1) {
              let offsetCoords = coords + vec2(x, y);
              let clampedCoords = clamp(
                offsetCoords,
                vec2(0),
                vec2(i32(texSize.x - 1), i32(texSize.y - 1))
              );
              let weight = kernel[x + kernelSize / 2] * kernel[y + kernelSize / 2];
              color += weight * textureLoad(inputTexture, clampedCoords, 0);
            }
          }

          textureStore(outputTexture, coords, color);
        }
      `;

      const bindGroupLayout = device.createBindGroupLayout({
        entries: [
          {
            binding: 0,
            visibility: GPUShaderStage.COMPUTE | GPUShaderStage.FRAGMENT,
            texture: { sampleType: "float" }
          },
          {
            binding: 1,
            visibility: GPUShaderStage.COMPUTE,
            storageTexture: { access: "write-only", format: "rgba8unorm" }
          },
          {
            binding: 2,
            visibility: GPUShaderStage.FRAGMENT,
            sampler: { type: "filtering" }
          }
        ]
      });

      const bindGroup = device.createBindGroup({
        layout: bindGroupLayout,
        entries: [
          { binding: 0, resource: inputTexture.createView() },
          { binding: 1, resource: outputTexture.createView() },
          { binding: 2, resource: sampler }  // Add sampler
        ]
      });






    const computePipeline = device.createComputePipeline({
      layout: device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] }),
      compute: {
        module: device.createShaderModule({ code: computeShaderCode }),
        entryPoint: "main"
      }
    });

    const renderPipeline = device.createRenderPipeline({
      layout: device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] }),
      vertex: {
        module: device.createShaderModule({
          code: `
            @vertex
            fn main(@builtin(vertex_index) vertexIndex: u32) -> @builtin(position) vec4f {
              var pos = array(
                vec2f(-1.0, -1.0),
                vec2f(1.0, -1.0),
                vec2f(-1.0, 1.0),
                vec2f(-1.0, 1.0),
                vec2f(1.0, -1.0),
                vec2f(1.0, 1.0)
              );
              return vec4f(pos[vertexIndex], 0.0, 1.0);
            }
          `
        }),
        entryPoint: 'main'
      },
      fragment: {
        module: device.createShaderModule({ code: fragmentShaderCode }),
        entryPoint: 'main',
        targets: [{ format: 'bgra8unorm' }]
      },
      primitive: { topology: 'triangle-list' }
    });


    const video = document.getElementById("gpuinvid");
    const stream = await navigator.mediaDevices.getUserMedia({
      video: { width, height }
    });
    video.srcObject = stream;
    video.play();

    function processFrame() {
      offscreenCtx.drawImage(video, 0, 0, width, height);
      const imageData = offscreenCtx.getImageData(0, 0, width, height);

      device.queue.writeTexture(
        { texture: inputTexture },
        imageData.data,
        { bytesPerRow: width * 4 },
        { width, height }
      );

      const commandEncoder = device.createCommandEncoder();

      // Compute pass
      const computePass = commandEncoder.beginComputePass();
      computePass.setPipeline(computePipeline);
      computePass.setBindGroup(0, bindGroup);
      computePass.dispatchWorkgroups(Math.ceil(width / 8), Math.ceil(height / 8));
      computePass.end();

      // Render pass
      const renderPassDescriptor = {
        colorAttachments: [{
          view: gpuCanvasCtx.getCurrentTexture().createView(),
          clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 },
          loadOp: 'clear',
          storeOp: 'store'
        }]
      };

      const renderPass = commandEncoder.beginRenderPass(renderPassDescriptor);
      renderPass.setPipeline(renderPipeline);
      renderPass.setBindGroup(0, bindGroup);
      renderPass.draw(6);
      renderPass.end();

      device.queue.submit([commandEncoder.finish()]);
      requestAnimationFrame(processFrame);
    }

    processFrame();

    }
    """
  end
end
{:module, D, <<70, 79, 82, 49, 0, 0, 13, ...>>, :ok}
D.start()