torchtree.evolution.tree_height_transform#

Classes#

GeneralNodeHeightTransform

Transform from ratios to node heights.

DifferenceNodeHeightTransform

Transform from node height differences to node heights.

Module Contents#

class torchtree.evolution.tree_height_transform.GeneralNodeHeightTransform(tree: TimeTreeModel, cache_size=0)[source]#

Bases: torch.distributions.Transform

Transform from ratios to node heights.

bijective = True#
sign#

Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms.

tree#
taxa_count#
sort_indices()[source]#
update_bounds() None[source]#

Called when topology changes.

log_abs_det_jacobian(x, y)[source]#

Computes the log det jacobian log |dy/dx| given input and output.

class torchtree.evolution.tree_height_transform.DifferenceNodeHeightTransform(tree_model: TimeTreeModel, k: float = 0.0, cache_size=0)[source]#

Bases: torch.distributions.Transform

Transform from node height differences to node heights.

The height \(x_i\) of node \(i\) is parameterized as

\[x_i = \max(x_{c(i,0)}, x_{c(i,1)}) + y_i\]

where \(x_c(i,j)\) is the height of the jth child of node \(i\) and \(y_i \in \mathbb{R}^+\). Function max can be approximated using logsumexp in order to propagate the gradient if k > 0.

bijective = True#
sign#

Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms.

tree#
taxa_count#
k#
log_abs_det_jacobian(x, y)[source]#

Computes the log det jacobian log |dy/dx| given input and output.