Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

機械に学習させる

ch01/01_origin.livemd

機械に学習させる

Mix.install([
  {:axon, "~> 0.5"},
  {:nx, "~> 0.5"},
  {:explorer, "~> 0.5"},
  {:kino, "~> 0.8"}
])

データフレーム

require Explorer.DataFrame, as: DF
Explorer.DataFrame
iris = Explorer.Datasets.iris()
#Explorer.DataFrame<
  Polars[150 x 5]
  sepal_length f64 [5.1, 4.9, 4.7, 4.6, 5.0, ...]
  sepal_width f64 [3.5, 3.0, 3.2, 3.1, 3.6, ...]
  petal_length f64 [1.4, 1.4, 1.3, 1.5, 1.4, ...]
  petal_width f64 [0.2, 0.2, 0.2, 0.2, 0.2, ...]
  species string ["Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", ...]
>

学習用のデータを準備する

normalized_iris =
  DF.mutate(
    iris,
    for col <- across(~w[sepal_width sepal_length petal_length petal_width]) do
      {col.name, (col - mean(col)) / variance(col)}
    end
  )
#Explorer.DataFrame<
  Polars[150 x 5]
  sepal_length f64 [-1.0840606189132322, -1.3757361217598405, -1.66741162460645,
   -1.8132493760297554, -1.2298983703365363, ...]
  sepal_width f64 [2.3722896125315045, -0.28722789030650403, 0.7765791108287005, 0.2446756102610982,
   2.9041931130991068, ...]
  petal_length f64 [-0.7576391687443839, -0.7576391687443839, -0.7897606710936369,
   -0.7255176663951307, -0.7576391687443839, ...]
  petal_width f64 [-1.7147014356654708, -1.7147014356654708, -1.7147014356654708,
   -1.7147014356654708, -1.7147014356654708, ...]
  species string ["Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", ...]
>
shuffled_normalized_iris = DF.shuffle(normalized_iris)
#Explorer.DataFrame<
  Polars[150 x 5]
  sepal_length float [1.2493434038596445, -1.2298983703365356, 1.6868566581295583,
   0.22847914389651144, -0.2090341103734024, ...]
  sepal_width float [-0.28722789030650403, -4.010552394279718, 0.7765791108287006,
   -4.5424558948473175, -2.9467453931445133, ...]
  petal_length float [0.3987349158287289, -0.1473306241085746, 0.3023704087809695,
   0.07751989233619747, 0.3987349158287289, ...]
  petal_width float [0.8607846993460835, -0.3411088303259749, 0.3456874723437727,
   -0.3411088303259749, 1.3758819263483943, ...]
  species string ["Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor",
   "Iris-virginica", ...]
>

トレーニング用データとテスト用データに分ける

train_df = DF.slice(shuffled_normalized_iris, 0..119)
test_df = DF.slice(shuffled_normalized_iris, 120..149)
#Explorer.DataFrame<
  Polars[30 x 5]
  sepal_length float [0.8118301495897308, -1.959087127453059, 1.2493434038596445,
   0.8118301495897308, -1.0840606189132314, ...]
  sepal_width float [0.7765791108287006, -4.010552394279718, 0.24467561026109824,
   -1.3510348914417087, 2.372289612531505, ...]
  petal_length float [0.23812740408246316, -0.7897606710936372, 0.20600590173321015,
   0.5914639299242476, -0.7576391687443842, ...]
  petal_width float [0.5173865480112098, -1.5430023599980331, 0.3456874723437727,
   1.7192800776832686, -1.7147014356654704, ...]
  species string ["Iris-versicolor", "Iris-setosa", "Iris-versicolor", "Iris-virginica",
   "Iris-setosa", ...]
>

テンソルに変換する

feature_columns = [
  "sepal_length",
  "sepal_width",
  "petal_length",
  "petal_width"
]

label_column = "species"

x_train = Nx.stack(train_df[feature_columns], axis: 1)

y_train =
  train_df
  |> DF.pull(label_column)
  |> Explorer.Series.to_list()
  |> Enum.map(fn
    "Iris-setosa" -> 0
    "Iris-versicolor" -> 1
    "Iris-virginica" -> 2
  end)
  |> Nx.tensor(type: :u8)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 3}, axis: -1))

x_test = Nx.stack(test_df[feature_columns], axis: 1)

y_test =
  test_df
  |> DF.pull(label_column)
  |> Explorer.Series.to_list()
  |> Enum.map(fn
    "Iris-setosa" -> 0
    "Iris-versicolor" -> 1
    "Iris-virginica" -> 2
  end)
  |> Nx.tensor(type: :u8)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 3}, axis: -1))
#Nx.Tensor<
  u8[30][3]
  [
    [0, 1, 0],
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 0],
    [0, 1, 0],
    [0, 1, 0],
    [0, 0, 1],
    [0, 0, 1],
    [0, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 0, ...],
    ...
  ]
>

Axon を使った多項ロジスティック回帰

まずはモデル定義

model =
  Axon.input("iris_features")
  |> Axon.dense(3, activation: :softmax)
#Axon<
  inputs: %{"iris_features" => nil}
  outputs: "softmax_0"
  nodes: 3
>
Axon.Display.as_graph(model, Nx.template({1, 4}, :f32))

インプットパイプラインを宣言

data_stream =
  Stream.repeatedly(fn ->
    {x_train, y_train}
  end)
#Function<51.6935098/2 in Stream.repeatedly/1>

学習ループを回す

trained_model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(data_stream, %{}, iterations: 500, epochs: 10)
Epoch: 0, Batch: 450, accuracy: 0.8049530 loss: 0.4976027
Epoch: 1, Batch: 450, accuracy: 0.8740941 loss: 0.4151042
Epoch: 2, Batch: 450, accuracy: 0.9118048 loss: 0.3731178
Epoch: 3, Batch: 450, accuracy: 0.9333282 loss: 0.3441738
Epoch: 4, Batch: 450, accuracy: 0.9285069 loss: 0.3219456
Epoch: 5, Batch: 450, accuracy: 0.9249966 loss: 0.3039479
Epoch: 6, Batch: 450, accuracy: 0.9320097 loss: 0.2889073
Epoch: 7, Batch: 450, accuracy: 0.9469921 loss: 0.2760663
Epoch: 8, Batch: 450, accuracy: 0.9538645 loss: 0.2649276
Epoch: 9, Batch: 450, accuracy: 0.9630468 loss: 0.2551444
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[3]
      [-0.39006364345550537, 1.434131383895874, -1.0440688133239746]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][3]
      [
        [-0.3026171624660492, 0.8268378376960754, 1.2651383876800537],
        [1.0946170091629028, -0.18169042468070984, -0.23133960366249084],
        [-1.5668177604675293, 0.2944740056991577, 1.5785413980484009],
        [-1.1826088428497314, -0.2164556086063385, 2.283381223678589]
      ]
    >
  }
}

学習済みモデルの評価

data = [{x_test, y_test}]

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(data, trained_model_state)
Batch: 0, accuracy: 0.9000000
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      0.8999999761581421
    >
  }
}

学習済みモデルの書き出し

serialized_model_state = Nx.serialize(trained_model_state)
File.write!("iris_model_state.nx", serialized_model_state)
:ok