Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

LSTM Demo

demo.livemd

LSTM Demo

Introduction

In Machine learning the process is always the same regardless of what technique is used:

  1. Read the file and clean the data. Removing useless words like articles, special characters etc. Ensuring the most homogenous configuration as possible. This is important to feed our net to get the best learning.
  2. Vectorized the information because ANNs only understand float instead of text
  3. Build the model
  4. Train the model
  5. Run on test data
  6. Finally test it with real data making predictions

About Neural networks, RNNs and LSTMs:

image image

Problem definition

  • we are going to classify a complaints csv file generated by a call center
  • data contains complaints in form of text and label specifying the type
  • we want a way to easily figure out future complaint types, without human need

Read and clean the data

symbols_to_be_replaced_by_space_regex = Regex.compile!("[/(){}\[\]\|@,;]")
symbols_to_be_removed_regex = Regex.compile!("[^a-z +_]")
censor_tokens_to_be_removed_regex = Regex.compile!("[^\s][@X/|\d{}]+[$\s.,\n]")
remaining_censor_tokens_to_be_removed_regex = Regex.compile!("[X]{2,}")

stopwords =
  "/data/english.stopwords"
  |> File.read!()
  |> String.split("\n")
  |> Enum.reject(&(&1 == ""))

clean_text = fn text ->
  text
  |> String.replace(censor_tokens_to_be_removed_regex, "")
  |> String.replace(remaining_censor_tokens_to_be_removed_regex, "")
  |> String.downcase()
  |> String.replace(symbols_to_be_replaced_by_space_regex, " ")
  |> String.replace(symbols_to_be_removed_regex, "")
  |> String.split()
  |> Enum.reject(&(&1 in stopwords))
  |> Enum.join(" ")
end
#Function<42.3316493/1 in :erl_eval.expr/6>
data =
  "/data/data.csv"
  |> File.stream!()
  |> CSV.decode!(separator: ?#)
  |> Enum.drop(1)
  |> Enum.map(fn [_date, label, input_text] ->
    %{
      label: label,
      input_text: clean_text.(input_text)
    }
  end)
  |> Enum.filter(fn %{label: label} ->
    label in [
      "Debt collection",
      "Mortgage"
    ]
  end)
  |> Enum.to_list()
  |> Enum.shuffle()
[
  %{input_text: "mortgage broker", label: "Mortgage"},
  %{input_text: "communication tactics", label: "Debt collection"},
  %{input_text: "monthly mortgage payment went although property tax went", label: "Mortgage"},
  %{
    input_text: "trying get modification existing mortgage + years every time submit information requested",
    label: "Mortgage"
  },
  %{input_text: "mortgage broker", label: "Mortgage"},
  %{input_text: "disclosure verification debt", label: "Debt collection"},
  %{
    input_text: "health insurance called companies forst reported supposed tobe deleted resubmitted insurance",
    label: "Debt collection"
  },
  %{input_text: "communication tactics", label: "Debt collection"},
  %{
    input_text: "mortgage originally opened service released carrington mortgage heard issues transfer payments automated got transferred carrington visibility account reporting credit report like supposed faxed letter research department communicates fax mail finally started showing couple weeks ago showed days late finally got portal access carrington first payment reported",
    label: "Mortgage"
  },
  %{input_text: "mortgage broker", label: "Mortgage"},
  %{
    input_text: "applied illinois hardest hit fund program file sent mortgage lender",
    label: "Mortgage"
  },
  %{
    input_text: "debt reporting credit file company indicated owed money correct lived apartment complex indicated could use refrigerator removed one actually came apartment moved take refrigerator trying charge appliance possession one initially apartment agreed using money owed company unacceptable",
    label: "Debt collection"
  },
  %{input_text: "mortgage broker", label: "Mortgage"},
  %{
    input_text: "went online p ay one mortgage payments saw end transaction amount higher normal payment",
    label: "Mortgage"
  },
  %{input_text: "contd attempts collect debt owed", label: "Debt collection"},
  %{
    input_text: "already settled case vs rpm mortgage got mortgage rpm never received payout thinking many people might left settlement called hotline one good answer would left loan number issued",
    label: "Mortgage"
  },
  %{input_text: "taking threatening illegal action", label: "Debt collection"},
  %{input_text: "contd attempts collect debt owed", label: "Debt collection"},
  %{
    input_text: "car needed repair took charged repair car picked technician told could make payments order pick car needed pay would leave balance made payments got call today still owe lease repair bill company collecting monterey financial aware lease car repair thought making payments repairs told lease pay additional legal action would taken received messages work cell phone collection company would like resolved paid amount repair plan making additional payments feel like taken advantage high interest lease aware sign lease would like resolved closed feel need pay repairs",
    label: "Debt collection"
  },
  %{input_text: "contd attempts collect debt owed", label: "Debt collection"},
  %{input_text: "communication tactics", label: "Debt collection"},
  %{input_text: "communication tactics", label: "Debt collection"},
  %{input_text: "disclosure verification debt", label: "Debt collection"},
  %{input_text: "received letter stating owe dollars hospital", label: "Debt collection"},
  %{input_text: "false statements representation", label: "Debt collection"},
  %{
    input_text: "notice received stating increase mortgage payment approx would begin call made loancare",
    label: "Mortgage"
  },
  %{
    input_text: "trying get huge amount debt without filing bankruptcy requested cease desist constant calls email capital one two accounts amount debt accounts near charge standards days late therefore",
    label: "Debt collection"
  },
  %{input_text: "mortgage broker", label: "Mortgage"},
  %{input_text: "taking threatening illegal action", label: "Debt collection"},
  %{
    input_text: "name trouble getting mortgage company admit mistakes paying insurance dues property mistake cost additional insurance fee add roughly month mortgage payment received call insurance company policy cancel since mortgage company pay invoice acct escrowed taxs insurance insurance agency said able get resolved attempts got involved called mortgage company said would pay insurance fee happen",
    label: "Mortgage"
  },
  %{input_text: "taking threatening illegal action", label: "Debt collection"},
  %{input_text: "communication tactics", label: "Debt collection"},
  %{input_text: "contd attempts collect debt owed", label: "Debt collection"},
  %{input_text: "keep trying", label: "Debt collection"},
  %{
    input_text: "year made payment admit days late date also made payment next month due noticed charged late fees month already made looked account applied payment principal payment instead toward payment due emailed called assured would fixed made payment due saw late fees saw never fixed payment made principal payment",
    label: "Debt collection"
  },
  %{
    input_text: "bayview took servicing mortgage behind attorney hired help us obtain loan modification since payments increased hardship situation due death primary supporter residence mortgage attorney told us",
    label: "Mortgage"
  },
  %{
    input_text: "years attempted numerous times fixed mortgage lower payment always said enough money show affordability moved stranger home generate extra money still denied excessive forbearance would nt problem would modified long time ago facing sale date scheduled lose home",
    label: "Mortgage"
  },
  %{
    input_text: "contacted bogus company wanting debt amount two companies",
    label: "Debt collection"
  },
  %{
    input_text: "refinanced mortgage early fall mortgage acquired wells fargo payments begin escrow balance transferred time closing monthly payments approx received word wells fargo couple months ago escrow account underfunded resulted owing could paid full paid time additional month mortgage seems grossly negligent",
    label: "Mortgage"
  },
  %{input_text: "contd attempts collect debt owed", label: "Debt collection"},
  %{
    input_text: "received annual escrow statement home mortgage fifth third bank statement indicated escrow shortage although changes homeowners insurance county property taxes",
    label: "Mortgage"
  },
  %{
    input_text: "around mortgage sold nationstar mortgage terms remain exactly arm interest rate change date",
    label: "Mortgage"
  },
  %{
    input_text: "recently made aware received credit mortgage payments wells fargo submitted original loan provider made partial payments totaling",
    label: "Mortgage"
  },
  %{
    input_text: "husband trying get mortgage company endorse return insurance claim check damage done roof interior house faxed invoices repair payments made contractors said unless get invoices showing everything paid full release check keep changing requirements us get check back every time talk one invoice still shows owed paid never got receipt stating paid full unable get invoice showing paid full confused insurance check supposed pays contractor mortgage company holding insurance check hostage individuals supposed pay contractors sure individuals money savings pay large sums money owed contractors suggested loancare call contractor doubt payment made would like get reimbursed took savings account pay contractors amounts paid insurance company check loancare holding give us amount approximately owe little loan house valued never missed late payment frustrating gone house inspections verify work completed work completed minor repairs reimbursed insurance company plan repairs dime summer determined contractor charging thought work cost loan permission share experience long personal info removed problem identity theft",
    label: "Mortgage"
  },
  %{
    input_text: "ocwen charge us claim court month behind waiting new modification ocwen claim home taken court lost bankruptcy claim six month shown drown lie never seen documents letter showing behind new bankruptcy filed last bankruptcy ocwen went court claiming take mortgage lies months behind area process first modification ocwen claim owe true owed see never sees go send court sent judge dont ask questions believe creditor say start complainting attorney generals office ocwen dont change anything continue charge fees insurance never know next company ocwen sent letter court lift stay modification remember complaints never except judge approve time bankruptcy ocwen claim payment original payment judge lower knew modification lower payments bankruptcy paying bankruptcy ocwen continues charging us taken escrow many month ago ocwen thing claiming owe behind credit us amount received bankruptcy paying continue get higher ocwen gotten set attorney began foreclosure home paid much claim bankruptcy send another area ocwen file claim insurance hail damage send check ocwen got recharged fees payout put escrow charge back see statement years fraudulent activities made many complaints correct problem attorney send many letter wife got particular judge nationality went taking bankruptcy claiming going modification got packs back months call clam working new modification lying deceitful bad practice got letters attorney ocwen investors question foreclosure company bad business requested talk president ceo person initially called talked wife things heard say thought phone hold could hear background talking another person taking property man day forward ive calling request speak wife durable power attorney take care business matter work state timei ocwen staff back still continue charge fees hard try pay claim court wad month behind wife call request send six months amounts would claim owe bankruptcy court refused give amount would would claiming behind month process modification ocwen accept payment come back claim owe money claiming statement dealing year taken fight claiming forbearance left balance thought way get good still happen fees charges want correct want take home sent harassment letters repo days respond ocwen company affiliationsthe thief different areas country purchased house years agethis long suffer distress mistreated mortgage company could anybody continue things keep getting away charge feed sued attorney general office lied claiming need send paper handwritten letter stating part lawsuit type letter send ocwen wait days call back thats need hand write letter center id call back act like funny joke come letters trying foreclose house right wrong really need get copy statement hope everything make sure company never practice state united states never",
    label: "Mortgage"
  },
  %{
    input_text: "son difficulty took emergency room base hospital hospital proper tool necessary treat immediately transported next city heard nothing charges received bill saying owed around immediately called billing department hospital spoke agent stated owed amount asked bill said doesnt know asked pay something know charges never received hospital bill military always paid informed contact contacted spoke agent named couldnt find charge said would removed hospital contact still sending bills calling immediately informed hospital billing department agent stated would look matter heard nothing months assumed issue resolved received another bill issue called hospitals billing department spoke another agent informed amount removed needed pay totally different amount bill couldnt figure numbers said paid pays portion called spoke another agent agent claimed amount taking son hospital network explained voluntarily take",
    label: "Debt collection"
  },
  %{input_text: "disclosure verification debt", label: "Debt collection"},
  %{
    input_text: "mortgage loan wells fargo bank struggling make payments result divorce settlement contact mortgage lender several times phone",
    label: "Mortgage"
  },
  %{
    input_text: "switched wireless provider returned devices ipad back using mailing coupons provided",
    ...
  },
  %{...},
  ...
]
length = length(data)
split_index = round(length * 0.8)

train_data = Enum.slice(data, 0, split_index)
test_data = Enum.slice(data, split_index, length)

train_labels = Enum.map(train_data, &amp; &amp;1.label)
train_input_texts = Enum.map(train_data, &amp; &amp;1.input_text)

test_labels = Enum.map(test_data, &amp; &amp;1.label)
test_input_texts = Enum.map(test_data, &amp; &amp;1.input_text)
["", "mortgage broker",
 "diversified consultant reporting credit never received anything owe anyways back got paid final bills dont even know last time sprint well years ago debt past statute limitations also",
 "contd attempts collect debt owed", "portfolio recovery associates", "hi",
 "received letter shipped standard mail", "trying work things servicer owned mortgage loan",
 "mortgage broker", "mortgage broker",
 "went van dyke mortgage worked officer named tried approve mortgage loan",
 "please advised third written request asking remove accounts listed remain credit report violation usc ss",
 "received official looking notice exterior label entitled credit card settlement inside company implied outstanding debt collections must contact within days avoid increase debt due interest possibly fines",
 "mortgage broker", "mortgage broker", "contd attempts collect debt owed",
 "selene finance accept payment additional principle monthly payment pre payment principle authorized va mortgage",
 "mortgage broker",
 "received collection credit control brought debt brought debt originator fraudulent credit card went court received dismissal account fraudulent supposed closed however keeps sold illegal attached report received portfolio proves account determined fraud victim identity theft needs removed immediately original balance stated however fees showing police report started happening",
 "old credit card went collections stegner stegner began collection attempts",
 "previously filed complaints provided false misleading information",
 "contd attempts collect debt owed", "contd attempts collect debt owed",
 "dear although conventional fixed rate mortgage wells fargo wf",
 "contacted mortgage servicer set biweekly payments", "contd attempts collect debt owed",
 "contd attempts collect debt owed", "contd attempts collect debt owed",
 "date occurence explained medical billing department debt paid processed incorrect department resulting collections credit attempted problem corrected funds paidin moved proper department successful offered pay debt company would remove negative information credit file told billing representative negative information would stay file statue limitations reached debt since reached pa statue limitations medical debt",
 "disclosure verification debt", "mortgage loan taken new servicing bank fifth third bank",
 "letter received claiming debt owed verified owed original creditor paypal",
 "filed bankruptcy st nd home filing bankruptcy continued make payments st discontinued making payments nd know would like sell home escrow company able get information clear nd mortgage",
 "original creditor amount stated owe requesting vod validation debt date disputed resolution account status updated requesting vod validation debt validated",
 "mortgage broker",
 "mortgage sold pnc bank sold contacted pnc bank stated required pay extra escrow company policy requiring us cushion reserve purchased loan could change terms originally agreed action seems predatory nature however threatened pay home could go foreclosure came agreement pay months added monthly mortgage payment",
 "trying sell property vacant lot mortgage paid bmo harris seven years ago acquired loan failed loan paid full seven years ago",
 "submitted numerous complaints ocwen ridiculous interest rate paying percent amount paying paying month time every month dollars go mortgage goes ocwen crazy asked interest rate reduced paying mortgage time giving money away company month understand considered illegal thru cfpb investigated understand guidelines given companies charge outrages ratesthis needs investigated continue file complaint something done company willing work consumers",
 "home damaged massive tornado hail sick work since used mortgage payments pay removal massive trees uprooted thrown home due safety concerns along repairs deck railings insurance company sent funds mortgage company mortgage company pay arrearages return funds pay keep giving different stories holding funds still thousands dollars repairs done",
 "submitting complaint regards recent change holder mortgage mortgage initially held sold select portfolio services sps",
 "communication tactics", "longtime customer ditech mortgage company recently",
 "received calls less minutes still coming wrote absolutely harassment dont go minute without getting call starting awful way conduct business",
 "dealing ocwen since trying get deed lieu found information online lawsuit currently going divorce exhusband decided best go deed lieu started process signed documents approved ocwen home appraisal followed ocwen monthly answer would get delay got paperwork stating approved house vacant time house already vacant home owners insurance canceled due vacant ocwen said would fine covered house insurance get hazardous insurance ocwen putting system kept getting notices house covered another delay ocwen call regards called said paperwork needed kept getting letters saying house covered called fought insurance department ocwen finally got insurance information entered system call every week every weeks check status deed lieu always get delay paperwork attorneys working never get reason delay get call ocwen notary contact regards signing paperwork later day get another phone call appraisal ran new appraisal need conducted prior signing documents called ocwen got another letter mail stating significant hazardous insurance house needed cover total cost owned property none listed letters received phone called also talked deed lieu department stated would expedite appraisal nothing ordered time called got answer time gave order number requesting appraisal order number called still appraisal order said would hear appraiser called ocwen back still ordered appraisal asked speak supervisor id afternoon contractor appraisal contacted met day appraisal done contacted bbb complaint months ocwen delaying complete deed lieu assistance program taken days beyond frustrated whits end get completed cant afford attorney already going lawyer custody child support divorce house name advice want taken care delaying call ocwen back today let know appraisal done friday said take business days get documents needs accepted uploaded system point attorneys begin preparing mortgage documents sign house still fighting ocwen try complete deed lieu assistance program",
 "applied mortgage loandepot direct online",
 "years old applied reverse mortgage wells fargo bank approved shortly closing informed getting reverse mortgage business went open market got approved wells fargo would give proper payoff amount reverse mortgage close applied approved companies able close due dispute amount owed",
 "mortgage broker", "communication tactics", "currently owe mortgage",
 "mortgage loan transferred rushmore loan management services llc time one month behind mortgage due escrow problem taxes unable make payment amount correct transferred rushmore loan management services llc today owe payments made payment rushmore loan management services llc amount account number instructed company attached letter shows writing send payment way paid additional money minutes called rushmore loan management services llc soon made payment reported mtcn number representative representative stated payment would post end business day since made website states well pm called rushmore loan management services llc told could take days post told law paid extra money get instructed youre accepting money hung numerous times trying call rushmore loan management services llc spoke supervisor named morning regards payment posted im told could take four days really ach routing account number transaction minutes paid extra spent hours time yesterday trying find correct information regards payment sent spoke representative day made payment late closed stated would post end business also told would give call today end business day update payment yet receive phone call told assigned single point contact named never called ceaseanddesist account sent via email right lift ceaseanddesist still received call single point contact found today whose name number got voicemail received false misleading information company",
 ...]

Vectorize the information

defmodule Vectorizer do
  alias Tokenizers.Tokenizer
  alias Tokenizers.Encoding

  defstruct [:tokenizer, :label_to_id, :id_to_label]

  @sequence_length 100

  def init do
    {:ok, tokenizer} = Tokenizer.from_pretrained("bert-base-cased")
    %Vectorizer{tokenizer: tokenizer}
  end

  def sequence_length, do: @sequence_length

  def encode_texts_to_tensor(%Vectorizer{tokenizer: tokenizer}, texts) do
    texts
    |> Enum.map(fn text ->
      {:ok, tokenized} = Tokenizer.encode(tokenizer, text)

      tokenized
      |> Encoding.pad(@sequence_length)
      |> Encoding.truncate(@sequence_length)
      |> Encoding.get_ids()
    end)
    |> Nx.tensor()
  end

  def fit_to_labels(%Vectorizer{} = vectorizer, labels) do
    labels_with_indices =
      labels
      |> Enum.uniq()
      |> Enum.with_index()

    %{
      vectorizer
      | label_to_id: Enum.into(labels_with_indices, %{}),
        id_to_label:
          labels_with_indices
          |> Enum.map(fn {label, id} -> {id, label} end)
          |> Enum.into(%{})
    }
  end

  def encode_labels_to_tensor(%Vectorizer{label_to_id: label_to_id} = vectorizer, labels) do
    labels
    |> Enum.map(&amp;[Map.fetch!(label_to_id, &amp;1)])
    |> Nx.tensor()
    |> Nx.equal(Nx.iota({1, length(unique_labels(vectorizer))}))
  end

  def unique_labels(%Vectorizer{label_to_id: label_to_id}) do
    Map.keys(label_to_id)
  end
end

vectorizer = Vectorizer.init() |> Vectorizer.fit_to_labels(train_labels)
%Vectorizer{
  tokenizer: #Tokenizers.Tokenizer<[
    vocab_size: 28996,
    continuing_subword_prefix: "##",
    max_input_chars_per_word: 100,
    model_type: "bpe",
    unk_token: "[UNK]"
  ]>,
  label_to_id: %{"Debt collection" => 1, "Mortgage" => 0},
  id_to_label: %{0 => "Mortgage", 1 => "Debt collection"}
}
unique_labels_count = Vectorizer.unique_labels(vectorizer) |> length()

input_x =
  Vectorizer.encode_texts_to_tensor(vectorizer, train_input_texts)
  |> Nx.reshape({:auto, Vectorizer.sequence_length()})

input_y =
  vectorizer
  |> Vectorizer.encode_labels_to_tensor(train_labels)

test_input_x =
  Vectorizer.encode_texts_to_tensor(vectorizer, test_input_texts)
  |> Nx.reshape({:auto, Vectorizer.sequence_length()})

test_input_y =
  vectorizer
  |> Vectorizer.encode_labels_to_tensor(test_labels)

{input_x, input_y}
{#Nx.Tensor<
   s64[17274][100]
   EXLA.Backend
   [
     [101, 16935, 24535, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...],
     ...
   ]
 >,
 #Nx.Tensor<
   u8[17274][2]
   EXLA.Backend
   [
     [1, 0],
     [0, 1],
     [1, 0],
     [1, 0],
     [1, 0],
     [0, 1],
     [0, 1],
     [0, 1],
     [1, 0],
     [1, 0],
     [1, 0],
     [0, 1],
     [1, 0],
     [1, 0],
     [0, 1],
     [1, 0],
     [0, 1],
     [0, 1],
     [0, 1],
     [0, 1],
     [0, 1],
     [0, 1],
     [0, 1],
     [0, 1],
     ...
   ]
 >}

Build model

model =
  Axon.input("complaints", shape: {nil, Vectorizer.sequence_length()})
  |> Axon.embedding(100, 100)
  |> Axon.lstm(120, activation: :tanh, gate: :hard_sigmoid)
  |> then(fn {output_sequence, {_new_cell, _new_hidden}} -> output_sequence end)
  |> Axon.dropout(rate: 0.25)
  |> Axon.lstm(60, activation: :tanh, gate: :hard_sigmoid)
  |> then(fn {output_sequence, {_new_cell, _new_hidden}} -> output_sequence end)
  |> Axon.nx(fn x -> x[[0..-1//1, -1]] end)
  |> Axon.dropout(rate: 0.35)
  |> Axon.dense(unique_labels_count, activation: :softmax)

model
|> Axon.Display.as_table(input_x)
|> IO.puts()
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                                                            Model                                                                                                                                                            |
+======================================================================================+===========================================================+=========================================================+=====================+==========================================================================================+
| Layer                                                                                | Input Shape                                               | Output Shape                                            | Options             | Parameters                                                                               |
+======================================================================================+===========================================================+=========================================================+=====================+==========================================================================================+
| complaints ( input )                                                                 | []                                                        | {17274, 100}                                            | shape: {nil, 100}   |                                                                                          |
|                                                                                      |                                                           |                                                         | optional: false     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| embedding_0 ( embedding["complaints"] )                                              | [{17274, 100}]                                            | {17274, 100, 100}                                       |                     | kernel: f32[100][100]                                                                    |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm__c_hidden_state ( recurrent_state["embedding_0"] )                              | [{17274, 100, 100}]                                       | {17274, 1, 120}                                         | key: #Nx.Tensor<    |                                                                                          |
|                                                                                      |                                                           |                                                         |   u32[2]            |                                                                                          |
|                                                                                      |                                                           |                                                         |                     |                                                                                          |
|                                                                                      |                                                           |                                                         |   Nx.Defn.Expr      |                                                                                          |
|                                                                                      |                                                           |                                                         |   tensor a   u32[2] |                                                                                          |
|                                                                                      |                                                           |                                                         | >                   |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm__h_hidden_state ( recurrent_state["embedding_0"] )                              | [{17274, 100, 100}]                                       | {17274, 1, 120}                                         | key: #Nx.Tensor<    |                                                                                          |
|                                                                                      |                                                           |                                                         |   u32[2]            |                                                                                          |
|                                                                                      |                                                           |                                                         |                     |                                                                                          |
|                                                                                      |                                                           |                                                         |   Nx.Defn.Expr      |                                                                                          |
|                                                                                      |                                                           |                                                         |   tensor a   u32[2] |                                                                                          |
|                                                                                      |                                                           |                                                         | >                   |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm__hidden_state ( container {"lstm__c_hidden_state", "lstm__h_hidden_state"} )    | {}                                                        | {{17274, 1, 120}, {17274, 1, 120}}                      |                     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_0 ( lstm["embedding_0", "lstm__hidden_state"] )                                 | [{17274, 100, 100}, {{17274, 1, 120}, {17274, 1, 120}}]   | {{17274, 100, 120}, {{17274, 1, 120}, {17274, 1, 120}}} | activation: :tanh   | input_kernel: tuple{"f32[100][120]", "f32[100][120]", "f32[100][120]", "f32[100][120]"}  |
|                                                                                      |                                                           |                                                         | gate: :hard_sigmoid | hidden_kernel: tuple{"f32[120][120]", "f32[120][120]", "f32[120][120]", "f32[120][120]"} |
|                                                                                      |                                                           |                                                         | unroll: :dynamic    | bias: tuple{"f32[120]", "f32[120]", "f32[120]", "f32[120]"}                              |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_1_output_sequence ( elem["lstm_0"] )                                            | [{{17274, 100, 120}, {{17274, 1, 120}, {17274, 1, 120}}}] | {17274, 100, 120}                                       |                     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| dropout_0 ( dropout["lstm_1_output_sequence"] )                                      | [{17274, 100, 120}]                                       | {17274, 100, 120}                                       | rate: 0.25          |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_1_c_hidden_state ( recurrent_state["dropout_0"] )                               | [{17274, 100, 120}]                                       | {17274, 1, 60}                                          | key: #Nx.Tensor<    |                                                                                          |
|                                                                                      |                                                           |                                                         |   u32[2]            |                                                                                          |
|                                                                                      |                                                           |                                                         |                     |                                                                                          |
|                                                                                      |                                                           |                                                         |   Nx.Defn.Expr      |                                                                                          |
|                                                                                      |                                                           |                                                         |   tensor a   u32[2] |                                                                                          |
|                                                                                      |                                                           |                                                         | >                   |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_1_h_hidden_state ( recurrent_state["dropout_0"] )                               | [{17274, 100, 120}]                                       | {17274, 1, 60}                                          | key: #Nx.Tensor<    |                                                                                          |
|                                                                                      |                                                           |                                                         |   u32[2]            |                                                                                          |
|                                                                                      |                                                           |                                                         |                     |                                                                                          |
|                                                                                      |                                                           |                                                         |   Nx.Defn.Expr      |                                                                                          |
|                                                                                      |                                                           |                                                         |   tensor a   u32[2] |                                                                                          |
|                                                                                      |                                                           |                                                         | >                   |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_1_hidden_state ( container {"lstm_1_c_hidden_state", "lstm_1_h_hidden_state"} ) | {}                                                        | {{17274, 1, 60}, {17274, 1, 60}}                        |                     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_1 ( lstm["dropout_0", "lstm_1_hidden_state"] )                                  | [{17274, 100, 120}, {{17274, 1, 60}, {17274, 1, 60}}]     | {{17274, 100, 60}, {{17274, 1, 60}, {17274, 1, 60}}}    | activation: :tanh   | input_kernel: tuple{"f32[120][60]", "f32[120][60]", "f32[120][60]", "f32[120][60]"}      |
|                                                                                      |                                                           |                                                         | gate: :hard_sigmoid | hidden_kernel: tuple{"f32[60][60]", "f32[60][60]", "f32[60][60]", "f32[60][60]"}         |
|                                                                                      |                                                           |                                                         | unroll: :dynamic    | bias: tuple{"f32[60]", "f32[60]", "f32[60]", "f32[60]"}                                  |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| lstm_2_output_sequence ( elem["lstm_1"] )                                            | [{{17274, 100, 60}, {{17274, 1, 60}, {17274, 1, 60}}}]    | {17274, 100, 60}                                        |                     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| nx_0 ( nx["lstm_2_output_sequence"] )                                                | [{17274, 100, 60}]                                        | {17274, 60}                                             |                     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| dropout_1 ( dropout["nx_0"] )                                                        | [{17274, 60}]                                             | {17274, 60}                                             | rate: 0.35          |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| dense_0 ( dense["dropout_1"] )                                                       | [{17274, 60}]                                             | {17274, 2}                                              |                     | kernel: f32[60][2]                                                                       |
|                                                                                      |                                                           |                                                         |                     | bias: f32[2]                                                                             |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
| softmax_0 ( softmax["dense_0"] )                                                     | [{17274, 2}]                                              | {17274, 2}                                              |                     |                                                                                          |
+--------------------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------------------------+---------------------+------------------------------------------------------------------------------------------+
Total Parameters: 159642
Total Parameters Memory: 638568 bytes
:ok

Train the model

batch_size = 128
batched_input_x = Nx.to_batched(input_x, batch_size)
batched_input_y = Nx.to_batched(input_y, batch_size)

model_params =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.1))
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |> Axon.Loop.run(
    Stream.zip(batched_input_x, batched_input_y),
    %{},
    epochs: 10,
    compiler: EXLA
  )
Epoch: 0, Batch: 100, Accuracy: 0.5512069 loss: 0.7114335
Epoch: 1, Batch: 100, Accuracy: 0.5563124 loss: 0.6977340
Epoch: 2, Batch: 100, Accuracy: 0.5603344 loss: 0.6930305
Epoch: 3, Batch: 100, Accuracy: 0.5590193 loss: 0.6910381
Epoch: 4, Batch: 100, Accuracy: 0.5579364 loss: 0.6895447
Epoch: 5, Batch: 100, Accuracy: 0.5519030 loss: 0.6893529
Epoch: 6, Batch: 100, Accuracy: 0.5537593 loss: 0.6889438
Epoch: 7, Batch: 100, Accuracy: 0.5537594 loss: 0.6886014
Epoch: 8, Batch: 100, Accuracy: 0.5546876 loss: 0.6883425
Epoch: 9, Batch: 100, Accuracy: 0.5554614 loss: 0.6881454
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[2]
      EXLA.Backend
      [0.1251661330461502, -0.125175341963768]
    >,
    "kernel" => #Nx.Tensor<
      f32[60][2]
      EXLA.Backend
      [
        [-0.11456689238548279, -0.22838549315929413],
        [-0.12396243959665298, 0.3480182886123657],
        [-0.24799303710460663, 0.4086136221885681],
        [-0.25208061933517456, 0.1209348812699318],
        [-0.41774049401283264, 0.017988236621022224],
        [0.04517333209514618, 0.025405297055840492],
        [0.11701110005378723, 0.3374139368534088],
        [-0.1964828222990036, 0.6522907614707947],
        [-0.18669693171977997, 0.231783926486969],
        [-0.2889942228794098, 0.21732266247272491],
        [-0.019810926169157028, -0.21428732573986053],
        [-0.4443662762641907, 0.5153412222862244],
        [0.4344821572303772, -0.2549160122871399],
        [0.24886995553970337, -0.29946666955947876],
        [0.26993128657341003, -0.12220335006713867],
        [-0.43190282583236694, 0.4242774546146393],
        [0.1516372263431549, -0.22656863927841187],
        [0.13703890144824982, -0.16237571835517883],
        [0.14989663660526276, 0.2655700743198395],
        [-0.1689331829547882, 0.4276685118675232],
        [0.4065585732460022, -0.2818610370159149],
        [-0.3965660333633423, 0.6050839424133301],
        [0.3432924747467041, -0.3004985749721527],
        [0.06472918391227722, ...],
        ...
      ]
    >
  },
  "embedding_0" => %{
    "kernel" => #Nx.Tensor<
      f32[100][100]
      EXLA.Backend
      [
        [0.4756399691104889, -0.49437612295150757, 0.4316394627094269, -0.4074139893054962, -0.4901922643184662, -0.36140480637550354, 0.36421868205070496, 0.3011569380760193, 0.43712350726127625, 0.5534155368804932, 0.5888078808784485, -0.33357852697372437, 0.5388597846031189, 0.40615323185920715, 0.481162428855896, 0.38006094098091125, -0.638387143611908, 0.4308280944824219, 0.4813270568847656, 0.5982890725135803, 0.4459790885448456, 0.3848690390586853, 0.5315402746200562, -0.5542699098587036, -0.3511618375778198, 0.3024648129940033, 0.1730550229549408, -0.41640833020210266, -0.35587364435195923, 0.5026905536651611, 0.3404591977596283, 0.5630281567573547, 0.6074140071868896, 0.5338529348373413, -0.615528404712677, -0.5442991256713867, 0.3665221333503723, 0.44749459624290466, -0.38137248158454895, 0.3597576916217804, 0.26979339122772217, 0.44773054122924805, -0.27586621046066284, -0.2849093973636627, -0.5049300789833069, -0.5402419567108154, 0.673927903175354, ...],
        ...
      ]
    >
  },
  "lstm_0" => %{
    "bias" => {#Nx.Tensor<
       f32[120]
       EXLA.Backend
       [-0.6068086624145508, -0.6412422060966492, 0.1200716495513916, -0.16444075107574463, 0.3364875018596649, -0.24439403414726257, 0.5298935174942017, 0.15725243091583252, -0.6854284405708313, -0.7712937593460083, 0.962135374546051, -0.09957625716924667, -0.5467063188552856, -0.2684676945209503, -0.679925799369812, -0.13650093972682953, 0.09936871379613876, 0.10097546875476837, -0.47389480471611023, 0.8439937233924866, -0.6029786467552185, -0.1408299058675766, -0.5871002078056335, 0.4825431704521179, -0.41130533814430237, 0.1793312281370163, 0.21258889138698578, 0.4704096019268036, -0.5082970857620239, 0.2585344910621643, -0.4169670641422272, 1.053611159324646, 0.10052215307950974, 0.012130340561270714, -0.8054720759391785, -0.2702268064022064, 0.8752445578575134, -0.6760954856872559, -0.2951154410839081, 0.02220281772315502, -0.4468700885772705, -0.9130715131759644, -1.2395175695419312, 0.45675531029701233, 0.051569171249866486, ...]
     >,
     #Nx.Tensor<
       f32[120]
       EXLA.Backend
       [-0.7288936972618103, -0.760654866695404, -0.2817695438861847, -0.3098216950893402, -0.059826578944921494, -0.6543993949890137, 0.05059446766972542, 0.20688965916633606, -0.7661867737770081, -0.875282883644104, 0.9727373123168945, -0.41664382815361023, -0.47966110706329346, -0.4021231234073639, -0.49192363023757935, -0.012789309024810791, 0.06188487261533737, -0.14829552173614502, -0.5137802362442017, 0.6509908437728882, -0.5031194686889648, -0.025931375101208687, -0.8159288763999939, 0.38978806138038635, 0.25578823685646057, 0.19951215386390686, 0.23506265878677368, 0.35722053050994873, -0.6560932397842407, 0.595476508140564, -0.5279211401939392, 0.9366962313652039, -0.10612500458955765, -0.11484526842832565, -0.8178385496139526, -0.38810819387435913, 0.7290158271789551, -0.7367604374885559, -0.012611097656190395, -0.020670797675848007, -0.7428527474403381, -1.0000216960906982, -0.9989742636680603, -0.12951615452766418, ...]
     >,
     #Nx.Tensor<
       f32[120]
       EXLA.Backend
       [0.30640318989753723, 0.27766096591949463, 0.3808199465274811, -0.3197195827960968, -1.4097706079483032, 0.6068865060806274, -1.4276896715164185, 0.25107407569885254, 0.5309128165245056, -0.13659098744392395, 1.2280856370925903, 0.6880354881286621, -0.00910650473088026, 0.3965328633785248, 0.6851245760917664, -1.4194806814193726, 0.8196843862533569, 0.755574643611908, 0.9890422821044922, 1.399299144744873, -0.2444954216480255, 0.5668771862983704, 0.06373570114374161, 0.6989892721176147, 0.46983757615089417, 0.6393361687660217, -0.341346800327301, 0.5220888257026672, -1.2845314741134644, 1.0785467624664307, 0.14364880323410034, 1.2751268148422241, -1.5366579294204712, -1.3190473318099976, -0.45921701192855835, -1.175851583480835, -1.0107700824737549, -0.10251428931951523, 0.6335068345069885, -0.37535175681114197, -0.6084280014038086, 0.2330879420042038, -0.5187585353851318, ...]
     >,
     #Nx.Tensor<
       f32[120]
       EXLA.Backend
       [-0.6604985594749451, -0.645201563835144, -0.18411312997341156, -0.22245639562606812, -0.33754822611808777, -0.4308028817176819, -0.1461275964975357, -0.07358670234680176, -0.7805112600326538, -0.9343949556350708, -0.35994160175323486, -0.7456783056259155, -0.4592522084712982, -0.356851726770401, -0.2927922010421753, -0.1059064194560051, 0.1651652455329895, -0.13765408098697662, -0.6344941258430481, -0.33435356616973877, -0.10497663915157318, -0.3560033142566681, -0.412919819355011, 0.41109102964401245, 0.20860286056995392, -0.7334303259849548, 0.1886008381843567, 0.3448133170604706, -0.562018632888794, 0.5120474696159363, -0.4046480357646942, 0.7070919871330261, -0.3554325997829437, -0.29739540815353394, -0.8193110227584839, -0.43955162167549133, 0.619873046875, -0.7053048014640808, -0.16967345774173737, -0.05964067205786705, -0.7261925935745239, -1.1039772033691406, ...]
     >},
    "hidden_kernel" => {#Nx.Tensor<
       f32[120][120]
       EXLA.Backend
       [
         [0.10788093507289886, 0.3512379825115204, -0.6248757839202881, -0.25008416175842285, 0.9393187165260315, -1.0760244131088257, -0.19958363473415375, 1.3663206100463867, -0.8373463153839111, 0.651411771774292, 0.35242757201194763, -0.789999783039093, 0.560447096824646, 0.591607928276062, 0.9425632357597351, 0.13658848404884338, 0.23705781996250153, 1.0845650434494019, 0.29009148478507996, 0.6239418387413025, 0.45298105478286743, 0.1120116263628006, 1.2759485244750977, 0.5315880179405212, -0.6587555408477783, 0.08043013513088226, 0.9660731554031372, 0.7736319899559021, -0.5391457080841064, 1.4012311697006226, 0.34877267479896545, 0.9900428652763367, -0.2637501657009125, 1.158486247062683, 0.38105133175849915, 1.2292758226394653, 1.1948695182800293, 1.0596129894256592, -0.6972832083702087, 1.1139757633209229, 0.9683577418327332, 0.3131406307220459, -0.35455724596977234, -0.30302369594573975, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[120][120]
       EXLA.Backend
       [
         [0.358450710773468, -0.2557828426361084, 0.10614006966352463, 0.04867010563611984, 0.8855361342430115, 1.1383148431777954, 0.10895808786153793, 0.8240983486175537, -0.6843536496162415, 0.6658166646957397, 0.3062891364097595, -0.7579455375671387, 0.5910908579826355, 0.662002682685852, 0.8434790968894958, 0.18384899199008942, -0.12716799974441528, 1.0196118354797363, 0.3334205746650696, 0.5405780673027039, 0.36579740047454834, -0.9010316729545593, 1.3652138710021973, 0.3605203628540039, -0.011785297654569149, -0.611190140247345, 1.0387834310531616, 0.763027548789978, -0.49559757113456726, 1.211172103881836, 0.0923260748386383, 0.49107322096824646, -0.3789612948894501, 0.9953087568283081, 0.3741160035133362, 0.7606132626533508, 0.9321890473365784, 1.0513157844543457, -0.4996790289878845, 0.9191576242446899, 0.814562201499939, 0.4356539249420166, 0.43873095512390137, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[120][120]
       EXLA.Backend
       [
         [-0.9162583351135254, -0.35054391622543335, -0.390370637178421, -0.9522840976715088, -0.882036566734314, 0.7854264974594116, -0.319416880607605, 1.0988832712173462, -0.9691332578659058, -0.16318337619304657, -0.13923679292201996, 0.20210163295269012, -0.7247295379638672, 0.0729798749089241, -0.08772959560155869, -0.29626843333244324, -0.008382383733987808, 0.9805390238761902, -0.7421141862869263, 0.28473320603370667, 0.05266217514872551, 0.59184330701828, -0.6791769862174988, -0.5857567191123962, 0.7672326564788818, -0.45282861590385437, -0.39024612307548523, -0.7432284355163574, -0.42507204413414, 0.4078441560268402, -0.35714033246040344, 0.8266581296920776, -0.43008852005004883, -0.4594729244709015, 0.6216307878494263, -0.8508207201957703, -0.40239715576171875, -0.44667401909828186, -0.31040412187576294, 0.36623865365982056, 0.6253090500831604, -0.6062566041946411, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[120][120]
       EXLA.Backend
       [
         [0.1334833800792694, 0.10071207582950592, -0.19375473260879517, -0.2175440937280655, 0.28261232376098633, -0.5313495397567749, -0.2761338949203491, 1.4415931701660156, 0.05195466801524162, 0.6653857231140137, 0.6318196654319763, -1.0091872215270996, 0.4887752830982208, 0.5448973178863525, 1.4287916421890259, 0.21670128405094147, -0.10501915216445923, 1.858127474784851, 0.37639763951301575, 0.2822466790676117, 0.6510127186775208, -0.7911188006401062, 1.5903093814849854, 0.5963724851608276, 0.6545816659927368, -0.684876561164856, 0.9205214977264404, 1.081081748008728, -0.5220369696617126, 1.3410801887512207, 0.36190569400787354, 0.7563324570655823, -0.26644060015678406, 1.1586782932281494, 0.36546164751052856, 1.3622890710830688, 1.0111808776855469, 1.118896484375, -0.6868638396263123, 0.8346412181854248, 0.9821230173110962, ...],
         ...
       ]
     >},
    "input_kernel" => {#Nx.Tensor<
       f32[100][120]
       EXLA.Backend
       [
         [-0.22517281770706177, 0.2330886274576187, 0.26489388942718506, -0.4754554331302643, -0.2764003872871399, -0.3151744604110718, -0.5738303661346436, 0.6054437160491943, 0.5670934915542603, 0.49993038177490234, 0.11194263398647308, 0.2954546809196472, 0.7397674322128296, 0.3066859245300293, 0.3888559639453888, -0.4686156213283539, 0.8454335331916809, -0.11548548191785812, 0.4553149938583374, -0.4523054361343384, 0.5684559345245361, 0.8590189218521118, 0.43880578875541687, 0.5149311423301697, 0.5222757458686829, -0.20426802337169647, 1.1208033561706543, 0.6668455004692078, 0.3600870370864868, 0.947739839553833, 0.2951829731464386, 0.9133455753326416, -0.3860554099082947, -0.32358044385910034, 0.20058099925518036, 0.19087252020835876, 1.2092773914337158, 0.4592578709125519, -0.8490164875984192, 0.2735322415828705, -0.2258705049753189, -0.3875930607318878, -0.10666191577911377, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[100][120]
       EXLA.Backend
       [
         [0.0908169224858284, 0.25650590658187866, 0.613166093826294, 0.197557732462883, 0.07342406362295151, 0.33218738436698914, -0.3089587092399597, 0.5680950880050659, 0.5540980100631714, 0.3553132712841034, -0.22676599025726318, 0.26317131519317627, 0.5392262935638428, 0.3632277250289917, 0.6455200910568237, -0.33749109506607056, 0.7817844748497009, 0.17298834025859833, 0.5132368206977844, -0.43405571579933167, 0.6071807146072388, 0.7985832095146179, 0.26781535148620605, 0.596405565738678, 0.5411176085472107, -0.320674866437912, 1.423883080482483, 0.6093699932098389, 0.37598443031311035, 1.0189014673233032, 0.48958510160446167, 0.9566540718078613, -0.22570694983005524, -0.18764962255954742, 0.07313736528158188, 0.26614925265312195, 0.9762657284736633, 0.3514813780784607, -0.7382213473320007, 0.3183710277080536, -0.24300618469715118, -0.5253326892852783, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[100][120]
       EXLA.Backend
       [
         [-0.5585066676139832, -0.304169237613678, -0.26811397075653076, 0.32110777497291565, 0.39580443501472473, 0.16679610311985016, 0.4381210207939148, 0.7626484632492065, -0.17587588727474213, -0.4944244623184204, -0.7519116401672363, -0.849659264087677, 0.4314565658569336, -0.44638580083847046, -0.7402040362358093, 0.3498849868774414, 0.014392167329788208, 0.2685610353946686, 0.20389670133590698, -0.633255660533905, 0.4375564157962799, 0.03051411546766758, 1.0173429250717163, -0.9753749966621399, -0.32472386956214905, -0.2310732752084732, -0.46713176369667053, -0.3331509530544281, 0.8307963013648987, 0.0348331443965435, -0.6700356006622314, -0.5212703347206116, 0.2389097362756729, 0.5111476182937622, -0.37461715936660767, -0.498336523771286, 0.5988596081733704, -0.9082761406898499, -0.5600574016571045, -0.6374697685241699, 1.242306113243103, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[100][120]
       EXLA.Backend
       [
         [-0.2233022004365921, 0.3642134368419647, 0.605087161064148, 0.09154772758483887, 0.41787266731262207, 0.3687729835510254, -0.1874573528766632, 0.3084380626678467, 0.48815447092056274, 0.288083016872406, 0.26539427042007446, -0.46608367562294006, 0.612269937992096, 0.35601305961608887, 0.8623830080032349, -0.24748845398426056, 0.8094604015350342, 0.38750773668289185, 0.6023486256599426, 0.4532455801963806, 0.7492527961730957, 0.5804922580718994, 0.2564087212085724, 0.6596786379814148, 0.6180156469345093, -0.585131049156189, 1.0382261276245117, 0.6604195237159729, 0.4935021996498108, 0.8505986928939819, 0.30538034439086914, 0.6377082467079163, 0.07395821064710617, -0.008431482128798962, -0.01718183048069477, 0.2545641362667084, 1.0079272985458374, 0.41212913393974304, -0.8448607325553894, 0.17696410417556763, ...],
         ...
       ]
     >}
  },
  "lstm_1" => %{
    "bias" => {#Nx.Tensor<
       f32[60]
       EXLA.Backend
       [-0.18341106176376343, -0.9243686199188232, -0.19160209596157074, -0.06065799295902252, -0.4303320348262787, 0.619128406047821, -0.7710100412368774, -0.44940507411956787, 0.5548861026763916, 0.5062773823738098, -0.8682750463485718, -0.9573096632957458, -0.8494192361831665, 0.5330183506011963, -0.44420722126960754, -0.9474890232086182, 0.003021322190761566, -0.3567139506340027, -0.12338213622570038, -0.4589076638221741, -0.6747087836265564, -0.5978382229804993, 0.6920117735862732, 0.00686887139454484, -0.2843252122402191, 0.5009809136390686, -0.35029929876327515, -0.44580861926078796, -0.7479485273361206, -1.1829873323440552, -0.7547271251678467, -0.5318294167518616, -0.07108983397483826, -0.9716689586639404, -0.4790057837963104, 0.5215351581573486, 0.5290589332580566, -0.7840427756309509, -0.92562335729599, -0.4603904187679291, -0.8354495763778687, -0.11161450296640396, -0.6068910360336304, -0.42815962433815, ...]
     >,
     #Nx.Tensor<
       f32[60]
       EXLA.Backend
       [0.6804019808769226, -0.699414074420929, -0.45327988266944885, -0.35844528675079346, -0.42983701825141907, 0.5335893034934998, -0.6355966925621033, -0.549679696559906, 0.5751937627792358, 0.5351293683052063, -0.6759853363037109, -1.2699031829833984, -0.5553288459777832, 0.4462152421474457, -0.49574947357177734, 0.37106940150260925, -0.08575860410928726, -0.15258634090423584, 0.09312330931425095, -0.6270717978477478, -0.7505979537963867, -0.6616138219833374, 0.6712358593940735, 0.2239842563867569, -0.40174755454063416, 0.44956734776496887, -0.39673495292663574, -0.3824487328529358, -0.8419355154037476, -0.052910901606082916, 0.8518306612968445, -0.3787192404270172, -0.2649839222431183, -0.8759109377861023, 1.225016713142395, 0.5358517169952393, 0.5783258676528931, -0.6812098026275635, 0.9010128974914551, -0.3434351086616516, 0.007674288935959339, -1.0556445121765137, -0.4569246172904968, ...]
     >,
     #Nx.Tensor<
       f32[60]
       EXLA.Backend
       [-0.07786507904529572, -0.14066745340824127, 0.21797983348369598, 0.4326586425304413, 0.12532973289489746, -2.2134063243865967, 0.6737685203552246, 0.38561344146728516, 1.0082470178604126, 0.771798849105835, -0.9056801795959473, -0.40424439311027527, 0.2508355975151062, -0.7841559052467346, -0.1986011415719986, 0.7531395554542542, -1.0036131143569946, 0.3681122958660126, 0.21005812287330627, -0.2543736696243286, 0.6155755519866943, 0.5249324440956116, 1.2180362939834595, 0.24463632702827454, -0.20888324081897736, 0.6982148289680481, 0.23401881754398346, -0.32930976152420044, -0.6250767111778259, -0.29982998967170715, -0.6776163578033447, -0.0866536796092987, -0.06168237328529358, -0.6064894199371338, -0.2976151704788208, 1.0346397161483765, 1.0402748584747314, 0.3370100259780884, 0.4493761658668518, -0.0687326192855835, -0.5876778960227966, 0.47508886456489563, ...]
     >,
     #Nx.Tensor<
       f32[60]
       EXLA.Backend
       [-0.014073977246880531, -0.9333531260490417, -0.3932308554649353, -0.2617798149585724, -0.4120787978172302, 0.5247744917869568, -0.5954556465148926, -0.4337135851383209, -0.3000921905040741, -0.37982460856437683, -0.9539894461631775, -0.9942632913589478, -0.5676699876785278, -0.3738715648651123, -0.43704116344451904, -0.9408429265022278, -0.34798961877822876, -0.6257126331329346, 0.4768255054950714, -0.46106183528900146, -0.6339935660362244, -0.5934643745422363, -0.22896303236484528, 0.06802326440811157, -0.358540415763855, -0.3800552189350128, -0.48113417625427246, -0.415779709815979, -0.7686687707901001, -0.9083367586135864, -3.6132237911224365, -0.3674773573875427, -0.3569193184375763, -0.8976558446884155, -0.11789632588624954, -0.23673051595687866, -0.3276209235191345, -0.8087908625602722, 0.26622435450553894, -0.681583821773529, -0.8279891610145569, ...]
     >},
    "hidden_kernel" => {#Nx.Tensor<
       f32[60][60]
       EXLA.Backend
       [
         [0.8149242997169495, 0.6868818998336792, 0.8296692967414856, 0.04104558005928993, -0.672260046005249, -0.3831019699573517, -0.49183714389801025, -0.23623262345790863, 0.1390298455953598, 0.18757423758506775, 0.01599043793976307, -0.3954319953918457, -0.3826022446155548, 0.5734179615974426, -0.4551549553871155, -0.48584339022636414, 0.6904323697090149, 0.9519748091697693, 1.1198066473007202, -0.11534616351127625, -0.21510937809944153, 0.31511276960372925, 0.5141907930374146, 0.7482413649559021, -0.8782660961151123, 1.045454740524292, -0.3620418608188629, 0.3273802399635315, -0.5370044112205505, -0.38227686285972595, -0.5593936443328857, -1.4275431632995605, 0.215467169880867, -0.11988919973373413, -0.8910923004150391, 0.4649423658847809, 0.804772675037384, -0.309302419424057, -0.6807579398155212, -0.4532768428325653, -0.5667983293533325, -0.5154366493225098, -1.2079955339431763, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[60][60]
       EXLA.Backend
       [
         [0.793323814868927, 0.11953596025705338, 0.7746643424034119, 0.04961645230650902, -0.39079779386520386, 0.1146659255027771, 0.6680335402488708, 0.03795963153243065, 0.6632503271102905, 0.33424490690231323, 0.69684237241745, -0.31309935450553894, -0.6280642151832581, 0.40980300307273865, -0.7115398049354553, 0.5349920392036438, 0.5307707190513611, -0.13484381139278412, 0.945764422416687, -0.10667821764945984, -1.0601099729537964, 0.04631057009100914, 0.568169116973877, 0.6014842987060547, -0.4642559885978699, 0.28926363587379456, -0.4174582064151764, 0.43141165375709534, -0.5650730133056641, 0.6618462800979614, 1.122145414352417, -1.4744189977645874, 0.29758939146995544, -0.17096078395843506, 3.3031349182128906, 0.5110371708869934, 0.4635887145996094, -0.1492435187101364, 0.12703798711299896, -0.3236512243747711, 0.40120822191238403, 0.17457297444343567, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[60][60]
       EXLA.Backend
       [
         [1.0875036716461182, -1.0162317752838135, 0.32510000467300415, 0.056312769651412964, -0.46246853470802307, -1.4880632162094116, 0.7796785831451416, -1.0549145936965942, -0.5140237808227539, 0.19759666919708252, -0.6880502104759216, 0.06203341484069824, 0.4295230209827423, -0.6768867373466492, -0.3241910934448242, 0.5643408894538879, 0.6144492626190186, 0.9292702078819275, 0.9756212830543518, -0.4731598496437073, -0.2200871855020523, -0.3215710520744324, -0.22164367139339447, -0.04326136037707329, 0.04068143665790558, 1.0577102899551392, 0.6599220037460327, 0.4999141991138458, 0.13533508777618408, -0.1282011717557907, 0.4594724178314209, 0.5638686418533325, 0.23980186879634857, -0.7966639995574951, 0.18221479654312134, -0.08567328751087189, 0.47228872776031494, 0.49250873923301697, -0.49068912863731384, -0.7881981730461121, -0.14168231189250946, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[60][60]
       EXLA.Backend
       [
         [0.5581820607185364, 0.7084558606147766, 0.9461260437965393, -0.47317537665367126, -0.6384133696556091, 0.15450984239578247, -0.23960112035274506, -0.029845207929611206, 0.3309059143066406, 0.4182018041610718, 0.012467822059988976, -0.3697502911090851, -0.6376978158950806, 0.851270854473114, -0.6514398455619812, -0.42584893107414246, 0.7886115908622742, -0.20063063502311707, 1.1557307243347168, -0.033951859921216965, -1.160847783088684, -0.00449313223361969, -0.6316667199134827, 0.588198184967041, -0.46046924591064453, 0.5943189263343811, -0.8891192078590393, -0.7727584838867188, -0.5157036781311035, -0.4221958816051483, -3.02982234954834, -0.35822680592536926, -0.23690663278102875, -0.11070509999990463, 1.6361840963363647, 0.5303181409835815, 0.47293728590011597, -0.12145724147558212, 0.901959240436554, -0.43027544021606445, ...],
         ...
       ]
     >},
    "input_kernel" => {#Nx.Tensor<
       f32[120][60]
       EXLA.Backend
       [
         [-0.2229905128479004, 0.07128430902957916, -0.3170720040798187, 0.47008344531059265, -0.9237403273582458, 0.7696903944015503, 0.3192524313926697, -0.8252050280570984, -0.5051939487457275, 0.026261145249009132, 1.2963849306106567, -0.9394668340682983, 0.059015750885009766, -0.1052999347448349, -0.11394556611776352, -0.9541149139404297, -0.7331550717353821, 0.9627928733825684, -0.013242818415164948, -0.5375745296478271, -0.9710343480110168, -0.5815613865852356, -0.3748815357685089, 1.5032639503479004, -1.2584089040756226, -0.9105948805809021, 0.48886385560035706, 0.4324774444103241, -0.11898104101419449, 0.04111045226454735, -0.29982277750968933, 0.4714553654193878, -0.12024517357349396, -0.9578234553337097, 0.663011372089386, -0.3438766598701477, -0.3213392198085785, -1.4615230560302734, 0.09529570490121841, -0.05152870714664459, -0.688315749168396, 0.9451417326927185, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[120][60]
       EXLA.Backend
       [
         [-0.42292624711990356, 0.5430710911750793, -0.21740177273750305, 0.4434795677661896, -0.7596100568771362, 0.6143496036529541, -0.26403871178627014, -0.09822661429643631, -0.39692407846450806, 0.6618760824203491, 1.3505511283874512, -0.1821458339691162, 0.12859275937080383, -0.2598292827606201, -0.208560973405838, 0.19370515644550323, -0.6715968251228333, 1.3023240566253662, 0.5996056199073792, -0.5182926058769226, -0.3554723858833313, -0.2900538444519043, -0.2775692939758301, 0.7317833304405212, -0.7962754368782043, -0.7965075373649597, 0.7054013013839722, -0.6736835837364197, -0.20065228641033173, 0.5850132703781128, -0.42285799980163574, -0.48797133564949036, -0.7282596230506897, -0.39593443274497986, 1.37859046459198, -0.3343508541584015, -0.20926474034786224, -0.64011150598526, 0.7374411225318909, 0.5465848445892334, 0.05645184591412544, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[120][60]
       EXLA.Backend
       [
         [0.9291626214981079, -0.9520668387413025, -0.35041478276252747, -0.08397278934717178, 0.2276611030101776, -0.7198747396469116, 0.5969547629356384, -0.3112131953239441, -0.20959077775478363, -0.22053366899490356, 0.08457772433757782, 1.0217751264572144, -0.07047414034605026, 0.6975047588348389, -0.03135767579078674, -0.05799692124128342, 1.2063730955123901, -0.3666382133960724, 0.7342422008514404, -0.46793264150619507, -0.5958633422851562, -0.5633532404899597, -0.6141194105148315, -0.37839555740356445, -0.053077224642038345, -0.6654996871948242, -0.2391616702079773, -0.0207744799554348, -0.8414919972419739, 0.4718431234359741, -1.3738021850585938, 0.2176326960325241, -0.5225270390510559, -0.4451864957809448, -0.7946727871894836, -0.37634170055389404, 0.16265809535980225, 0.6172932982444763, 0.41356128454208374, -0.4509551227092743, ...],
         ...
       ]
     >,
     #Nx.Tensor<
       f32[120][60]
       EXLA.Backend
       [
         [-0.790632963180542, 0.06195172294974327, -0.39926397800445557, -0.49038833379745483, -0.9322876930236816, 0.8836627006530762, 0.4735221862792969, -0.9031734466552734, -1.010717511177063, -0.48718148469924927, 1.200811743736267, -1.0804530382156372, 0.0706436038017273, -0.4983786642551422, -0.14185523986816406, -0.9715566635131836, -1.053016185760498, 1.755744457244873, 0.0918399840593338, -0.5686318278312683, -0.9415125250816345, -0.6188305020332336, 0.9015722274780273, 2.0455636978149414, -0.732143223285675, -0.9958089590072632, 0.5037012100219727, 0.3952459394931793, -0.12595726549625397, 0.4200909733772278, -0.309084951877594, 0.4975631833076477, 0.3589181900024414, -0.9218600392341614, 0.7198611497879028, -0.9687642455101013, -0.7808705568313599, -1.2150698900222778, 1.01679265499115, ...],
         ...
       ]
     >}
  }
}

Save model state

# model_params
# |> :erlang.term_to_binary()
# |> then(&File.write!("/data/model_params.bin", &1, [:write, :binary]))
:ok

Validate the model

model_params =
  "/data/model_params.bin"
  |> File.read!()
  |> :erlang.binary_to_term()

batch_size = 128
batched_test_input_x = Nx.to_batched(test_input_x, batch_size)
batched_test_input_y = Nx.to_batched(test_input_y, batch_size)

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(
  Stream.zip(batched_test_input_x, batched_test_input_y),
  model_params,
  compiler: EXLA
)
Batch: 33, accuracy: 0.5969670
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend
      0.5969670414924622
    >
  }
}

Inference

model_params =
  "/data/model_params.bin"
  |> File.read!()
  |> :erlang.binary_to_term()

{_, predict_fn} = Axon.build(model)
predict_fn = fn input -> predict_fn.(model_params, input) end

# Student loan
mortgage_example = """
I filed a bankruptcy in XXXX the bankruptcy did not include my mortgage with US BANK however US BANK stopped reporting my mortgage to the credit bureaus. I called US BANK multiple times and each time I was advised that if a person filed bankruptcy that they were not required to report my mortgage payments to the credit bureau. I called again in XXXX of XXXX and was advised by XXXX in the bankruptcy department that in order for them to report my payments to the crew bureau I have needed to complete a form prior to XXXX of XXXX. I escalated the call to the manager XXXX and I informed her that I signed a reaffirmation form in XXXX of XXXX that stated that by signing the form it would make it possible for US Bank to submit beneficial credit reporting on my behalf. I spoke with XXXX from the bankruptcy department XX/XX/XXXX who informed me that us bank decided not to report to the credit bureau on my behalf even though we have a signed agreement on record with the courts.
"""

debt_collection_example = "Cont'd attempts collect debt not owed"

[
  mortgage_example,
  debt_collection_example
]
|> Enum.map(fn text ->
  [clean_text.(text)]
  |> then(&amp;Vectorizer.encode_texts_to_tensor(vectorizer, &amp;1))
  |> Nx.reshape({:auto, Vectorizer.sequence_length()})
  |> predict_fn.()
  |> Nx.argmax()
  |> Nx.to_number()
end)
[0, 1]