Load NetCDF into Nx
Mix.install([
{:netcdf, ">= 0.0.0", github: "dockyard/netcdf"},
{:nx, "~> 0.3"}
])
Load the example data
The repository contains an example NetCDF file at priv/hello_word.nc
. We will load the data and format it into Nx tensors.
# assuming the livebook is loaded at the repo's root dir
filepath = "./priv/hello_world.nc"
{:ok, file} = NetCDF.File.open(filepath)
# Load each of the variables
variables =
for var_name <- file.variables do
{:ok, var} = NetCDF.Variable.load(file, var_name)
var
end
as_nx_type = fn
:i8 -> :s8
:i16 -> :s16
:i32 -> :s32
:i64 -> :s64
t -> t
end
var_tensors =
for var <- variables, into: %{} do
t = Nx.tensor(var.value, type: as_nx_type.(var.type))
{var.name, t}
end
# For this variable, we also know that the "temp" variable is a function of lat, lon and time,
# so we need to reshape its value accordingly
size_lat = Nx.size(var_tensors["lat"])
size_lon = Nx.size(var_tensors["lon"])
var_tensors =
Map.update!(var_tensors, "temp", fn t ->
Nx.reshape(t, {size_lat, size_lon, :auto}, names: [:lat, :lon, :time])
end)
We can now manipulate the tensors normally as Nx tensors. In the example below, we will select only the data that’s contained within a given geofence (lat and lon bounds)
min_lat = -85
max_lat = -70
min_lon = 10
max_lon = 30
temp = var_tensors["temp"]
lat = var_tensors["lat"]
lon = var_tensors["lon"]
lat_selector = Nx.greater_equal(lat, min_lat) |> Nx.logical_and(Nx.less_equal(lat, max_lat))
lon_selector = Nx.greater_equal(lon, min_lon) |> Nx.logical_and(Nx.less_equal(lon, max_lon))
lat_start = Nx.argmax(lat_selector, tie_break: :low)
lat_len = Nx.sum(lat_selector) |> Nx.to_number()
lon_start = Nx.argmax(lon_selector, tie_break: :low)
lon_len = Nx.sum(lon_selector) |> Nx.to_number()
# Note: Nx.to_number doesn't work in defn, so the length calculations must be done
# outside of defns. This happens because dynamic shapes aren't supported in Nx Defn
sliced_temp =
Nx.slice(temp, [lat_start, lon_start, 0], [lat_len, lon_len, Nx.axis_size(temp, :time)])
sliced_lat = Nx.slice(lat, [lat_start], [lat_len])
sliced_lon = Nx.slice(lon, [lon_start], [lon_len])
{sliced_temp, sliced_lat, sliced_lon}