Source code for torchtree.inference.mcmc.operator

from __future__ import annotations

import abc
import math
import statistics
from collections import deque
from typing import Any

import torch
from torch import Tensor

from torchtree.core.identifiable import Identifiable
from torchtree.core.utils import process_objects, register_class
from torchtree.typing import ID, Parameter


[docs] class MCMCOperator(Identifiable, abc.ABC): def __init__( self, id_: ID, parameters: list[Parameter], weight: float, target_acceptance_probability: float, **kwargs, ): super().__init__(id_) self.parameters = parameters self.weight = weight self.target_acceptance_probability = target_acceptance_probability self._adapt_count = 0 self._accept = 0 self._reject = 0 self._disable_adaptation = kwargs.get("disable_adaptation", False) self._accept_window_length = kwargs.get("acceptance_window_length", 100) self._accept_window = deque() @property @abc.abstractmethod def tuning_parameter(self) -> float: pass @property @abc.abstractmethod def adaptable_parameter(self) -> float: pass @adaptable_parameter.setter def adaptable_parameter(self, value: float) -> None: self.set_adaptable_parameter(value) self._adapt_count += 1
[docs] @abc.abstractmethod def set_adaptable_parameter(self, value: float) -> None: pass
@abc.abstractmethod def _step(self) -> Tensor: pass
[docs] def step(self) -> Tensor: self.saved_tensors = [parameter.tensor.clone() for parameter in self.parameters] return self._step()
[docs] def accept(self) -> None: self._accept += 1 self._accept_window.append(1) if len(self._accept_window) > self._accept_window_length: self._accept_window.popleft()
[docs] def reject(self) -> None: for parameter, saved_tensor in zip(self.parameters, self.saved_tensors): parameter.tensor = saved_tensor self._reject += 1 self._accept_window.append(0) if len(self._accept_window) > self._accept_window_length: self._accept_window.popleft()
[docs] def smoothed_acceptance_rate(self) -> float: if len(self._accept_window) == 0: return math.nan return statistics.mean(self._accept_window)
[docs] def tune(self, acceptance_prob: Tensor, sample: int, accepted: bool) -> None: if not self._disable_adaptation: assert 0.0 <= acceptance_prob <= 1.0 new_parameter = self.adaptable_parameter + ( acceptance_prob.item() - self.target_acceptance_probability ) / (2 + self._adapt_count) self.adaptable_parameter = new_parameter
[docs] def state_dict(self) -> dict[str, Any]: state_dict = { "id": self.id, "adapt_count": self._adapt_count, "accept": self._accept, "reject": self._reject, "accept_window": list(self._accept_window), } state_dict.update(self._state_dict()) return state_dict
@abc.abstractmethod def _state_dict(self) -> dict[str, Any]: pass
[docs] def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._adapt_count = state_dict["adapt_count"] self._accept = state_dict["accept"] self._reject = state_dict["reject"] self._accept_window = deque(state_dict["accept_window"]) self._load_state_dict(state_dict)
@abc.abstractmethod def _load_state_dict(self, state_dict: dict[str, Any]) -> None: pass @staticmethod def _parse_json(data, dic): id_ = data["id"] parameters = process_objects(data["parameters"], dic, force_list=True) weight = data.get("weight", 1.0) target_acceptance_probability = data.get("target_acceptance_probability", 0.24) optionals = {} optionals["disable_adaptation"] = data.get("disable_adaptation", False) optionals["acceptance_window_length"] = data.get( "acceptance_window_length", False ) return id_, parameters, weight, target_acceptance_probability, optionals
[docs] @register_class class ScalerOperator(MCMCOperator): def __init__( self, id_: ID, parameters: list[Parameter], weight: float, target_acceptance_probability: float, scaler: float, **kwargs, ): super().__init__( id_, parameters, weight, target_acceptance_probability, **kwargs ) self._scaler = scaler @property def tuning_parameter(self) -> float: return self._scaler @MCMCOperator.adaptable_parameter.getter def adaptable_parameter(self) -> float: return math.log(1.0 / self._scaler - 1.0)
[docs] def set_adaptable_parameter(self, value: float) -> None: self._scaler = 1.0 / (math.exp(value) + 1.0)
def _step(self) -> Tensor: s = self._scaler + ( torch.rand(1).item() * ((1.0 / self._scaler) - self._scaler) ) index = torch.randint(0, len(self.parameters), (1,)).item() index2 = torch.randint(0, len(self.parameters[index].tensor), (1,)).item() p = self.parameters[index].tensor p[index2] *= s self.parameters[index].tensor = p # this does not trigger listeners: # self.parameters[index].tensor[index2] *= s return -torch.tensor( s, device=self.parameters[0].device, dtype=self.parameters[0].dtype ).log() def _state_dict(self) -> dict[str, Any]: return {"scaler": self._scaler} def _load_state_dict(self, state_dict: dict[str, Any]) -> None: self._scaler = state_dict["scaler"]
[docs] @classmethod def from_json(cls, data, dic): ( id_, parameters, weight, target_acceptance_probability, optionals, ) = MCMCOperator._parse_json(data, dic) scaler = data.get("scaler", 0.1) return cls( id_, parameters, weight, target_acceptance_probability, scaler, **optionals )
[docs] @register_class class SlidingWindowOperator(MCMCOperator): def __init__( self, id_: ID, parameters: list[Parameter], weight: float, target_acceptance_probability: float, width: float, **kwargs, ) -> None: super().__init__( id_, parameters, weight, target_acceptance_probability, **kwargs ) self._width = width @property def tuning_parameter(self) -> float: return self._width @MCMCOperator.adaptable_parameter.getter def adaptable_parameter(self) -> float: return math.log(self._width)
[docs] def set_adaptable_parameter(self, value: float) -> None: self._width = math.exp(value)
def _step(self) -> Tensor: shift = self._width * (torch.rand(1).item() - 0.5) index = torch.randint(0, len(self.parameters), (1,)).item() index2 = torch.randint(0, len(self.parameters[index].tensor), (1,)).item() p = self.parameters[index].tensor p[index2] += shift self.parameters[index].tensor = p return torch.tensor( 0.0, device=self.parameters[0].device, dtype=self.parameters[0].dtype ) def _state_dict(self) -> dict[str, Any]: return {"width": self._width} def _load_state_dict(self, state_dict: dict[str, Any]) -> None: self._width = state_dict["width"]
[docs] @classmethod def from_json(cls, data, dic): ( id_, parameters, weight, target_acceptance_probability, optionals, ) = MCMCOperator._parse_json(data, dic) width = data.get("width", 0.1) return cls( id_, parameters, weight, target_acceptance_probability, width, **optionals )
[docs] @register_class class DirichletOperator(MCMCOperator): def __init__( self, id_: ID, parameters: Parameter, weight: float, target_acceptance_probability: float, scaler: float, **kwargs, ) -> None: super().__init__( id_, parameters, weight, target_acceptance_probability, **kwargs ) self._scaler = scaler @property def tuning_parameter(self) -> float: return self._scaler @MCMCOperator.adaptable_parameter.getter def adaptable_parameter(self) -> float: return math.log(self._scaler)
[docs] def set_adaptable_parameter(self, value: float) -> None: self._scaler = math.exp(value)
def _step(self) -> Tensor: old_values = self.parameters[0].tensor scaled_old = old_values * self._scaler dist_old = torch.distributions.Dirichlet(scaled_old) new_values = dist_old.sample() scaled_new = new_values * self._scaler self.parameters[0].tensor = new_values f = dist_old.log_prob(new_values) b = torch.distributions.Dirichlet(scaled_new).log_prob(old_values) return b - f def _state_dict(self) -> dict[str, Any]: return {"scaler": self._scaler} def _load_state_dict(self, state_dict: dict[str, Any]) -> None: self._scaler = state_dict["scaler"]
[docs] @classmethod def from_json(cls, data, dic): ( id_, parameters, weight, target_acceptance_probability, optionals, ) = MCMCOperator._parse_json(data, dic) scaler = data.get("scaler", 1.0) return cls( id_, parameters, weight, target_acceptance_probability, scaler, **optionals )