torchtree.nf.planar#
Classes#
Implementation of the transformation used in planar flow: |
Module Contents#
- class torchtree.nf.planar.PlanarTransform(u: torch.nn.Parameter, w: torch.nn.Parameter, b: torch.nn.Parameter)[source]#
Bases:
torch.nn.Module
Implementation of the transformation used in planar flow:
f(z) = z + u * tanh(dot(w.T, z) + b)
where z are the inputs and u, w, and b are learnable parameters. The shape of z is (batch_size, input_size).
- Parameters:
- u#
- w#
- b#