Skip to content

Modules

DistributionModule

Bases: Module

Base class for constructing learnable distributions. This subclass of torch.nn.Module acts like a torch.distributions.Distribution object with learnable torch.nn.Parameter attributes. It works by lazily constructing distributions as needed. Here is a simple example of distribution matching using learnable distributions with reparameterized gradients.

from rs_distributions import modules as rsm
import torch

q = rsm.FoldedNormal(10., 5.)
p = torch.distributions.HalfNormal(1.)

opt = torch.optim.Adam(q.parameters())

steps = 10_000
num_samples = 256
for i in range(steps):
    opt.zero_grad()
    z = q.rsample((num_samples,))
    kl = (q.log_prob(z) - p.log_prob(z)).mean()
    kl.backward()
    opt.step()
Source code in src/rs_distributions/modules/distribution.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class DistributionModule(torch.nn.Module):
    """
    Base class for constructing learnable distributions.
    This subclass of `torch.nn.Module` acts like a `torch.distributions.Distribution`
    object with learnable `torch.nn.Parameter` attributes.
    It works by lazily constructing distributions as needed.
    Here is a simple example of distribution matching using learnable distributions with reparameterized gradients.

    ```python
    from rs_distributions import modules as rsm
    import torch

    q = rsm.FoldedNormal(10., 5.)
    p = torch.distributions.HalfNormal(1.)

    opt = torch.optim.Adam(q.parameters())

    steps = 10_000
    num_samples = 256
    for i in range(steps):
        opt.zero_grad()
        z = q.rsample((num_samples,))
        kl = (q.log_prob(z) - p.log_prob(z)).mean()
        kl.backward()
        opt.step()
    ```
    """

    distribution_class = torch.distributions.Distribution
    __doc__ = distribution_class.__doc__
    arg_constraints = distribution_class.arg_constraints

    def __init__(self, *args, **kwargs):
        super().__init__()
        sig = signature(self.distribution_class)
        bargs = sig.bind(*args, **kwargs)
        bargs.apply_defaults()
        for arg in self.distribution_class.arg_constraints:
            param = bargs.arguments.pop(arg)
            param = self._constrain_arg_if_needed(arg, param)
            setattr(self, f"_transformed_{arg}", param)
        self._extra_args = bargs.arguments

    def __repr__(self):
        rstring = super().__repr__().split("\n")[1:]
        rstring = [str(self.distribution_class) + " DistributionModule("] + rstring
        return "\n".join(rstring)

    def _distribution(self):
        kwargs = {
            k: self._realize_parameter(getattr(self, f"_transformed_{k}"))
            for k in self.distribution_class.arg_constraints
        }
        kwargs.update(self._extra_args)
        return self.distribution_class(**kwargs)

    def _constrain_arg_if_needed(self, name, value):
        if isinstance(value, TransformedParameter):
            return value
        cons = self.distribution_class.arg_constraints[name]
        if cons == torch.distributions.constraints.dependent:
            transform = torch.distributions.AffineTransform(0.0, 1.0)
        else:
            transform = torch.distributions.constraint_registry.transform_to(cons)
        return TransformedParameter(value, transform)

    @staticmethod
    def _realize_parameter(param):
        if isinstance(param, TransformedParameter):
            return param()
        return param

    def __getattr__(self, name: str):
        if name in self.distribution_class.arg_constraints or hasattr(
            self.distribution_class, name
        ):
            q = self._distribution()
            return getattr(q, name)
        return super().__getattr__(name)

    @staticmethod
    def _extract_distributions(*modules, base_class=torch.distributions.Distribution):
        """
        extract all torch.distributions.Distribution subclasses from a module(s)
        into a dict {name: cls}
        """
        d = {}
        for module in modules:
            for k in module.__all__:
                distribution_class = getattr(module, k)
                if not hasattr(distribution_class, "arg_constraints"):
                    continue
                if not hasattr(distribution_class.arg_constraints, "items"):
                    continue
                if issubclass(distribution_class, base_class):
                    d[k] = distribution_class
        return d

    def __init_subclass__(cls, /, distribution_class, **kwargs):
        super().__init_subclass__(**kwargs)
        update_wrapper(
            cls.__init__,
            distribution_class.__init__,
        )
        cls.distribution_class = distribution_class
        cls.arg_constraints = distribution_class.arg_constraints
        cls.__doc__ = distribution_class.__doc__

TransformedParameter

Bases: Module

A torch.nn.Module subclass representing a constrained variabled.

Source code in src/rs_distributions/modules/transformed_parameter.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class TransformedParameter(torch.nn.Module):
    """
    A `torch.nn.Module` subclass representing a constrained variabled.
    """

    def __init__(self, value, transform):
        """
        Args:
            value : Tensor
                The initial value of this learnable parameter
            transform : torch.distributions.Transform
                A transform instance which is applied to the underlying, unconstrained value
        """
        super().__init__()
        value = torch.as_tensor(value)  # support floats
        if isinstance(value, torch.nn.Parameter):
            self._value = value
            value.data = transform.inv(value)
        else:
            self._value = torch.nn.Parameter(transform.inv(value))
        self.transform = transform

    def forward(self):
        return self.transform(self._value)

__init__(value, transform)

Parameters:

Name Type Description Default
value

Tensor The initial value of this learnable parameter

required
transform

torch.distributions.Transform A transform instance which is applied to the underlying, unconstrained value

required
Source code in src/rs_distributions/modules/transformed_parameter.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def __init__(self, value, transform):
    """
    Args:
        value : Tensor
            The initial value of this learnable parameter
        transform : torch.distributions.Transform
            A transform instance which is applied to the underlying, unconstrained value
    """
    super().__init__()
    value = torch.as_tensor(value)  # support floats
    if isinstance(value, torch.nn.Parameter):
        self._value = value
        value.data = transform.inv(value)
    else:
        self._value = torch.nn.Parameter(transform.inv(value))
    self.transform = transform