torchtree.nf.planar#

Classes#

PlanarTransform

Implementation of the transformation used in planar flow:

Module Contents#

class torchtree.nf.planar.PlanarTransform(u: torch.nn.Parameter, w: torch.nn.Parameter, b: torch.nn.Parameter)[source]#

Bases: torch.nn.Module

Implementation of the transformation used in planar flow:

f(z) = z + u * tanh(dot(w.T, z) + b)

where z are the inputs and u, w, and b are learnable parameters. The shape of z is (batch_size, input_size).

Parameters:
  • u (Parameter) – scaling factor with shape(1, input_size)

  • w (Parameter) – weight with shape (1, input_size)

  • b (Parameter) – bias with shape (1)

u#
w#
b#
forward(z: torch.Tensor) torch.Tensor[source]#
log_abs_det_jacobian(x: torch.Tensor, y: torch.Tensor) torch.Tensor[source]#
u_hat() torch.Tensor[source]#