Model hooks
Mix.install([
{:axon, ">= 0.5.0"}
])
:ok
Creating models with hooks
Sometimes it’s useful to inspect or visualize the values of intermediate layers in your model during the forward or backward pass. For example, it’s common to visualize the gradients of activation functions to ensure your model is learning in a stable manner. Axon supports this functionality via model hooks.
Model hooks are a means of unidirectional communication with an executing model. Hooks are unidirectional in the sense that you can only receive information from your model, and not send information back.
Hooks are attached per-layer and can execute at 4 different points in model execution: on the pre-forward, forward, or backward pass of the model or during model initialization. You can also configure the same hook to execute on all 3 events. You can attach hooks to models using Axon.attach_hook/3
:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.attach_hook(fn val -> IO.inspect(val, label: :dense_forward) end, on: :forward)
|> Axon.attach_hook(fn val -> IO.inspect(val, label: :dense_init) end, on: :initialize)
|> Axon.relu()
|> Axon.attach_hook(fn val -> IO.inspect(val, label: :relu) end, on: :forward)
{init_fn, predict_fn} = Axon.build(model)
input = Nx.iota({2, 4}, type: :f32)
params = init_fn.(input, %{})
dense_init: %{
"bias" => #Nx.Tensor<
f32[8]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[0.6067318320274353, 0.5483129620552063, -0.05663269758224487, -0.48249542713165283, -0.18357598781585693, 0.6496620774269104, 0.4919115900993347, -0.08380156755447388],
[-0.19745409488677979, 0.10483592748641968, -0.43387970328330994, -0.1041460633277893, -0.4129607081413269, -0.6482449769973755, 0.6696910262107849, 0.4690167307853699],
[-0.18194729089736938, -0.4856645464897156, 0.39400774240493774, -0.28496378660202026, 0.32120805978775024, -0.41854584217071533, 0.5671316981315613, -0.21937215328216553],
[0.4516749978065491, -0.23585206270217896, -0.6682141423225403, 0.4286096692085266, -0.14930623769760132, -0.3825327157974243, 0.2700549364089966, -0.3888852596282959]
]
>
}
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[0.6067318320274353, 0.5483129620552063, -0.05663269758224487, -0.48249542713165283, -0.18357598781585693, 0.6496620774269104, 0.4919115900993347, -0.08380156755447388],
[-0.19745409488677979, 0.10483592748641968, -0.43387970328330994, -0.1041460633277893, -0.4129607081413269, -0.6482449769973755, 0.6696910262107849, 0.4690167307853699],
[-0.18194729089736938, -0.4856645464897156, 0.39400774240493774, -0.28496378660202026, 0.32120805978775024, -0.41854584217071533, 0.5671316981315613, -0.21937215328216553],
[0.4516749978065491, -0.23585206270217896, -0.6682141423225403, 0.4286096692085266, -0.14930623769760132, -0.3825327157974243, 0.2700549364089966, -0.3888852596282959]
]
>
}
}
Notice how during initialization the :dense_init
hook fired and inspected the layer’s parameters. Now when executing, you’ll see outputs for :dense
and :relu
:
predict_fn.(params, input)
relu: #Nx.Tensor<
f32[2][8]
[
[0.7936763167381287, 0.0, 0.0, 0.61175537109375, 0.0, 0.0, 2.614119291305542, 0.0],
[3.5096981525421143, 0.0, 0.0, 0.0, 0.0, 0.0, 10.609275817871094, 0.0]
]
>
#Nx.Tensor<
f32[2][8]
[
[0.7936763167381287, 0.0, 0.0, 0.61175537109375, 0.0, 0.0, 2.614119291305542, 0.0],
[3.5096981525421143, 0.0, 0.0, 0.0, 0.0, 0.0, 10.609275817871094, 0.0]
]
>
It’s important to note that hooks execute in the order they were attached to a layer. If you attach 2 hooks to the same layer which execute different functions on the same event, they will run in order:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.attach_hook(fn val -> IO.inspect(val, label: :hook1) end, on: :forward)
|> Axon.attach_hook(fn val -> IO.inspect(val, label: :hook2) end, on: :forward)
|> Axon.relu()
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(input, %{})
predict_fn.(params, input)
hook2: #Nx.Tensor<
f32[2][8]
[
[-0.6567458510398865, 2.2303993701934814, -1.540865421295166, -1.873536229133606, -2.386439085006714, -1.248870849609375, -2.9092607498168945, -0.1976098120212555],
[2.4088101387023926, 5.939034461975098, -2.024522066116333, -7.58249568939209, -10.193460464477539, 0.33839887380599976, -10.836882591247559, 1.8173918724060059]
]
>
#Nx.Tensor<
f32[2][8]
[
[0.0, 2.2303993701934814, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[2.4088101387023926, 5.939034461975098, 0.0, 0.0, 0.0, 0.33839887380599976, 0.0, 1.8173918724060059]
]
>
Notice that :hook1
fires before :hook2
.
You can also specify a hook to fire on all events:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.attach_hook(&IO.inspect/1, on: :all)
|> Axon.relu()
|> Axon.dense(1)
{init_fn, predict_fn} = Axon.build(model)
{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>,
#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>}
On initialization:
params = init_fn.(input, %{})
%{
"bias" => #Nx.Tensor<
f32[8]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[0.2199305295944214, -0.05434012413024902, -0.07989239692687988, -0.4456246793270111, -0.2792319655418396, -0.1601254940032959, -0.6115692853927612, 0.37740427255630493],
[-0.3606935739517212, 0.6091846823692322, -0.3203054368495941, -0.6252920031547546, -0.41500264406204224, -0.20729252696037292, -0.6763507127761841, -0.6776859164237976],
[0.659041702747345, -0.615885317325592, -0.45865312218666077, 0.18774819374084473, 0.31994110345840454, -0.3055777847766876, -0.3537192642688751, 0.4297131896018982],
[0.06112170219421387, 0.13321959972381592, 0.5566524863243103, -0.1115691065788269, -0.3557875156402588, -0.03118818998336792, -0.5788122415542603, -0.6988758444786072]
]
>
}
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[0.2199305295944214, -0.05434012413024902, -0.07989239692687988, -0.4456246793270111, -0.2792319655418396, -0.1601254940032959, -0.6115692853927612, 0.37740427255630493],
[-0.3606935739517212, 0.6091846823692322, -0.3203054368495941, -0.6252920031547546, -0.41500264406204224, -0.20729252696037292, -0.6763507127761841, -0.6776859164237976],
[0.659041702747345, -0.615885317325592, -0.45865312218666077, 0.18774819374084473, 0.31994110345840454, -0.3055777847766876, -0.3537192642688751, 0.4297131896018982],
[0.06112170219421387, 0.13321959972381592, 0.5566524863243103, -0.1115691065788269, -0.3557875156402588, -0.03118818998336792, -0.5788122415542603, -0.6988758444786072]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.0]
>,
"kernel" => #Nx.Tensor<
f32[8][1]
[
[0.3259686231613159],
[0.4874255657196045],
[0.6338149309158325],
[0.4437469244003296],
[-0.22870665788650513],
[0.8108665943145752],
[7.919073104858398e-4],
[0.4469025135040283]
]
>
}
}
On pre-forward and forward:
predict_fn.(params, input)
#Nx.Tensor<
f32[2][4]
[
[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0]
]
>
#Nx.Tensor<
f32[2][8]
[
[1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228],
[3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334]
]
>
#Nx.Tensor<
f32[2][8]
[
[1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228],
[3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334]
]
>
#Nx.Tensor<
f32[2][1]
[
[0.6458775401115417],
[1.1593825817108154]
]
>
And on backwards:
Nx.Defn.grad(fn params -> predict_fn.(params, input) end).(params)
#Nx.Tensor<
f32[2][4]
[
[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0]
]
>
#Nx.Tensor<
f32[2][8]
[
[1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228],
[3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334]
]
>
#Nx.Tensor<
f32[2][8]
[
[1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228],
[3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334]
]
>
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.6519372463226318, 0.4874255657196045, 0.6338149309158325, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[1.3038744926452637, 1.949702262878418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.9558117389678955, 2.4371278285980225, 0.6338149309158325, 0.0, 0.0, 0.0, 0.0, 0.0],
[2.6077489852905273, 2.924553394317627, 1.267629861831665, 0.0, 0.0, 0.0, 0.0, 0.0],
[3.259686231613159, 3.4119789600372314, 1.9014447927474976, 0.0, 0.0, 0.0, 0.0, 0.0]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[1]
[2.0]
>,
"kernel" => #Nx.Tensor<
f32[8][1]
[
[4.599111557006836],
[0.06578820943832397],
[0.43234577775001526],
[0.0],
[0.0],
[0.0],
[0.0],
[0.0]
]
>
}
}
Finally, you can specify hooks to only run when the model is built in a certain mode such as training and inference mode. You can read more about training and inference mode in Training and inference mode:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.attach_hook(&IO.inspect/1, on: :forward, mode: :train)
|> Axon.relu()
{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(input, %{})
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[-0.13241732120513916, 0.6946331858634949, -0.6328000426292419, -0.684409499168396, -0.39569517970085144, -0.10005003213882446, 0.2501150965690613, 0.14561182260513306],
[-0.5495109558105469, 0.459137499332428, -0.4059434235095978, -0.4489462077617645, -0.6331832408905029, 0.05011630058288574, -0.35836488008499146, -0.2661571800708771],
[0.29260867834091187, 0.42186349630355835, 0.32596689462661743, -0.12340176105499268, 0.6767188906669617, 0.2658537030220032, 0.5745270848274231, 6.475448608398438e-4],
[0.16781508922576904, 0.23747843503952026, -0.5311254858970642, 0.22617805004119873, -0.5153165459632874, 0.19729173183441162, -0.5706893801689148, -0.5531126260757446]
]
>
}
}
The model was built in training mode so the hook will run:
predict_fn.(params, input)
#Nx.Tensor<
f32[2][8]
[
[0.539151668548584, 2.0152997970581055, -1.347386121749878, -0.017215579748153687, -0.8256950974464417, 1.173698902130127, -0.9213788509368896, -1.9241999387741089],
[-0.3468663692474365, 9.267749786376953, -6.322994232177734, -4.139533042907715, -4.295599460601807, 2.8265457153320312, -1.3390271663665771, -4.616241931915283]
]
>
%{
prediction: #Nx.Tensor<
f32[2][8]
[
[0.539151668548584, 2.0152997970581055, 0.0, 0.0, 0.0, 1.173698902130127, 0.0, 0.0],
[0.0, 9.267749786376953, 0.0, 0.0, 0.0, 2.8265457153320312, 0.0, 0.0]
]
>,
state: %{}
}
{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(input, %{})
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[4][8]
[
[0.02683490514755249, -0.28041765093803406, 0.15839070081710815, 0.16674137115478516, -0.5444575548171997, -0.34951671957969666, 0.08247309923171997, 0.6700448393821716],
[0.6001952290534973, -0.26907777786254883, 0.4580194354057312, -0.060002803802490234, -0.5385662317276001, -0.46773862838745117, 0.25804388523101807, -0.6824946999549866],
[0.13328874111175537, -0.46421635150909424, -0.5192649960517883, -0.0429919958114624, 0.0771912932395935, -0.447194904088974, 0.30910569429397583, -0.6105270981788635],
[0.5253992676734924, 0.41786473989486694, 0.6903378367424011, 0.6038702130317688, 0.06673228740692139, 0.4242702126502991, -0.6737087368965149, -0.6956207156181335]
]
>
}
}
The model was built in inference mode so the hook will not run:
predict_fn.(params, input)
#Nx.Tensor<
f32[2][8]
[
[2.4429705142974854, 0.056083738803863525, 1.490502953529358, 1.6656239032745361, 0.0, 0.0, 0.0, 0.0],
[7.585843086242676, 0.0, 4.640434741973877, 4.336091041564941, 0.0, 0.0, 0.0, 0.0]
]
>