Demo 1: Fine tune
Mix.install(
[
{:axon, "~> 0.5"},
{:bumblebee, "~> 0.5"},
{:nx, "~> 0.8.0", override: true},
{:exla, "~> 0.8.0", override: true},
{:kino, "~> 0.14"},
{:kino_flame, github: "hugobarauna/kino_flame"},
{:flame, "~> 0.5.1"},
{:explorer, "~> 0.9.1"},
{:kino_vega_lite, "~> 0.1.13"},
{:vega_lite, "~> 0.1.9"},
{:table_rex, "~> 4.0", override: true},
{:kino_explorer, "~> 0.1.23"},
{:number, "~> 1.0"}
],
system_env: [
XLA_TARGET: "cuda12",
AWS_ACCESS_KEY_ID: System.get_env("LB_LIVEBOOK_DEMOS_ACCESS_KEY_ID"),
AWS_SECRET_ACCESS_KEY: System.get_env("LB_LIVEBOOK_DEMOS_SECRET_ACCESS_KEY"),
AWS_REGION: "auto",
AWS_ENDPOINT_URL_S3: "https://fly.storage.tigris.dev"
],
config: [
nx: [
default_backend: EXLA.Backend,
default_defn_options: [compiler: EXLA, client: :cuda]
# default_defn_options: [compiler: EXLA]
]
]
)
Configurações iniciais
require Explorer.DataFrame, as: DataFrame
require Explorer.Series, as: Series
alias VegaLite, as: Vl
require Logger
import Kino.Shorts
defmodule RemoteMacOsNotifier do
def notify(message) do
:erpc.call(livebook_node(), System, :cmd, [
"osascript",
[
"-e",
"display notification \"#{message}\" with title \"Livebook\" sound name \"Pop\""
]
])
end
defp livebook_node() do
lb_node =
Node.list(:hidden)
|> Enum.filter(&String.contains?(Atom.to_string(&1), "127.0.0.1"))
|> List.first()
Node.connect(lb_node)
lb_node
end
end
repo = "google-bert/bert-base-cased"
sequence_length = 512
repo_cache_dir_name = String.replace(repo, "/", "--")
repo_cache_dir = "#{Bumblebee.cache_dir()}/huggingface/#{repo_cache_dir_name}"
paths = for file <- File.ls!(repo_cache_dir), do: "#{repo_cache_dir}/#{file}"
Configuração do cluster
Kino.start_child(
{FLAME.Pool,
name: :training_pool,
code_sync: [
start_apps: true,
sync_beams: Kino.beam_paths(),
compress: true,
copy_paths: paths,
verbose: true
],
min: 0,
max: 16,
max_concurrency: 1,
boot_timeout: :timer.minutes(3),
idle_shutdown_after: :timer.minutes(1),
timeout: :infinity,
track_resources: true,
log: :info,
backend:
{FLAME.FlyBackend,
cpu_kind: "performance",
cpus: 4,
memory_mb: 32768,
gpu_kind: "l40s",
gpus: 1,
env: %{
"LIVEBOOK_COOKIE" => Node.get_cookie(),
"AWS_ACCESS_KEY_ID" => System.fetch_env!("AWS_ACCESS_KEY_ID"),
"AWS_ENDPOINT_URL_S3" => System.fetch_env!("AWS_ENDPOINT_URL_S3"),
"AWS_REGION" => System.fetch_env!("AWS_REGION"),
"AWS_SECRET_ACCESS_KEY" => System.fetch_env!("AWS_SECRET_ACCESS_KEY"),
"XLA_TARGET" => System.fetch_env!("XLA_TARGET")
}}}
)
RemoteMacOsNotifier.notify("Cluster configurado ✅")
Dataset
dataset =
DataFrame.from_csv!("s3://livebook-demos/cannabinoid-edibles-expanded.csv",
config: [
endpoint: "https://fly.storage.tigris.dev"
]
)
defmodule Cannabinoid do
def load(dataframe, tokenizer, opts \\ []) do
dataframe
|> stream()
|> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
end
def stream(df) do
xs = df["abstract"]
ys = df["class"]
xs
|> Explorer.Series.to_enum()
|> Stream.zip(Explorer.Series.to_enum(ys))
end
def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
tokenizer = Bumblebee.configure(tokenizer, length: sequence_length)
stream
|> Stream.chunk_every(batch_size)
|> Stream.map(fn batch ->
{text, labels} = Enum.unzip(batch)
tokenized = Bumblebee.apply_tokenizer(tokenizer, text)
{tokenized, Nx.stack(labels)}
end)
end
end
Treinamento
defmodule Demo do
def load_datasets(df, tokenizer, batch_size, sequence_length) do
Logger.log(:info, "Loading dataset")
dataset_size = DataFrame.n_rows(df)
train_size = floor(dataset_size * 0.8)
test_size = dataset_size - train_size
train_df = DataFrame.head(df, train_size)
test_df = DataFrame.tail(df, test_size)
train_data =
Cannabinoid.load(train_df, tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)
test_data =
Cannabinoid.load(test_df, tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)
%{train: train_data, test: test_data}
end
def load_model(repo) do
{:ok, spec} =
Bumblebee.load_spec({:hf, repo},
architecture: :for_sequence_classification
)
spec = Bumblebee.configure(spec, num_labels: 2)
{:ok, %{model: model, params: params} = bumblebee_model} =
Bumblebee.load_model({:hf, repo},
spec: spec,
backend: EXLA.Backend
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, repo})
logits_model = Axon.nx(model, & &1.logits)
%{params: params, model: logits_model, tokenizer: tokenizer, bumblebee_model: bumblebee_model}
end
def build_loop(logits_model, lr, plots, run) do
[loss_plot_handler, accuracy_plot_handler, precision_plot_handler, recall_plot_handler] =
for metric <- ~w[loss accuracy precision recall],
do: plot_handler(plots[metric], metric, run)
logits_model
|> Axon.Loop.trainer(
&Axon.Losses.categorical_cross_entropy(&1, &2,
reduction: :mean,
from_logits: true,
sparse: true
),
Polaris.Optimizers.adam(learning_rate: lr),
log: 1
)
|> Axon.Loop.metric(
&Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true),
"accuracy"
)
|> Axon.Loop.metric(
fn y_true, y_pred ->
y_pred = y_pred |> Axon.Activations.softmax() |> Nx.argmax(axis: -1)
Axon.Metrics.recall(y_true, y_pred)
end,
"recall"
)
|> Axon.Loop.metric(
fn y_true, y_pred ->
y_pred = y_pred |> Axon.Activations.softmax() |> Nx.argmax(axis: -1)
Axon.Metrics.precision(y_true, y_pred)
end,
"precision"
)
|> Axon.Loop.handle_event(:iteration_completed, loss_plot_handler)
|> Axon.Loop.handle_event(:iteration_completed, accuracy_plot_handler)
|> Axon.Loop.handle_event(:iteration_completed, recall_plot_handler)
|> Axon.Loop.handle_event(:iteration_completed, precision_plot_handler)
|> Axon.Loop.handle_event(:iteration_completed, &notify/1)
|> Axon.Loop.checkpoint(event: :epoch_completed)
end
def notify(state) do
IO.inspect(state, label: "STATE CALLED FROM notify")
if state.iteration == 0 do
remote_runtime_node =
Node.list(:hidden)
|> Enum.filter(&String.contains?(Atom.to_string(&1), "remote_runtime"))
|> List.first()
:erpc.call(remote_runtime_node, RemoteMacOsNotifier, :notify, ["Started to plot line chart"])
end
{:continue, state}
end
def train(loop, train_data, params) do
Logger.log(:info, "Training model")
Axon.Loop.run(loop, train_data, params,
epochs: 1,
compiler: EXLA,
strict?: false,
debug: true
)
end
def test(model, test_data, trained_model_state) do
Logger.log(:info, "Testing model")
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(
&Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true),
"accuracy"
)
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA, strict?: false)
end
defp plot_handler(plot, metric, run) do
fn %{
metrics: metrics,
handler_metadata: handler_metadata
} = state ->
unless Map.has_key?(metrics, metric) do
raise ArgumentError,
"invalid metric to plot, key #{inspect(metric)} not present in metrics"
end
plot_metadata_key = "plot_#{metric}"
plot_metadata = Map.get(handler_metadata, plot_metadata_key, %{})
{iteration, plot_metadata} = absolute_iteration(plot_metadata)
Kino.VegaLite.push(plot, %{
"step" => iteration,
metric => Nx.to_number(metrics[metric]),
"run" => run
})
next_handler_metadata = Map.put(handler_metadata, plot_metadata_key, plot_metadata)
{:continue, %{state | handler_metadata: next_handler_metadata}}
end
end
defp absolute_iteration(plot_metadata) do
case plot_metadata do
%{"absolute_iteration" => iteration} ->
{iteration, Map.put(plot_metadata, "absolute_iteration", iteration + 1)}
%{} ->
{0, %{"absolute_iteration" => 1}}
end
end
end
# Demo.load_model(repo)
Hyperparameter search: grid search
grid =
for lr <- Enum.take(50..20//-2, 4), batch_size <- [4, 8, 16, 32] do
%{batch_size: batch_size, lr: lr * 1.0e-5}
end
|> Stream.with_index()
|> Enum.map(fn {params, idx} -> Map.put(params, :run, "run_#{idx}") end)
length(grid)
experiments = read_number("Nº de experimentos para o grid search", default: 16)
Kino.nothing()
grid = Enum.take(grid, experiments)
dataframe =
dataset
|> DataFrame.shuffle(seed: 1)
|> DataFrame.head(300)
plots =
for metric <- ~w[loss accuracy precision recall],
into: %{},
do:
{metric,
Vl.new(width: 400, height: 300)
|> Vl.mark(:line)
|> Vl.encode_field(:x, "step", type: :quantitative)
|> Vl.encode_field(:y, metric, type: :quantitative, tooltip: true)
|> Vl.encode_field(:color, "run", type: :nominal)
|> Kino.VegaLite.new()}
stream =
Task.async_stream(
grid,
fn %{lr: lr, run: run, batch_size: batch_size} ->
FLAME.call(
:training_pool,
fn ->
%{params: params, model: model, tokenizer: tokenizer} = Demo.load_model(repo)
loop = Demo.build_loop(model, lr, plots, run)
datasets = Demo.load_datasets(dataframe, tokenizer, batch_size, sequence_length)
trained_model_state =
Demo.train(loop, datasets[:train], params)
%{
trained_model_state: Nx.backend_transfer(trained_model_state),
run: run,
lr: lr
}
end,
timeout: :infinity
)
end,
max_concurrency: length(grid),
timeout: :infinity,
ordered: false
)
plots
|> Map.values()
|> Kino.Layout.grid(columns: 2)
trained_runs =
stream
|> Stream.map(fn {:ok, val} -> val end)
|> Enum.to_list()
RemoteMacOsNotifier.notify("Training finished ✅")
trained_runs
|> List.first()
# |> Map.keys()
# |> Map.delete(:trained_model_state)
number_of_weights =
trained_runs
|> List.first()
|> then(fn run -> run.trained_model_state end)
|> Enum.flat_map(fn {_, param} ->
_tensors = Map.values(param)
end)
|> Enum.reduce(0, fn tensor, acc ->
acc + Nx.size(tensor)
end)
|> then(fn number_of_weights ->
Number.Human.number_to_human(number_of_weights) <> " of params in each of the #{experiments} trained model(s)"
end)