Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Modeling Everything with Transformers

ModelingEverythingWithTransformers.livemd

Modeling Everything with Transformers

Mix.install([
  {:bumblebee, "~> 0.1.0", github: "elixir-nx/bumblebee"},
  {:axon, "~> 0.5"},
  {:exla, "~> 0.5"},
  {:nx, "~> 0.5", override: true},
  {:kino, "~> 0.8"},
  {:kino_bumblebee, "~> 0.1", github: "livebook-dev/kino_bumblebee"}
])

Nx.global_default_backend(EXLA.Backend)
{EXLA.Backend, []}

Zero-shot classification with BART

{:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-mnli"})

08:58:21.407 [debug] the following PyTorch parameters were unused:

  * model.decoder.version
  * model.encoder.version
  * model.shared.weight

{:ok,
 %Bumblebee.Text.BartTokenizer{
   tokenizer: #Tokenizers.Tokenizer<[
     vocab_size: 50265,
     continuing_subword_prefix: "",
     dropout: nil,
     end_of_word_suffix: "",
     fuse_unk: false,
     model_type: "bpe",
     unk_token: nil
   ]>,
   special_tokens: %{
     bos: "",
     cls: "",
     eos: "",
     mask: "",
     pad: "",
     sep: "",
     unk: ""
   }
 }}
IO.inspect(model.model)
#Axon<
  inputs: %{"attention_head_mask" => {12, 16}, "attention_mask" => {nil, nil}, "cache" => nil, "cross_attention_head_mask" => {12, 16}, "decoder_attention_head_mask" => {12, 16}, "decoder_attention_mask" => {nil, nil}, "decoder_input_embeddings" => {nil, nil, 1024}, "decoder_input_ids" => {nil, nil}, "decoder_position_ids" => {nil, nil}, "encoder_hidden_state" => {nil, nil, 1024}, "input_embeddings" => {nil, nil, 1024}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "container_100"
  nodes: 2359
>
#Axon<
  inputs: %{"attention_head_mask" => {12, 16}, "attention_mask" => {nil, nil}, "cache" => nil, "cross_attention_head_mask" => {12, 16}, "decoder_attention_head_mask" => {12, 16}, "decoder_attention_mask" => {nil, nil}, "decoder_input_embeddings" => {nil, nil, 1024}, "decoder_input_ids" => {nil, nil}, "decoder_position_ids" => {nil, nil}, "encoder_hidden_state" => {nil, nil, 1024}, "input_embeddings" => {nil, nil, 1024}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "container_100"
  nodes: 2359
>
labels = ["New booking", "Update booking", "Cancel booking", "Refund"]

zero_shot_serving =
  Bumblebee.Text.zero_shot_classification(
    model,
    tokenizer,
    labels
  )
%Nx.Serving{
  module: Nx.Serving.Default,
  arg: #Function<3.95287396/1 in Bumblebee.Text.ZeroShotClassification.zero_shot_classification/4>,
  client_preprocessing: #Function<4.95287396/1 in Bumblebee.Text.ZeroShotClassification.zero_shot_classification/4>,
  client_postprocessing: #Function<5.95287396/3 in Bumblebee.Text.ZeroShotClassification.zero_shot_classification/4>,
  distributed_postprocessing: &Function.identity/1,
  process_options: [batch_size: nil],
  defn_options: []
}
input = "I need to book a new flight"

Nx.Serving.run(zero_shot_serving, input)
%{
  predictions: [
    %{label: "New booking", score: 0.5991652011871338},
    %{label: "Update booking", score: 0.3455488979816437},
    %{label: "Refund", score: 0.028283976018428802},
    %{label: "Cancel booking", score: 0.027001921087503433}
  ]
}
inputs = [
  "I want to change my existing flight",
  "I want to cancel my current flight",
  "I demand my money back"
]

Nx.Serving.run(zero_shot_serving, inputs)
[
  %{
    predictions: [
      %{label: "New booking", score: 0.43927058577537537},
      %{label: "Update booking", score: 0.4268641471862793},
      %{label: "Cancel booking", score: 0.10792690515518188},
      %{label: "Refund", score: 0.02593844011425972}
    ]
  },
  %{
    predictions: [
      %{label: "Cancel booking", score: 0.5605528950691223},
      %{label: "Refund", score: 0.3020733594894409},
      %{label: "Update booking", score: 0.09756755083799362},
      %{label: "New booking", score: 0.03980622440576553}
    ]
  },
  %{
    predictions: [
      %{label: "Refund", score: 0.913806140422821},
      %{label: "Cancel booking", score: 0.04736287519335747},
      %{label: "Update booking", score: 0.02491646446287632},
      %{label: "New booking", score: 0.013914537616074085}
    ]
  }
]

Generating Text

{:ok, model} = Bumblebee.load_model({:hf, "microsoft/DialoGPT-medium"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})

08:58:23.161 [debug] the following PyTorch parameters were unused:

  * lm_head.weight
  * transformer.h.0.attn.bias
  * transformer.h.1.attn.bias
  * transformer.h.10.attn.bias
  * transformer.h.11.attn.bias
  * transformer.h.12.attn.bias
  * transformer.h.13.attn.bias
  * transformer.h.14.attn.bias
  * transformer.h.15.attn.bias
  * transformer.h.16.attn.bias
  * transformer.h.17.attn.bias
  * transformer.h.18.attn.bias
  * transformer.h.19.attn.bias
  * transformer.h.2.attn.bias
  * transformer.h.20.attn.bias
  * transformer.h.21.attn.bias
  * transformer.h.22.attn.bias
  * transformer.h.23.attn.bias
  * transformer.h.3.attn.bias
  * transformer.h.4.attn.bias
  * transformer.h.5.attn.bias
  * transformer.h.6.attn.bias
  * transformer.h.7.attn.bias
  * transformer.h.8.attn.bias
  * transformer.h.9.attn.bias

{:ok,
 %Bumblebee.Text.Gpt2Tokenizer{
   tokenizer: #Tokenizers.Tokenizer<[
     vocab_size: 50257,
     continuing_subword_prefix: "",
     dropout: nil,
     end_of_word_suffix: "",
     fuse_unk: false,
     model_type: "bpe",
     unk_token: nil
   ]>,
   special_tokens: %{
     bos: "<|endoftext|>",
     eos: "<|endoftext|>",
     pad: "<|endoftext|>",
     unk: "<|endoftext|>"
   }
 }}
serving =
  Bumblebee.Text.conversation(model, tokenizer,
    max_new_tokens: 100,
    compile: [batch_size: 1, sequence_length: 1000],
    defn_options: [compiler: EXLA]
  )
%Nx.Serving{
  module: Nx.Serving.Default,
  arg: #Function<0.123941517/1 in Bumblebee.Text.Conversation.conversation/3>,
  client_preprocessing: #Function<1.123941517/1 in Bumblebee.Text.Conversation.conversation/3>,
  client_postprocessing: #Function<2.123941517/3 in Bumblebee.Text.Conversation.conversation/3>,
  distributed_postprocessing: &Function.identity/1,
  process_options: [batch_size: 1],
  defn_options: [compiler: EXLA]
}
frame = Kino.Frame.new()

controls = [message: Kino.Input.text("New Message")]
form = Kino.Control.form(controls, submit: "Send Message", reset_on_submit: [:message])

form
|> Kino.Control.stream()
|> Kino.listen(nil, fn %{data: %{message: message}}, history ->
  Kino.Frame.append(frame, Kino.Markdown.new("**Me:** #{message}"))
  %{text: text, history: history} = Nx.Serving.run(serving, %{text: message, history: history})
  Kino.Frame.append(frame, Kino.Markdown.new("**Bot:** #{text}"))
  {:cont, history}
end)

Kino.Layout.grid([frame, form], gap: 16)

Classifying Images

{:ok, model_info} = Bumblebee.load_model({:hf, "google/vit-base-patch16-224"})

{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "google/vit-base-patch16-224"})

serving =
  Bumblebee.Vision.image_classification(model_info, featurizer,
    top_k: 1,
    compile: [batch_size: 1],
    defn_options: [compiler: EXLA]
  )
%Nx.Serving{
  module: Nx.Serving.Default,
  arg: #Function<1.63124952/1 in Bumblebee.Vision.ImageClassification.image_classification/3>,
  client_preprocessing: #Function<2.63124952/1 in Bumblebee.Vision.ImageClassification.image_classification/3>,
  client_postprocessing: #Function<3.63124952/3 in Bumblebee.Vision.ImageClassification.image_classification/3>,
  distributed_postprocessing: &Function.identity/1,
  process_options: [batch_size: 1],
  defn_options: [compiler: EXLA]
}
image_input = Kino.Input.image("Image", size: {224, 224})
form = Kino.Control.form([image: image_input], submit: "Run")
frame = Kino.Frame.new()

form
|> Kino.Control.stream()
|> Stream.filter(&amp; &amp;1.data.image)
|> Kino.listen(fn %{data: %{image: image}} ->
  Kino.Frame.render(frame, Kino.Markdown.new("Running..."))

  image =
    image.data
    |> Nx.from_binary(:u8)
    |> Nx.reshape({image.height, image.width, 3})

  output = Nx.Serving.run(serving, image)

  output.predictions
  |> Enum.map(&amp;{&amp;1.label, &amp;1.score})
  |> Kino.Bumblebee.ScoredList.new()
  |> then(&amp;Kino.Frame.render(frame, &amp;1))
end)

Kino.Layout.grid([form, frame], boxed: true, gap: 16)

Fine-tuning Pre-trained Models

{:ok, spec} =
  Bumblebee.load_spec({:hf, "distilbert-base-cased"},
    module: Bumblebee.Text.Distilbert,
    architecture: :for_sequence_classification
  )

spec = Bumblebee.configure(spec, num_labels: 5)

{:ok, %{model: model, params: params}} =
  Bumblebee.load_model(
    {:hf, "distilbert-base-cased"},
    spec: spec
  )

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "distilbert-base-cased"})

08:58:24.188 [debug] the following parameters were missing:

  * sequence_classification_head.output.kernel
  * sequence_classification_head.output.bias
  * pooler.output.kernel
  * pooler.output.bias


08:58:24.188 [debug] the following PyTorch parameters were unused:

  * vocab_layer_norm.bias
  * vocab_layer_norm.weight
  * vocab_projector.bias
  * vocab_projector.weight
  * vocab_transform.bias
  * vocab_transform.weight

{:ok,
 %Bumblebee.Text.DistilbertTokenizer{
   tokenizer: #Tokenizers.Tokenizer<[
     vocab_size: 28996,
     continuing_subword_prefix: "##",
     max_input_chars_per_word: 100,
     model_type: "bpe",
     unk_token: "[UNK]"
   ]>,
   special_tokens: %{cls: "[CLS]", mask: "[MASK]", pad: "[PAD]", sep: "[SEP]", unk: "[UNK]"}
 }}
batch_size = 32
max_length = 128

train_data =
  File.stream!("yelp_review_full_csv/train.csv")
  |> Stream.chunk_every(batch_size)
  |> Stream.map(fn inputs ->
    {labels, reviews} =
      inputs
      |> Enum.map(fn line ->
        [label, review] = String.split(line, "\",\"")
        {String.trim(label, "\""), String.trim(review, "\"")}
      end)
      |> Enum.unzip()

    labels = labels |> Enum.map(&amp;String.to_integer/1) |> Nx.tensor()
    tokens = Bumblebee.apply_tokenizer(tokenizer, reviews, length: max_length)

    {tokens, labels}
  end)
#Stream<[
  enum: #Stream<[
    enum: %File.Stream{
      path: "yelp_review_full_csv/train.csv",
      modes: [:raw, :read_ahead, :binary],
      line_or_bytes: :line,
      raw: true
    },
    funs: [#Function<3.6935098/1 in Stream.chunk_while/4>]
  ]>,
  funs: [#Function<48.6935098/1 in Stream.map/2>]
]>
Enum.take(train_data, 1)
[
  {%{
     "attention_mask" => #Nx.Tensor<
       s64[32][128]
       EXLA.Backend
       [
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...],
         ...
       ]
     >,
     "input_ids" => #Nx.Tensor<
       s64[32][128]
       EXLA.Backend
       [
         [101, 173, 1197, 119, 2284, 2953, 3272, 1917, 178, 1440, 1111, 1107, 170, 1704, 22351, 119, 1119, 112, 188, 3505, 1105, 3123, 1106, 2037, 1106, 1443, 1217, 10063, 4404, 132, 1119, 112, 188, 1579, 1113, 1159, 1107, 3195, 1117, 4420, 132, 1119, 112, 188, 6559, 1114, ...],
         ...
       ]
     >,
     "token_type_ids" => #Nx.Tensor<
       s64[32][128]
       EXLA.Backend
       [
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[32]
     EXLA.Backend
     [5, 2, 4, 4, 1, 5, 5, 1, 2, 3, 1, 1, 4, 2, 5, 5, 5, 5, 5, 5, 4, 3, 2, 5, 1, 1, 1, 2, 2, 4, 2, 2]
   >}
]
Axon.get_output_shape(model, %{"input_ids" => Nx.template({32, 128}, :s64)})
%{attentions: #Axon.None<...>, hidden_states: #Axon.None<...>, logits: {32, 5}}
model = Axon.nx(model, fn %{logits: logits} -> logits end)
#Axon<
  inputs: %{"attention_head_mask" => {6, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "nx_0"
  nodes: 407
>
optimizer = Axon.Optimizers.adamw(5.0e-5)

loss =
  &amp;Axon.Losses.categorical_cross_entropy(&amp;1, &amp;2,
    from_logits: true,
    sparse: true,
    reduction: :mean
  )

trained_model_state =
  model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA)

08:58:24.354 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 850, accuracy: 0.0936031 loss: 0.9988262