Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Day 8

day8.livemd

Day 8

Mix.install([
  {:nx, "~> 0.9.2"}
])

Setup

input =
  System.fetch_env!("LB_AOC_DIR")
  |> Path.join("data/day8.txt")
  |> File.read!()

nil
nil

Solve

defmodule TensorUtils do
  def where(tensor, predicate) when is_function(predicate, 1) do
    # Apply predicate to get boolean mask
    mask = predicate.(tensor)
    {rows, cols} = Nx.shape(tensor)

    # Create coordinate tensors
    row_indices = Nx.broadcast(Nx.iota({rows, 1}), {rows, cols})
    col_indices = Nx.broadcast(Nx.iota({1, cols}), {rows, cols})

    # Use select to mask the indices
    masked_rows = Nx.select(mask, row_indices, Nx.broadcast(-1, {rows, cols}))
    masked_cols = Nx.select(mask, col_indices, Nx.broadcast(-1, {rows, cols}))

    # Convert to lists and zip only valid indices
    rows = Nx.to_flat_list(masked_rows)
    cols = Nx.to_flat_list(masked_cols)

    Enum.zip(rows, cols)
    |> Enum.filter(fn {r, c} -> r != -1 and c != -1 end)
  end

  def where_equals(tensor, value) do
    where(tensor, &Nx.equal(&1, value))
  end
end
{:module, TensorUtils, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:where_equals, 2}}
defmodule Day8 do
  def parse(input) do
    for row <- String.split(input, "\n"), row != "" do
      String.to_charlist(row)
    end
    |> Nx.tensor()
  end

  @doc ~S"""
  iex> Day8.part1("............\n........0...\n.....0......\n.......0....\n....0.......\n......A.....\n............\n............\n........A...\n.........A..\n............\n............\n")
  14
  """
  def part1(input) do
    %Nx.Tensor{shape: {maxx, maxy}} = board = parse(input)
    antennas = TensorUtils.where(board, &amp;Nx.not_equal(&amp;1, ?.))

    for {x1, y1} = ant1 <- antennas,
        {x2, y2} = ant2 <- antennas,
        ant1 != ant2,
        board[x1][y1] == board[x2][y2],
        ant1 = ant1 |> Tuple.to_list() |> Nx.tensor(),
        ant2 = ant2 |> Tuple.to_list() |> Nx.tensor(),
        {a1, a2} <- [{ant1, ant2}, {ant2, ant1}] do
      antinode = [x3, y3] = Nx.add(a1, Nx.subtract(a1, a2)) |> Nx.to_list()

      if x3 >= 0 &amp;&amp; x3 < maxx &amp;&amp; y3 >= 0 &amp;&amp; y3 < maxy, do: antinode, else: nil
    end
    |> Enum.filter(&amp; &amp;1)
    |> Enum.uniq()
    |> Enum.count()
  end

  @doc ~S"""
  iex> Day8.part2("............\n........0...\n.....0......\n.......0....\n....0.......\n......A.....\n............\n............\n........A...\n.........A..\n............\n............\n")
  34
  """
  def part2(input) do
    %Nx.Tensor{shape: {maxx, maxy}} = board = parse(input)
    antennas = TensorUtils.where(board, &amp;Nx.not_equal(&amp;1, ?.))

    for {x1, y1} = ant1 <- antennas,
        {x2, y2} = ant2 <- antennas,
        ant1 != ant2,
        board[x1][y1] == board[x2][y2],
        ant1 = ant1 |> Tuple.to_list(),
        ant2 = ant2 |> Tuple.to_list(),
        pair <- [{ant1, ant2}, {ant2, ant1}] do
      Stream.unfold(pair, fn
        {a1, a2} ->
          a1 = Nx.tensor(a1)
          a2 = Nx.tensor(a2)
          antinode = [x3, y3] = Nx.add(a2, Nx.subtract(a2, a1)) |> Nx.to_list()

          if x3 >= 0 &amp;&amp; x3 < maxx &amp;&amp; y3 >= 0 &amp;&amp; y3 < maxy do
            {List.to_tuple(antinode), {a2, antinode}}
          else
            {Tuple.to_list(pair) |> Enum.map(&amp;List.to_tuple/1), :halt}
          end

        :halt ->
          nil
      end)
      |> Stream.filter(&amp; &amp;1)
      |> Enum.to_list()
    end
    |> List.flatten()
    |> Enum.uniq()
    |> Enum.count()
  end
end
{:module, Day8, <<70, 79, 82, 49, 0, 0, 25, ...>>, {:part2, 1}}