Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us
Notesclub

train

02_spam/nbs/train.livemd

train

Mix.install([
  {:exla, "~> 0.5.1"},
  {:nx, "~> 0.5.1"},
  {:explorer, "~> 0.5.0"},
  {:axon, "~> 0.5.1"},
  {:bumblebee, "~> 0.3.0"}
])

Section

Nx.default_backend(EXLA.Backend)
path_train = "./data/combined/train.csv"
path_test = "./data/combined/test.csv"
path_output = "./model"
model_id = "bert-base-cased"
batch_size = 32
sequence_length = 64
lr = 5.0e-5
{:ok, spec} =
  Bumblebee.load_spec({:hf, model_id},
    architecture: :for_sequence_classification
  )

spec = Bumblebee.configure(spec, num_labels: 2)
{:ok, model} = Bumblebee.load_model({:hf, model_id}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_id})
defmodule Spam do
  def load(path, tokenizer, opts \\ []) do
    path
    |> Explorer.DataFrame.from_csv!(header: true)
    |> Explorer.DataFrame.rename(["text", "label"])
    |> stream()
    |> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
  end

  def stream(df) do
    xs = df["text"]
    ys = df["label"]

    xs
    |> Explorer.Series.to_enum()
    |> Stream.zip(Explorer.Series.to_enum(ys))
  end

  def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
    stream
    |> Stream.chunk_every(batch_size)
    |> Stream.map(fn batch ->
      try do
        {text, labels} = Enum.unzip(batch)
        IO.inspect(batch, label: "batch")
        IO.inspect(text, label: "text")
        IO.inspect(labels, label: "labels")

        tokenized = Bumblebee.apply_tokenizer(tokenizer, text, length: sequence_length)
        {tokenized, Nx.stack(labels)}
      rescue
        e in RuntimeError ->
          IO.puts("An error occurred: #{Exception.message(e)}")
          IO.puts("Failed input: #{batch}")
          {:error, batch}
      end
    end)
  end
end
train_data =
  Spam.load(path_train, tokenizer,
    batch_size: batch_size,
    sequence_length: sequence_length
  )

test_data =
  Spam.load(path_test, tokenizer,
    batch_size: batch_size,
    sequence_length: sequence_length
  )
Enum.take(train_data, 1)
%{model: model, params: params} = model

model
[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(model, input)
logits_model = Axon.nx(model, & &1.logits)
[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(logits_model, input)
loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2, reduction: :mean, from_logits: true, sparse: true)

optimizer = Axon.Optimizers.adam(lr)
loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)
loop = Axon.Loop.metric(loop, accuracy, "accuracy")
Enum.take(test_data, 4)
train_data = Enum.take(train_data, 10)
test_data = Enum.take(test_data, 10)
trained_model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.run(train_data, params, epochs: 1, compiler: EXLA, strict?: false)
logits_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(accuracy, "accuracy")
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
# %{model: model, params: params} = trained_model_state
# params_ser = Nx.serialize(params)
# File.write!(path_output, params_ser)