Titanic optimized
Mix.install([
{:exla, "~> 0.9", override: true},
{:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
{:kino, "~> 0.14"},
{:kino_vega_lite, "~> 0.1"},
{:explorer, "~> 0.9"},
{:table_rex, "~> 4.0", override: true}
])
エイリアス
alias Explorer.DataFrame
alias Explorer.Series
require Explorer.DataFrame
データの準備
Kaggle のタイタニックデータをダウンロードする
https://www.kaggle.com/competitions/titanic/data
train_data_input = Kino.Input.file("train data")
test_data_input = Kino.Input.file("test data")
データ読込
train_data =
train_data_input
|> Kino.Input.read()
|> Map.get(:file_ref)
|> Kino.Input.file_path()
|> DataFrame.from_csv!()
Kino.DataTable.new(train_data)
test_data =
test_data_input
|> Kino.Input.read()
|> Map.get(:file_ref)
|> Kino.Input.file_path()
|> DataFrame.from_csv!()
Kino.DataTable.new(test_data)
full_data =
train_data
|> DataFrame.discard("Survived")
|> DataFrame.concat_rows(test_data)
Kino.DataTable.new(full_data)
欠損値補完
Series.nil_count(full_data["Age"])
get_mode = fn series ->
series
|> Series.frequencies()
|> DataFrame.filter(is_not_nil(values))
|> Map.get("values")
|> Series.first()
end
get_statistics = fn series ->
%{
平均値: Series.mean(series),
中央値: Series.median(series),
最頻値: get_mode.(series)
}
end
get_statistics.(full_data["Age"])
histgram = fn df, colname ->
value_list = Series.to_list(df[colname])
unique_count = value_list |> Enum.uniq() |> Enum.count()
{x_type, bin} =
if unique_count > 50 do
{:quantitative, %{maxbins: 50}}
else
{:nominal, nil}
end
VegaLite.new(width: 600, height: 300)
|> VegaLite.data_from_values(value: value_list)
|> VegaLite.mark(:bar, tooltip: true)
|> VegaLite.encode_field(:x, "value", type: x_type, bin: bin, title: colname)
|> VegaLite.encode_field(:y, "value", type: :quantitative, aggregate: :count)
end
histgram.(full_data, "Age")
full_data
|> DataFrame.filter(col("Age") < 10)
|> Kino.DataTable.new()
full_data =
full_data
|> DataFrame.mutate(
prob_child:
col("Name") |> contains("Master") or
(col("Name") |> contains("Miss") and
col("Parch") > 0)
)
|> DataFrame.mutate(prob_adult: not prob_child)
Kino.DataTable.new(full_data)
full_data
|> DataFrame.filter(prob_child)
|> DataFrame.pull("Age")
|> get_statistics.()
full_data
|> DataFrame.filter(prob_adult)
|> DataFrame.pull("Age")
|> get_statistics.()
full_data
|> DataFrame.filter(prob_child)
|> histgram.("Age")
full_data
|> DataFrame.filter(prob_adult)
|> histgram.("Age")
Series.nil_count(full_data["Fare"])
get_statistics.(full_data["Fare"])
histgram.(full_data, "Fare")
Series.nil_count(full_data["Embarked"])
histgram.(full_data, "Embarked")
Series.nil_count(full_data["Cabin"])
生存率との相関
histgram.(train_data, "Survived")
survived_counts =
train_data["Survived"]
|> Series.frequencies()
|> DataFrame.sort_by(values)
|> DataFrame.to_columns()
|> Map.get("counts")
survived_rate = Enum.at(survived_counts, 1) / Enum.sum(survived_counts)
color_histgram = fn df, colname, color_colname ->
value_list = Series.to_list(df[colname])
color_list = Series.to_list(df[color_colname])
unique_count = value_list |> Enum.uniq() |> Enum.count()
{x_type, bin} =
if unique_count > 20 do
{:quantitative, %{maxbins: 20}}
else
{:nominal, nil}
end
VegaLite.new(width: 600, height: 300)
|> VegaLite.data_from_values(value: value_list, color: color_list)
|> VegaLite.mark(:bar, tooltip: true)
|> VegaLite.encode_field(:x, "value", type: x_type, bin: bin, title: colname)
|> VegaLite.encode_field(:y, "value", type: :quantitative, aggregate: :count)
|> VegaLite.encode_field(:color, "color", type: :nominal)
end
color_histgram.(train_data, "Pclass", "Survived")
cross_table =
train_data
|> DataFrame.group_by(["Pclass", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(col("Pclass"))
Kino.DataTable.new(cross_table)
cross_table
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
color_histgram.(train_data, "Fare", "Survived")
train_data
|> DataFrame.filter(is_not_nil(col("Fare")))
|> DataFrame.mutate(fare_group: (col("Fare") / 50) |> floor() |> cast(:integer))
|> DataFrame.group_by([:fare_group, "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(fare_group)
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
color_histgram.(train_data, "Sex", "Survived")
train_data
|> DataFrame.group_by(["Sex", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
train_data
|> DataFrame.filter(is_not_nil(col("Age")))
|> DataFrame.mutate(age_group: (col("Age") / 10) |> floor() |> cast(:integer))
|> color_histgram.("age_group", "Survived")
train_data
|> DataFrame.filter(is_not_nil(col("Age")))
|> DataFrame.mutate(age_group: (col("Age") / 10) |> floor() |> cast(:integer))
|> DataFrame.group_by([:age_group, "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(age_group)
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
color_histgram.(train_data, "Embarked", "Survived")
train_data
|> DataFrame.group_by(["Embarked", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
train_data
|> DataFrame.group_by("Embarked")
|> DataFrame.summarise(mean: col("Fare") |> mean())
|> Kino.DataTable.new()
color_histgram.(full_data, "Embarked", "Pclass")
color_histgram.(train_data, "SibSp", "Survived")
color_histgram.(train_data, "Parch", "Survived")
train_data
|> DataFrame.mutate(family: col("SibSp") + col("Parch"))
|> color_histgram.("family", "Survived")
train_data
|> DataFrame.filter(col("Ticket") == "LINE")
|> Kino.DataTable.new()
followers_df =
full_data["Ticket"]
|> Series.frequencies()
|> DataFrame.rename(["Ticket", "followers"])
|> DataFrame.mutate(followers: followers - 1)
Kino.DataTable.new(followers_df)
train_data =
train_data
|> DataFrame.join(DataFrame.filter(followers_df, col("Ticket") != "LINE"), how: :left)
|> then(&DataFrame.put(&1, :followers, Series.fill_missing(&1["followers"], 0)))
Kino.DataTable.new(train_data)
train_data
|> DataFrame.filter(followers == 0 and col("SibSp") > 0 and col("Parch") > 0)
|> Kino.DataTable.new()
color_histgram.(train_data, "followers", "Survived")
train_data
|> DataFrame.group_by(["followers", "Survived"])
|> DataFrame.summarise(count: count(col("Survived")))
|> DataFrame.pivot_wider("Survived", "count", names_prefix: "Survived_")
|> DataFrame.sort_by(col("followers"))
|> DataFrame.mutate(suvived_rate: col("Survived_1") / (col("Survived_0") + col("Survived_1")))
|> Kino.DataTable.new()
defmodule PreProcess do
def load_csv(kino_input) do
kino_input
|> Kino.Input.read()
|> Map.get(:file_ref)
|> Kino.Input.file_path()
|> DataFrame.from_csv!()
end
def fill_empty(data, fill_map) do
fill_map
|> Enum.reduce(data, fn {column_name, fill_value}, acc ->
fill_value =
if fill_value == :median do
Series.median(data[column_name])
else
fill_value
end
DataFrame.put(
acc,
column_name,
Series.fill_missing(data[column_name], fill_value)
)
end)
end
def replace_dummy(data, columns_names) do
data
|> DataFrame.dummies(columns_names)
|> DataFrame.concat_columns(DataFrame.discard(data, columns_names))
end
def to_tensor(data) do
data
|> DataFrame.to_columns()
|> Map.values()
|> Nx.tensor(backend: EXLA.Backend)
|> Nx.transpose()
|> Nx.to_batched(1)
|> Enum.to_list()
end
def process(kino_input, id_key, label_key, followers_df) do
data_org = load_csv(kino_input)
id_list = Series.to_list(data_org[id_key])
has_label_key =
data_org
|> DataFrame.names()
|> Enum.member?(label_key)
label_list =
if has_label_key do
data_org[label_key]
|> Series.to_tensor(backend: EXLA.Backend)
|> Nx.as_type(:f32)
|> Nx.new_axis(1)
|> Nx.to_batched(1)
|> Enum.to_list()
else
nil
end
inputs =
if has_label_key do
DataFrame.discard(data_org, [id_key, label_key])
else
DataFrame.discard(data_org, [id_key])
end
|> DataFrame.mutate(
prob_child:
col("Name") |> contains("Master") or
(col("Name") |> contains("Miss") and
col("Parch") > 0)
)
filled_age =
[
Series.to_list(inputs["Age"]),
Series.to_list(inputs["prob_child"])
]
|> Enum.zip()
|> Enum.map(fn
{nil, true} ->
9
{nil, false} ->
30
{age, _prob_child} ->
age
end)
|> Series.from_list()
inputs =
inputs
|> DataFrame.put("Age", filled_age)
|> DataFrame.join(followers_df, how: :left)
|> fill_empty(%{"followers" => 0, "Embarked" => "S", "Fare" => :median})
|> replace_dummy(["Embarked", "Pclass"])
|> DataFrame.mutate(is_man: col("Sex") == "male")
|> DataFrame.mutate(fare_group: (col("Fare") / 50) |> floor())
|> DataFrame.mutate(age_group: (col("Age") / 10) |> floor())
|> DataFrame.discard(["Cabin", "Name", "Ticket", "Sex", "Fare", "Age", "SibSp", "Parch"])
|> to_tensor()
{id_list, label_list, inputs}
end
end
followers_df = DataFrame.filter(followers_df, col("Ticket") != "LINE")
{
train_id_list,
train_label_list,
train_inputs
} = PreProcess.process(train_data_input, "PassengerId", "Survived", followers_df)
{
test_id_list,
test_label_list,
test_inputs
} = PreProcess.process(test_data_input, "PassengerId", "Survived", followers_df)
モデルの定義
model =
Axon.input("input", shape: {nil, 11})
|> Axon.dense(48, activation: :tanh)
|> Axon.dropout(rate: 0.2)
|> Axon.dense(48, activation: :tanh)
|> Axon.dense(1, activation: :sigmoid)
学習
train_data = Enum.zip(train_inputs, train_label_list)
loss_plot =
VegaLite.new(width: 300)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "step", type: :quantitative)
|> VegaLite.encode_field(:y, "loss", type: :quantitative)
|> Kino.VegaLite.new()
acc_plot =
VegaLite.new(width: 300)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "step", type: :quantitative)
|> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
|> Kino.VegaLite.new()
Kino.Layout.grid([loss_plot, acc_plot], columns: 2)
trained_state =
model
|> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adam(learning_rate: 0.0005))
|> Axon.Loop.metric(:accuracy, "accuracy")
|> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
|> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 50, compiler: EXLA)
未知データに対する推論
results =
test_inputs
|> Nx.concatenate()
|> then(&Axon.predict(model, trained_state, &1))
|> Nx.to_flat_list()
|> Enum.map(&round(&1))
|> then(
&%{
"PassengerId" => test_id_list,
"Survived" => &1
}
)
|> DataFrame.new()
Kino.DataTable.new(results)
results
|> DataFrame.dump_csv!()
|> then(&Kino.Download.new(fn -> &1 end, filename: "result.csv"))