Skip to content

Transforms

DiagTransform

Bases: Transform

Applies transformation to the diagonal of a square matrix

Source code in src/rs_distributions/transforms/fill_scale_tril.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class DiagTransform(Transform):
    """
    Applies transformation to the diagonal of a square matrix
    """

    def __init__(self, diag_transform):
        super().__init__()
        self.diag_transform = diag_transform

    @property
    def domain(self):
        return self.diag_transform.domain

    @property
    def codomain(self):
        return self.diag_transform.codomain

    @property
    def bijective(self):
        return self.diag_transform.bijective

    def _call(self, x):
        """
        Args:
            x (torch.Tensor): Input matrix
        Returns
            torch.Tensor: Transformed matrix
        """
        diagonal = x.diagonal(dim1=-2, dim2=-1)
        transformed_diagonal = self.diag_transform(diagonal)
        result = x.diagonal_scatter(transformed_diagonal, dim1=-2, dim2=-1)

        return result

    def _inverse(self, y):
        diagonal = y.diagonal(dim1=-2, dim2=-1)
        result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1)
        return result

    def log_abs_det_jacobian(self, x, y):
        diagonal = x.diagonal(dim1=-2, dim2=-1)
        return self.diag_transform.log_abs_det_jacobian(diagonal, y)

FillScaleTriL

Bases: ComposeTransform

A ComposeTransform that reshapes a real-valued vector into a lower triangular matrix. The diagonal of the matrix is transformed with diag_transform.

Source code in src/rs_distributions/transforms/fill_scale_tril.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class FillScaleTriL(ComposeTransform):
    """
    A `ComposeTransform` that reshapes a real-valued vector into a lower triangular matrix.
    The diagonal of the matrix is transformed with `diag_transform`.
    """

    def __init__(self, diag_transform=None):
        if diag_transform is None:
            diag_transform = torch.distributions.ComposeTransform(
                (
                    SoftplusTransform(),
                    AffineTransform(1e-5, 1.0),
                )
            )
        super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])
        self.diag_transform = diag_transform

    @property
    def bijective(self):
        return True

    def log_abs_det_jacobian(self, x, y):
        x = FillTriL()._call(x)
        diagonal = x.diagonal(dim1=-2, dim2=-1)
        return self.diag_transform.log_abs_det_jacobian(diagonal, diagonal)

    @staticmethod
    def params_size(event_size):
        """
        Returns the number of parameters required to create an n-by-n lower triangular matrix, which is given by n*(n+1)//2

        Args:
            event_size (int): size of event
        Returns:
            int: Number of parameters needed

        """
        return event_size * (event_size + 1) // 2

params_size(event_size) staticmethod

Returns the number of parameters required to create an n-by-n lower triangular matrix, which is given by n*(n+1)//2

Parameters:

Name Type Description Default
event_size int

size of event

required

Returns: int: Number of parameters needed

Source code in src/rs_distributions/transforms/fill_scale_tril.py
117
118
119
120
121
122
123
124
125
126
127
128
@staticmethod
def params_size(event_size):
    """
    Returns the number of parameters required to create an n-by-n lower triangular matrix, which is given by n*(n+1)//2

    Args:
        event_size (int): size of event
    Returns:
        int: Number of parameters needed

    """
    return event_size * (event_size + 1) // 2

FillTriL

Bases: Transform

Transform for converting a real-valued vector into a lower triangular matrix

Source code in src/rs_distributions/transforms/fill_scale_tril.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class FillTriL(Transform):
    """
    Transform for converting a real-valued vector into a lower triangular matrix
    """

    def __init__(self):
        super().__init__()

    @property
    def domain(self):
        return constraints.real_vector

    @property
    def codomain(self):
        return constraints.lower_triangular

    @property
    def bijective(self):
        return True

    def _call(self, x):
        """
        Converts real-valued vector to lower triangular matrix.

        Args:
            x (torch.Tensor): input real-valued vector
        Returns:
            torch.Tensor: Lower triangular matrix
        """

        return vec_to_tril_matrix(x)

    def _inverse(self, y):
        return tril_matrix_to_vec(y)

    def log_abs_det_jacobian(self, x, y):
        batch_shape = x.shape[:-1]
        return torch.zeros(batch_shape, dtype=x.dtype, device=x.device)