Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

tfjs

stage2/tfjs.livemd

tfjs

Mix.install([
  {:kino, "~> 0.10.0"}
])

Section

defmodule KinoTF.Sin do
  use Kino.JS

  def new(html) do
    Kino.JS.new(__MODULE__, html)
  end

  asset "main.js" do
    """
    import "https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js";
    function draw(ctx, data, n) {
    	var hc = n * 20 + 100;
    	ctx.strokeStyle = "#f00";
    	ctx.lineWidth = 1;
    	ctx.moveTo(0, hc);
    	for (var i = 1; i < 200; i++)
    	{
    		ctx.lineTo(i, hc - data[i] * 30);
    	}
    	ctx.stroke();
    }
    export function init(ctx, html) {
    	const canvas = document.createElement("canvas");
    	canvas.style.width = "400px";
    	canvas.style.height = "400px";
    	ctx.root.appendChild(canvas);
    	const ctx2 = canvas.getContext('2d');
    	ctx2.clearRect(0, 0, canvas.width, canvas.height);
    	ctx2.beginPath();
    	tf.setBackend('cpu');
    	const buffer = tf.buffer([200, 1]);
    	const buffer2 = tf.buffer([200, 1]);
    	var b = [];
    	for (var i = 0; i < 200; i++)
    	{
    		var x = i / 30.0;
    		var y = Math.sin(x);
    		b.push(y);
    		buffer.set(x, i, 0);
    		buffer2.set(y, i, 0);
    	}
    	
    	const model = tf.sequential();
    	eval(html);
    	const xs = buffer.toTensor();
    	const ys = buffer2.toTensor();
    	model.fit(xs, ys, {
    		batchSize: 200,
    		epochs: 6000
    	}).then((d) => {
    		var str = "loss = ";
    		str += d.history.loss[0];
    		alert(str);
    		var p = [];
    		for (i = 0; i < 200; i++)
    		{
    			var x = i / 30.0;
    			var pre = model.predict(tf.tensor2d([[x]], [1, 1]));
    			var f = pre.dataSync();
    			p.push(f);
    		}
    		draw(ctx2, p, 1);
    	});
       draw(ctx2, b, 0);
    }
    """
  end
end

KinoTF.Sin.new("""
model.add(tf.layers.dense({
	units: 40,
	activation: 'tanh',
	inputShape: [1]
}));
model.add(tf.layers.dense({
	units: 1,
	activation: 'tanh'
}));
model.compile({
	optimizer: 'adam',
	loss: 'meanSquaredError'
});
""")