Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Demo 1: Fine tune

sorry-dave/fine_tune_demo.livemd

Demo 1: Fine tune

Mix.install(
  [
    {:axon, "~> 0.5"},
    {:bumblebee, "~> 0.5"},
    {:nx, "~> 0.8.0", override: true},
    {:exla, "~> 0.8.0", override: true},
    {:kino, "~> 0.14"},
    {:kino_flame, github: "hugobarauna/kino_flame"},
    {:flame, "~> 0.5.1"},
    {:explorer, "~> 0.9.1"},
    {:kino_vega_lite, "~> 0.1.13"},
    {:vega_lite, "~> 0.1.9"},
    {:table_rex, "~> 4.0", override: true},
    {:kino_explorer, "~> 0.1.23"},
    {:number, "~> 1.0"}
  ],
  system_env: [
    XLA_TARGET: "cuda12",
    AWS_ACCESS_KEY_ID: System.get_env("LB_LIVEBOOK_DEMOS_ACCESS_KEY_ID"),
    AWS_SECRET_ACCESS_KEY: System.get_env("LB_LIVEBOOK_DEMOS_SECRET_ACCESS_KEY"),
    AWS_REGION: "auto",
    AWS_ENDPOINT_URL_S3: "https://fly.storage.tigris.dev"
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend,
      default_defn_options: [compiler: EXLA, client: :cuda]
      # default_defn_options: [compiler: EXLA]
    ]
  ]
)

Configurações iniciais

require Explorer.DataFrame, as: DataFrame
require Explorer.Series, as: Series

alias VegaLite, as: Vl

require Logger

import Kino.Shorts
defmodule RemoteMacOsNotifier do
  def notify(message) do
    :erpc.call(livebook_node(), System, :cmd, [
      "osascript",
      [
        "-e",
        "display notification \"#{message}\" with title \"Livebook\" sound name \"Pop\""
      ]
    ])
  end

  defp livebook_node() do
    lb_node =
      Node.list(:hidden)
      |> Enum.filter(&String.contains?(Atom.to_string(&1), "127.0.0.1"))
      |> List.first()

    Node.connect(lb_node)
    lb_node
  end
end
repo = "google-bert/bert-base-cased"
sequence_length = 512
repo_cache_dir_name = String.replace(repo, "/", "--")
repo_cache_dir = "#{Bumblebee.cache_dir()}/huggingface/#{repo_cache_dir_name}"

paths = for file <- File.ls!(repo_cache_dir), do: "#{repo_cache_dir}/#{file}"

Configuração do cluster

Kino.start_child(
  {FLAME.Pool,
   name: :training_pool,
   code_sync: [
     start_apps: true,
     sync_beams: Kino.beam_paths(),
     compress: true,
     copy_paths: paths,
     verbose: true
   ],
   min: 0,
   max: 16,
   max_concurrency: 1,
   boot_timeout: :timer.minutes(3),
   idle_shutdown_after: :timer.minutes(1),
   timeout: :infinity,
   track_resources: true,
   log: :info,
   backend:
     {FLAME.FlyBackend,
      cpu_kind: "performance",
      cpus: 4,
      memory_mb: 32768,
      gpu_kind: "l40s",
      gpus: 1,
      env: %{
        "LIVEBOOK_COOKIE" => Node.get_cookie(),
        "AWS_ACCESS_KEY_ID" => System.fetch_env!("AWS_ACCESS_KEY_ID"),
        "AWS_ENDPOINT_URL_S3" => System.fetch_env!("AWS_ENDPOINT_URL_S3"),
        "AWS_REGION" => System.fetch_env!("AWS_REGION"),
        "AWS_SECRET_ACCESS_KEY" => System.fetch_env!("AWS_SECRET_ACCESS_KEY"),
        "XLA_TARGET" => System.fetch_env!("XLA_TARGET")
      }}}
)
RemoteMacOsNotifier.notify("Cluster configurado ✅")

Dataset

dataset =
  DataFrame.from_csv!("s3://livebook-demos/cannabinoid-edibles-expanded.csv",
    config: [
      endpoint: "https://fly.storage.tigris.dev"
    ]
  )
defmodule Cannabinoid do
  def load(dataframe, tokenizer, opts \\ []) do
    dataframe
    |> stream()
    |> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
  end

  def stream(df) do
    xs = df["abstract"]
    ys = df["class"]

    xs
    |> Explorer.Series.to_enum()
    |> Stream.zip(Explorer.Series.to_enum(ys))
  end

  def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
    tokenizer = Bumblebee.configure(tokenizer, length: sequence_length)

    stream
    |> Stream.chunk_every(batch_size)
    |> Stream.map(fn batch ->
      {text, labels} = Enum.unzip(batch)
      tokenized = Bumblebee.apply_tokenizer(tokenizer, text)
      {tokenized, Nx.stack(labels)}
    end)
  end
end

Treinamento

defmodule Demo do
  def load_datasets(df, tokenizer, batch_size, sequence_length) do
    Logger.log(:info, "Loading dataset")

    dataset_size = DataFrame.n_rows(df)
    train_size = floor(dataset_size * 0.8)
    test_size = dataset_size - train_size
    train_df = DataFrame.head(df, train_size)
    test_df = DataFrame.tail(df, test_size)

    train_data =
      Cannabinoid.load(train_df, tokenizer,
        batch_size: batch_size,
        sequence_length: sequence_length
      )

    test_data =
      Cannabinoid.load(test_df, tokenizer,
        batch_size: batch_size,
        sequence_length: sequence_length
      )

    %{train: train_data, test: test_data}
  end

  def load_model(repo) do
    {:ok, spec} =
      Bumblebee.load_spec({:hf, repo},
        architecture: :for_sequence_classification
      )

    spec = Bumblebee.configure(spec, num_labels: 2)

    {:ok, %{model: model, params: params} = bumblebee_model} =
      Bumblebee.load_model({:hf, repo},
        spec: spec,
        backend: EXLA.Backend
      )

    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, repo})

    logits_model = Axon.nx(model, &amp; &amp;1.logits)

    %{params: params, model: logits_model, tokenizer: tokenizer, bumblebee_model: bumblebee_model}
  end

  def build_loop(logits_model, lr, plots, run) do
    [loss_plot_handler, accuracy_plot_handler, precision_plot_handler, recall_plot_handler] =
      for metric <- ~w[loss accuracy precision recall],
          do: plot_handler(plots[metric], metric, run)

    logits_model
    |> Axon.Loop.trainer(
      &amp;Axon.Losses.categorical_cross_entropy(&amp;1, &amp;2,
        reduction: :mean,
        from_logits: true,
        sparse: true
      ),
      Polaris.Optimizers.adam(learning_rate: lr),
      log: 1
    )
    |> Axon.Loop.metric(
      &amp;Axon.Metrics.accuracy(&amp;1, &amp;2, from_logits: true, sparse: true),
      "accuracy"
    )
    |> Axon.Loop.metric(
      fn y_true, y_pred ->
        y_pred = y_pred |> Axon.Activations.softmax() |> Nx.argmax(axis: -1)
        Axon.Metrics.recall(y_true, y_pred)
      end,
      "recall"
    )
    |> Axon.Loop.metric(
      fn y_true, y_pred ->
        y_pred = y_pred |> Axon.Activations.softmax() |> Nx.argmax(axis: -1)
        Axon.Metrics.precision(y_true, y_pred)
      end,
      "precision"
    )
    |> Axon.Loop.handle_event(:iteration_completed, loss_plot_handler)
    |> Axon.Loop.handle_event(:iteration_completed, accuracy_plot_handler)
    |> Axon.Loop.handle_event(:iteration_completed, recall_plot_handler)
    |> Axon.Loop.handle_event(:iteration_completed, precision_plot_handler)
    |> Axon.Loop.handle_event(:iteration_completed, &amp;notify/1)
    |> Axon.Loop.checkpoint(event: :epoch_completed)
  end

  def notify(state) do
    IO.inspect(state, label: "STATE CALLED FROM notify")

    if state.iteration == 0 do
      remote_runtime_node =
        Node.list(:hidden)
        |> Enum.filter(&amp;String.contains?(Atom.to_string(&amp;1), "remote_runtime"))
        |> List.first()

      :erpc.call(remote_runtime_node, RemoteMacOsNotifier, :notify, ["Started to plot line chart"])
    end

    {:continue, state}
  end

  def train(loop, train_data, params) do
    Logger.log(:info, "Training model")

    Axon.Loop.run(loop, train_data, params,
      epochs: 1,
      compiler: EXLA,
      strict?: false,
      debug: true
    )
  end

  def test(model, test_data, trained_model_state) do
    Logger.log(:info, "Testing model")

    model
    |> Axon.Loop.evaluator()
    |> Axon.Loop.metric(
      &amp;Axon.Metrics.accuracy(&amp;1, &amp;2, from_logits: true, sparse: true),
      "accuracy"
    )
    |> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA, strict?: false)
  end

  defp plot_handler(plot, metric, run) do
    fn %{
         metrics: metrics,
         handler_metadata: handler_metadata
       } = state ->
      unless Map.has_key?(metrics, metric) do
        raise ArgumentError,
              "invalid metric to plot, key #{inspect(metric)} not present in metrics"
      end

      plot_metadata_key = "plot_#{metric}"
      plot_metadata = Map.get(handler_metadata, plot_metadata_key, %{})

      {iteration, plot_metadata} = absolute_iteration(plot_metadata)

      Kino.VegaLite.push(plot, %{
        "step" => iteration,
        metric => Nx.to_number(metrics[metric]),
        "run" => run
      })

      next_handler_metadata = Map.put(handler_metadata, plot_metadata_key, plot_metadata)

      {:continue, %{state | handler_metadata: next_handler_metadata}}
    end
  end

  defp absolute_iteration(plot_metadata) do
    case plot_metadata do
      %{"absolute_iteration" => iteration} ->
        {iteration, Map.put(plot_metadata, "absolute_iteration", iteration + 1)}

      %{} ->
        {0, %{"absolute_iteration" => 1}}
    end
  end
end
# Demo.load_model(repo)

Hyperparameter search: grid search

grid =
  for lr <- Enum.take(50..20//-2, 4), batch_size <- [4, 8, 16, 32] do
    %{batch_size: batch_size, lr: lr * 1.0e-5}
  end
  |> Stream.with_index()
  |> Enum.map(fn {params, idx} -> Map.put(params, :run, "run_#{idx}") end)
length(grid)
experiments = read_number("Nº de experimentos para o grid search", default: 16)

Kino.nothing()
grid = Enum.take(grid, experiments)
dataframe = 
  dataset
  |> DataFrame.shuffle(seed: 1)
  |> DataFrame.head(300)
plots =
  for metric <- ~w[loss accuracy precision recall],
      into: %{},
      do:
        {metric,
         Vl.new(width: 400, height: 300)
         |> Vl.mark(:line)
         |> Vl.encode_field(:x, "step", type: :quantitative)
         |> Vl.encode_field(:y, metric, type: :quantitative, tooltip: true)
         |> Vl.encode_field(:color, "run", type: :nominal)
         |> Kino.VegaLite.new()}
stream =
  Task.async_stream(
    grid,
    fn %{lr: lr, run: run, batch_size: batch_size} ->
      FLAME.call(
        :training_pool,
        fn ->
          %{params: params, model: model, tokenizer: tokenizer} = Demo.load_model(repo)
          loop = Demo.build_loop(model, lr, plots, run)
          datasets = Demo.load_datasets(dataframe, tokenizer, batch_size, sequence_length)

          trained_model_state =
            Demo.train(loop, datasets[:train], params)

          %{
            trained_model_state: Nx.backend_transfer(trained_model_state),
            run: run,
            lr: lr
          }
        end,
        timeout: :infinity
      )
    end,
    max_concurrency: length(grid),
    timeout: :infinity,
    ordered: false
  )
plots 
|> Map.values()
|> Kino.Layout.grid(columns: 2)
trained_runs =
  stream
  |> Stream.map(fn {:ok, val} -> val end)
  |> Enum.to_list()

RemoteMacOsNotifier.notify("Training finished ✅")
trained_runs
|> List.first()
# |> Map.keys()
# |> Map.delete(:trained_model_state)
number_of_weights =
  trained_runs
  |> List.first()
  |> then(fn run -> run.trained_model_state end)
  |> Enum.flat_map(fn {_, param} ->
    _tensors = Map.values(param)
  end)
  |> Enum.reduce(0, fn tensor, acc ->
    acc + Nx.size(tensor)
  end)
  |> then(fn number_of_weights ->
    Number.Human.number_to_human(number_of_weights) <> " of params in each of the #{experiments} trained model(s)"
  end)