Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Titanic optimized

livebooks/axon/titanic_optimized.livemd

Titanic optimized

Mix.install([
  {:exla, "~> 0.9", override: true},
  {:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
  {:kino, "~> 0.14"},
  {:kino_vega_lite, "~> 0.1"},
  {:explorer, "~> 0.9"},
  {:table_rex, "~> 4.0", override: true}
])

エイリアス

alias Explorer.DataFrame
alias Explorer.Series
require Explorer.DataFrame

データの準備

Kaggle のタイタニックデータをダウンロードする

https://www.kaggle.com/competitions/titanic/data

train_data_input = Kino.Input.file("train data")
test_data_input = Kino.Input.file("test data")

データ読込

train_data =
  train_data_input
  |> Kino.Input.read()
  |> Map.get(:file_ref)
  |> Kino.Input.file_path()
  |> DataFrame.from_csv!()

Kino.DataTable.new(train_data)
test_data =
  test_data_input
  |> Kino.Input.read()
  |> Map.get(:file_ref)
  |> Kino.Input.file_path()
  |> DataFrame.from_csv!()

Kino.DataTable.new(test_data)
full_data =
  train_data
  |> DataFrame.discard("Survived")
  |> DataFrame.concat_rows(test_data)

Kino.DataTable.new(full_data)

欠損値補完

Series.nil_count(full_data["Age"])
get_mode = fn series ->
  series
  |> Series.frequencies()
  |> DataFrame.filter(is_not_nil(values))
  |> Map.get("values")
  |> Series.first()
end
get_statistics = fn series ->
  %{
    平均値: Series.mean(series),
    中央値: Series.median(series),
    最頻値: get_mode.(series)
  }
end
get_statistics.(full_data["Age"])
histgram = fn df, colname ->
  value_list = Series.to_list(df[colname])

  unique_count = value_list |> Enum.uniq() |> Enum.count()

  {x_type, bin} =
    if unique_count > 50 do
      {:quantitative, %{maxbins: 50}}
    else
      {:nominal, nil}
    end

  VegaLite.new(width: 600, height: 300)
  |> VegaLite.data_from_values(value: value_list)
  |> VegaLite.mark(:bar, tooltip: true)
  |> VegaLite.encode_field(:x, "value", type: x_type, bin: bin, title: colname)
  |> VegaLite.encode_field(:y, "value", type: :quantitative, aggregate: :count)
end
histgram.(full_data, "Age")
full_data
|> DataFrame.filter(col("Age") < 10)
|> Kino.DataTable.new()
full_data =
  full_data
  |> DataFrame.mutate(
    prob_child:
      col("Name") |> contains("Master") or
        (col("Name") |> contains("Miss") and
           col("Parch") > 0)
  )
  |> DataFrame.mutate(prob_adult: not prob_child)

Kino.DataTable.new(full_data)
full_data
|> DataFrame.filter(prob_child)
|> DataFrame.pull("Age")
|> get_statistics.()
full_data
|> DataFrame.filter(prob_adult)
|> DataFrame.pull("Age")
|> get_statistics.()
full_data
|> DataFrame.filter(prob_child)
|> histgram.("Age")
full_data
|> DataFrame.filter(prob_adult)
|> histgram.("Age")
Series.nil_count(full_data["Fare"])
get_statistics.(full_data["Fare"])
histgram.(full_data, "Fare")
Series.nil_count(full_data["Embarked"])
histgram.(full_data, "Embarked")
Series.nil_count(full_data["Cabin"])

生存率との相関

histgram.(train_data, "Survived")
survived_counts =
  train_data["Survived"]
  |> Series.frequencies()
  |> DataFrame.sort_by(values)
  |> DataFrame.to_columns()
  |> Map.get("counts")

survived_rate = Enum.at(survived_counts, 1) / Enum.sum(survived_counts)
color_histgram = fn df, colname, color_colname ->
  value_list = Series.to_list(df[colname])
  color_list = Series.to_list(df[color_colname])

  unique_count = value_list |> Enum.uniq() |> Enum.count()

  {x_type, bin} =
    if unique_count > 20 do
      {:quantitative, %{maxbins: 20}}
    else
      {:nominal, nil}
    end

  VegaLite.new(width: 600, height: 300)
  |> VegaLite.data_from_values(value: value_list, color: color_list)
  |> VegaLite.mark(:bar, tooltip: true)
  |> VegaLite.encode_field(:x, "value", type: x_type, bin: bin, title: colname)
  |> VegaLite.encode_field(:y, "value", type: :quantitative, aggregate: :count)
  |> VegaLite.encode_field(:color, "color", type: :nominal)
end
color_histgram.(train_data, "Pclass", "Survived")
cross_table =
  train_data
  |> DataFrame.group_by(["Pclass", "Survived"])
  |> DataFrame.summarise(count: count(col("Survived")))
  |> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
  |> DataFrame.sort_by(col("Pclass"))

Kino.DataTable.new(cross_table)
cross_table
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
color_histgram.(train_data, "Fare", "Survived")
train_data
|> DataFrame.filter(is_not_nil(col("Fare")))
|> DataFrame.mutate(fare_group: (col("Fare") / 50) |> floor() |> cast(:integer))
|> DataFrame.group_by([:fare_group, "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(fare_group)
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
color_histgram.(train_data, "Sex", "Survived")
train_data
|> DataFrame.group_by(["Sex", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
train_data
|> DataFrame.filter(is_not_nil(col("Age")))
|> DataFrame.mutate(age_group: (col("Age") / 10) |> floor() |> cast(:integer))
|> color_histgram.("age_group", "Survived")
train_data
|> DataFrame.filter(is_not_nil(col("Age")))
|> DataFrame.mutate(age_group: (col("Age") / 10) |> floor() |> cast(:integer))
|> DataFrame.group_by([:age_group, "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(age_group)
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
color_histgram.(train_data, "Embarked", "Survived")
train_data
|> DataFrame.group_by(["Embarked", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
train_data
|> DataFrame.group_by("Embarked")
|> DataFrame.summarise(mean: col("Fare") |> mean())
|> Kino.DataTable.new()
color_histgram.(full_data, "Embarked", "Pclass")
color_histgram.(train_data, "SibSp", "Survived")
color_histgram.(train_data, "Parch", "Survived")
train_data
|> DataFrame.mutate(family: col("SibSp") + col("Parch"))
|> color_histgram.("family", "Survived")
train_data
|> DataFrame.filter(col("Ticket") == "LINE")
|> Kino.DataTable.new()
followers_df =
  full_data["Ticket"]
  |> Series.frequencies()
  |> DataFrame.rename(["Ticket", "followers"])
  |> DataFrame.mutate(followers: followers - 1)

Kino.DataTable.new(followers_df)
train_data =
  train_data
  |> DataFrame.join(DataFrame.filter(followers_df, col("Ticket") != "LINE"), how: :left)
  |> then(&amp;DataFrame.put(&amp;1, :followers, Series.fill_missing(&amp;1["followers"], 0)))

Kino.DataTable.new(train_data)
train_data
|> DataFrame.filter(followers == 0 and col("SibSp") > 0 and col("Parch") > 0)
|> Kino.DataTable.new()
color_histgram.(train_data, "followers", "Survived")
train_data
|> DataFrame.group_by(["followers", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(col("followers"))
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
defmodule PreProcess do
  def load_csv(kino_input) do
    kino_input
    |> Kino.Input.read()
    |> Map.get(:file_ref)
    |> Kino.Input.file_path()
    |> DataFrame.from_csv!()
  end

  def fill_empty(data, fill_map) do
    fill_map
    |> Enum.reduce(data, fn {column_name, fill_value}, acc ->
      fill_value =
        if fill_value == :median do
          Series.median(data[column_name])
        else
          fill_value
        end

      DataFrame.put(
        acc,
        column_name,
        Series.fill_missing(data[column_name], fill_value)
      )
    end)
  end

  def replace_dummy(data, columns_names) do
    data
    |> DataFrame.dummies(columns_names)
    |> DataFrame.concat_columns(DataFrame.discard(data, columns_names))
  end

  def to_tensor(data) do
    data
    |> DataFrame.to_columns()
    |> Map.values()
    |> Nx.tensor(backend: EXLA.Backend)
    |> Nx.transpose()
    |> Nx.to_batched(1)
    |> Enum.to_list()
  end

  def process(kino_input, id_key, label_key, followers_df) do
    data_org = load_csv(kino_input)

    id_list = Series.to_list(data_org[id_key])

    has_label_key =
      data_org
      |> DataFrame.names()
      |> Enum.member?(label_key)

    label_list =
      if has_label_key do
        data_org[label_key]
        |> Series.to_tensor(backend: EXLA.Backend)
        |> Nx.as_type(:f32)
        |> Nx.new_axis(1)
        |> Nx.to_batched(1)
        |> Enum.to_list()
      else
        nil
      end

    inputs =
      if has_label_key do
        DataFrame.discard(data_org, [id_key, label_key])
      else
        DataFrame.discard(data_org, [id_key])
      end
      |> DataFrame.mutate(
        prob_child:
          col("Name") |> contains("Master") or
            (col("Name") |> contains("Miss") and
               col("Parch") > 0)
      )

    filled_age =
      [
        Series.to_list(inputs["Age"]),
        Series.to_list(inputs["prob_child"])
      ]
      |> Enum.zip()
      |> Enum.map(fn
        {nil, true} ->
          9

        {nil, false} ->
          30

        {age, _prob_child} ->
          age
      end)
      |> Series.from_list()

    inputs =
      inputs
      |> DataFrame.put("Age", filled_age)
      |> DataFrame.join(followers_df, how: :left)
      |> fill_empty(%{"followers" => 0, "Embarked" => "S", "Fare" => :median})
      |> replace_dummy(["Embarked", "Pclass"])
      |> DataFrame.mutate(is_man: col("Sex") == "male")
      |> DataFrame.mutate(fare_group: (col("Fare") / 50) |> floor())
      |> DataFrame.mutate(age_group: (col("Age") / 10) |> floor())
      |> DataFrame.discard(["Cabin", "Name", "Ticket", "Sex", "Fare", "Age", "SibSp", "Parch"])
      |> to_tensor()

    {id_list, label_list, inputs}
  end
end
followers_df = DataFrame.filter(followers_df, col("Ticket") != "LINE")
{
  train_id_list,
  train_label_list,
  train_inputs
} = PreProcess.process(train_data_input, "PassengerId", "Survived", followers_df)
{
  test_id_list,
  test_label_list,
  test_inputs
} = PreProcess.process(test_data_input, "PassengerId", "Survived", followers_df)

モデルの定義

model =
  Axon.input("input", shape: {nil, 11})
  |> Axon.dense(48, activation: :tanh)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(48, activation: :tanh)
  |> Axon.dense(1, activation: :sigmoid)

学習

train_data = Enum.zip(train_inputs, train_label_list)
loss_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "loss", type: :quantitative)
  |> Kino.VegaLite.new()

acc_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
  |> Kino.VegaLite.new()

Kino.Layout.grid([loss_plot, acc_plot], columns: 2)
trained_state =
  model
  |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: 0.0005))
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
  |> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
  |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 50, compiler: EXLA)

未知データに対する推論

results =
  test_inputs
  |> Nx.concatenate()
  |> then(&amp;Axon.predict(model, trained_state, &amp;1))
  |> Nx.to_flat_list()
  |> Enum.map(&amp;round(&amp;1))
  |> then(
    &amp;%{
      "PassengerId" => test_id_list,
      "Survived" => &amp;1
    }
  )
  |> DataFrame.new()

Kino.DataTable.new(results)
results
|> DataFrame.dump_csv!()
|> then(&amp;Kino.Download.new(fn -> &amp;1 end, filename: "result.csv"))