Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

Load NetCDF into Nx

examples/load_to_nx.livemd

Load NetCDF into Nx

Mix.install([
  {:netcdf, ">= 0.0.0", github: "dockyard/netcdf"},
  {:nx, "~> 0.3"}
])

Load the example data

The repository contains an example NetCDF file at priv/hello_word.nc. We will load the data and format it into Nx tensors.

# assuming the livebook is loaded at the repo's root dir
filepath = "./priv/hello_world.nc"
{:ok, file} = NetCDF.File.open(filepath)
# Load each of the variables
variables =
  for var_name <- file.variables do
    {:ok, var} = NetCDF.Variable.load(file, var_name)
    var
  end

as_nx_type = fn
  :i8 -> :s8
  :i16 -> :s16
  :i32 -> :s32
  :i64 -> :s64
  t -> t
end

var_tensors =
  for var <- variables, into: %{} do
    t = Nx.tensor(var.value, type: as_nx_type.(var.type))
    {var.name, t}
  end

# For this variable, we also know that the "temp" variable is a function of lat, lon and time,
# so we need to reshape its value accordingly

size_lat = Nx.size(var_tensors["lat"])
size_lon = Nx.size(var_tensors["lon"])

var_tensors =
  Map.update!(var_tensors, "temp", fn t ->
    Nx.reshape(t, {size_lat, size_lon, :auto}, names: [:lat, :lon, :time])
  end)

We can now manipulate the tensors normally as Nx tensors. In the example below, we will select only the data that’s contained within a given geofence (lat and lon bounds)

min_lat = -85
max_lat = -70

min_lon = 10
max_lon = 30

temp = var_tensors["temp"]
lat = var_tensors["lat"]
lon = var_tensors["lon"]

lat_selector = Nx.greater_equal(lat, min_lat) |> Nx.logical_and(Nx.less_equal(lat, max_lat))
lon_selector = Nx.greater_equal(lon, min_lon) |> Nx.logical_and(Nx.less_equal(lon, max_lon))

lat_start = Nx.argmax(lat_selector, tie_break: :low)
lat_len = Nx.sum(lat_selector) |> Nx.to_number()

lon_start = Nx.argmax(lon_selector, tie_break: :low)
lon_len = Nx.sum(lon_selector) |> Nx.to_number()

# Note: Nx.to_number doesn't work in defn, so the length calculations must be done
# outside of defns. This happens because dynamic shapes aren't supported in Nx Defn

sliced_temp =
  Nx.slice(temp, [lat_start, lon_start, 0], [lat_len, lon_len, Nx.axis_size(temp, :time)])

sliced_lat = Nx.slice(lat, [lat_start], [lat_len])
sliced_lon = Nx.slice(lon, [lon_start], [lon_len])

{sliced_temp, sliced_lat, sliced_lon}