Skip to content

transforms

LinearTransform

Bases: Transform

Class for linear transformations.

The transformation is defined as y = slope * x + intercept.

Attributes:

Name Type Description
slope float

The slope of the linear transformation.

intercept float

The intercept of the linear transformation.

Source code in gallifrey/inference/transforms.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
class LinearTransform(Transform):
    """
    Class for linear transformations.

    The transformation is defined as y = slope * x + intercept.

    Attributes
    ----------
    slope : float
        The slope of the linear transformation.
    intercept : float
        The intercept of the linear transformation.
    """

    def __init__(self, slope: ScalarFloat, intercept: ScalarFloat):
        """
        Initializes the LinearTransform object.

        Parameters
        ----------
        slope : ScalarFloat
            The slope of the linear transformation.
        intercept : ScalarFloat
            The intercept of the linear transformation.
        """
        self.slope = jnp.asarray(slope)
        self.intercept = jnp.asarray(intercept)

    def apply(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Applies the linear transformation to input x.

        Parameters
        ----------
        x :  Float[ArrayLike, "..."]
            The input data.

        Returns
        -------
         Float[jnp.ndarray, "..."]
            The transformed data.

        """
        return self.slope * x + self.intercept

    def unapply(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Unapplies the linear transformation to input x.

        Parameters
        ----------
        x :  Float[ArrayLike, "..."]
            The (reverse) transformed data.

        Returns
        -------
         Float[jnp.ndarray, "..."]
            The un-transformed data.
        """
        return jnp.asarray((x - self.intercept) / self.slope)

    def apply_mean(
        self,
        mean_val: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Applies the linear transformation to a mean value.

        Parameters
        ----------
        mean_val : Float[ArrayLike, "..."]
            The mean value to be transformed.

        Returns
        -------
        Float[jnp.ndarray, "..."]
            The transformed mean value.
        """
        return self.apply(jnp.asarray(mean_val))

    def unapply_mean(
        self,
        mean_val: Float[ArrayLike, "..."],
    ) -> Float[ArrayLike, "..."]:
        """
        Unapplies the linear transformation to a mean value.

        Parameters
        ----------
        mean_val : Float[ArrayLike, "..."]
            The mean value to be un-transformed.

        Returns
        -------
        Float[ArrayLike, "..."]
            The un-transformed mean value.
        """
        return self.unapply(jnp.asarray(mean_val))

    def apply_var(
        self,
        var_val: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Applies the linear transformation to a variance value.

        Parameters
        ----------
        var_val : Float[ArrayLike, "..."]
            The variance value to be transformed.

        Returns
        -------
        Float[jnp.ndarray, "..."]
            The transformed variance value.
        """
        return jnp.asarray(self.slope**2 * var_val)

    def unapply_var(
        self,
        var_val: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Unapplies the linear transformation to a variance value.

        Parameters
        ----------
        var_val : Float[ArrayLike, "..."]
            The variance value to be un-transformed.

        Returns
        -------
        Float[jnp.ndarray, "..."]
            The un-transformed variance value.
        """
        return jnp.asarray((1 / (self.slope**2)) * var_val)

    def apply_mean_var(
        self,
        mean_val: Float[ArrayLike, "..."],
        var_val: Float[ArrayLike, "..."],
    ) -> tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]:
        """
        Applies the linear transformation to mean and variance values.

        Parameters
        ----------
        mean_val : Float[ArrayLike, "..."]
            The mean value to be transformed.
        var_val : Float[ArrayLike, "..."]
            The variance value to be transformed.

        Returns
        -------
        tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]
            A tuple containing the transformed mean and variance values.

        """
        m = self.apply_mean(mean_val)
        v = self.apply_var(var_val)
        return (m, v)

    def unapply_mean_var(
        self,
        mean_val: Float[ArrayLike, "..."],
        var_val: Float[ArrayLike, "..."],
    ) -> tuple[Float[ArrayLike, "..."], Float[ArrayLike, "..."]]:
        """
        Unapplies the linear transformation to mean and variance values.

        Parameters
        ----------
        mean_val : Float[ArrayLike, "..."]
            The mean value to be un-transformed.
        var_val : Float[ArrayLike, "..."]
            The variance value to be un-transformed.

        Returns
        -------
        tuple[Float[ArrayLike, "..."], Float[ArrayLike, "..."]]
            A tuple containing the un-transformed mean and variance values.

        """
        m = self.unapply_mean(mean_val)
        v = self.unapply_var(var_val)
        return (m, v)

    @classmethod
    def from_data_range(
        cls,
        data: Float[ArrayLike, "..."],
        lo: ScalarFloat | ScalarInt,
        hi: ScalarFloat | ScalarInt,
    ) -> LinearTransform:
        """
        Creates a LinearTransform instance such that data
        is scaled to [lo, hi].

        NaN values are ignored in the calculation.

        Parameters
        ----------
        data :  Float[ArrayLike, "..."]
            The input data.
        lo : ScalarFloat | ScalarInt
            The lower bound of the desired range.
        hi : ScalarFloat | ScalarInt
            The upper bound of the desired

        Returns
        -------
        LinearTransform
            A LinearTransform instance with slope and intercept
            such that data is scaled to [lo, hi].

        Raises
        ------
        ValueError
            If the input data contains less than 2 non-NaN values.

        """
        tnan = jnp.asarray(data)[~jnp.isnan(data)]
        if len(tnan) < 2:
            raise ValueError("Cannot scale with <2 values.")
        tmin = jnp.min(tnan)
        tmax = jnp.max(tnan)
        a = hi - lo
        b = tmax - tmin
        slope = a / b
        intercept = -slope * tmin + lo
        return cls(slope, intercept)

    @classmethod
    def from_data_width(
        cls,
        data: Float[ArrayLike, "..."],
        width: ScalarFloat | ScalarInt,
    ) -> LinearTransform:
        """
        Creates a LinearTransform instance such that the width of the data
        is scaled to the given width, i.e., the data is scaled to
        [-width/2, width/2].

        NaN values are ignored in the calculation.

        Parameters
        ----------
        data :  Float[ArrayLike, "..."]
            The input data.
        width : ScalarFloat | ScalarInt
            The desired width of the data.

        Returns
        -------
        LinearTransform
            A LinearTransform instance with slope and intercept
            such that the data is scaled to [-width/2, width/2].

        Raises
        ------
        ValueError
            If the input data contains less than 2 non-NaN values.
        """
        tnan = jnp.asarray(data)[~jnp.isnan(data)]
        if len(tnan) < 2:
            raise ValueError("Cannot scale with <2 values.")

        a = tnan.max() - tnan.min()
        slope = width / a
        intercept = -(jnp.asarray(width) * tnan.mean()) / a
        return cls(slope, intercept)

__init__(slope, intercept)

Initializes the LinearTransform object.

Parameters:

Name Type Description Default
slope ScalarFloat

The slope of the linear transformation.

required
intercept ScalarFloat

The intercept of the linear transformation.

required
Source code in gallifrey/inference/transforms.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(self, slope: ScalarFloat, intercept: ScalarFloat):
    """
    Initializes the LinearTransform object.

    Parameters
    ----------
    slope : ScalarFloat
        The slope of the linear transformation.
    intercept : ScalarFloat
        The intercept of the linear transformation.
    """
    self.slope = jnp.asarray(slope)
    self.intercept = jnp.asarray(intercept)

apply(x)

Applies the linear transformation to input x.

Parameters:

Name Type Description Default
x Float[ArrayLike, "..."]

The input data.

required

Returns:

Type Description
Float[jnp.ndarray, "..."]

The transformed data.

Source code in gallifrey/inference/transforms.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def apply(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Applies the linear transformation to input x.

    Parameters
    ----------
    x :  Float[ArrayLike, "..."]
        The input data.

    Returns
    -------
     Float[jnp.ndarray, "..."]
        The transformed data.

    """
    return self.slope * x + self.intercept

apply_mean(mean_val)

Applies the linear transformation to a mean value.

Parameters:

Name Type Description Default
mean_val Float[ArrayLike, '...']

The mean value to be transformed.

required

Returns:

Type Description
Float[ndarray, '...']

The transformed mean value.

Source code in gallifrey/inference/transforms.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def apply_mean(
    self,
    mean_val: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Applies the linear transformation to a mean value.

    Parameters
    ----------
    mean_val : Float[ArrayLike, "..."]
        The mean value to be transformed.

    Returns
    -------
    Float[jnp.ndarray, "..."]
        The transformed mean value.
    """
    return self.apply(jnp.asarray(mean_val))

apply_mean_var(mean_val, var_val)

Applies the linear transformation to mean and variance values.

Parameters:

Name Type Description Default
mean_val Float[ArrayLike, '...']

The mean value to be transformed.

required
var_val Float[ArrayLike, '...']

The variance value to be transformed.

required

Returns:

Type Description
tuple[Float[ndarray, '...'], Float[ndarray, '...']]

A tuple containing the transformed mean and variance values.

Source code in gallifrey/inference/transforms.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def apply_mean_var(
    self,
    mean_val: Float[ArrayLike, "..."],
    var_val: Float[ArrayLike, "..."],
) -> tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]:
    """
    Applies the linear transformation to mean and variance values.

    Parameters
    ----------
    mean_val : Float[ArrayLike, "..."]
        The mean value to be transformed.
    var_val : Float[ArrayLike, "..."]
        The variance value to be transformed.

    Returns
    -------
    tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]
        A tuple containing the transformed mean and variance values.

    """
    m = self.apply_mean(mean_val)
    v = self.apply_var(var_val)
    return (m, v)

apply_var(var_val)

Applies the linear transformation to a variance value.

Parameters:

Name Type Description Default
var_val Float[ArrayLike, '...']

The variance value to be transformed.

required

Returns:

Type Description
Float[ndarray, '...']

The transformed variance value.

Source code in gallifrey/inference/transforms.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def apply_var(
    self,
    var_val: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Applies the linear transformation to a variance value.

    Parameters
    ----------
    var_val : Float[ArrayLike, "..."]
        The variance value to be transformed.

    Returns
    -------
    Float[jnp.ndarray, "..."]
        The transformed variance value.
    """
    return jnp.asarray(self.slope**2 * var_val)

from_data_range(data, lo, hi) classmethod

Creates a LinearTransform instance such that data is scaled to [lo, hi].

NaN values are ignored in the calculation.

Parameters:

Name Type Description Default
data Float[ArrayLike, "..."]

The input data.

required
lo ScalarFloat | ScalarInt

The lower bound of the desired range.

required
hi ScalarFloat | ScalarInt

The upper bound of the desired

required

Returns:

Type Description
LinearTransform

A LinearTransform instance with slope and intercept such that data is scaled to [lo, hi].

Raises:

Type Description
ValueError

If the input data contains less than 2 non-NaN values.

Source code in gallifrey/inference/transforms.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
@classmethod
def from_data_range(
    cls,
    data: Float[ArrayLike, "..."],
    lo: ScalarFloat | ScalarInt,
    hi: ScalarFloat | ScalarInt,
) -> LinearTransform:
    """
    Creates a LinearTransform instance such that data
    is scaled to [lo, hi].

    NaN values are ignored in the calculation.

    Parameters
    ----------
    data :  Float[ArrayLike, "..."]
        The input data.
    lo : ScalarFloat | ScalarInt
        The lower bound of the desired range.
    hi : ScalarFloat | ScalarInt
        The upper bound of the desired

    Returns
    -------
    LinearTransform
        A LinearTransform instance with slope and intercept
        such that data is scaled to [lo, hi].

    Raises
    ------
    ValueError
        If the input data contains less than 2 non-NaN values.

    """
    tnan = jnp.asarray(data)[~jnp.isnan(data)]
    if len(tnan) < 2:
        raise ValueError("Cannot scale with <2 values.")
    tmin = jnp.min(tnan)
    tmax = jnp.max(tnan)
    a = hi - lo
    b = tmax - tmin
    slope = a / b
    intercept = -slope * tmin + lo
    return cls(slope, intercept)

from_data_width(data, width) classmethod

Creates a LinearTransform instance such that the width of the data is scaled to the given width, i.e., the data is scaled to [-width/2, width/2].

NaN values are ignored in the calculation.

Parameters:

Name Type Description Default
data Float[ArrayLike, "..."]

The input data.

required
width ScalarFloat | ScalarInt

The desired width of the data.

required

Returns:

Type Description
LinearTransform

A LinearTransform instance with slope and intercept such that the data is scaled to [-width/2, width/2].

Raises:

Type Description
ValueError

If the input data contains less than 2 non-NaN values.

Source code in gallifrey/inference/transforms.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
@classmethod
def from_data_width(
    cls,
    data: Float[ArrayLike, "..."],
    width: ScalarFloat | ScalarInt,
) -> LinearTransform:
    """
    Creates a LinearTransform instance such that the width of the data
    is scaled to the given width, i.e., the data is scaled to
    [-width/2, width/2].

    NaN values are ignored in the calculation.

    Parameters
    ----------
    data :  Float[ArrayLike, "..."]
        The input data.
    width : ScalarFloat | ScalarInt
        The desired width of the data.

    Returns
    -------
    LinearTransform
        A LinearTransform instance with slope and intercept
        such that the data is scaled to [-width/2, width/2].

    Raises
    ------
    ValueError
        If the input data contains less than 2 non-NaN values.
    """
    tnan = jnp.asarray(data)[~jnp.isnan(data)]
    if len(tnan) < 2:
        raise ValueError("Cannot scale with <2 values.")

    a = tnan.max() - tnan.min()
    slope = width / a
    intercept = -(jnp.asarray(width) * tnan.mean()) / a
    return cls(slope, intercept)

unapply(x)

Unapplies the linear transformation to input x.

Parameters:

Name Type Description Default
x Float[ArrayLike, "..."]

The (reverse) transformed data.

required

Returns:

Type Description
Float[jnp.ndarray, "..."]

The un-transformed data.

Source code in gallifrey/inference/transforms.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def unapply(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Unapplies the linear transformation to input x.

    Parameters
    ----------
    x :  Float[ArrayLike, "..."]
        The (reverse) transformed data.

    Returns
    -------
     Float[jnp.ndarray, "..."]
        The un-transformed data.
    """
    return jnp.asarray((x - self.intercept) / self.slope)

unapply_mean(mean_val)

Unapplies the linear transformation to a mean value.

Parameters:

Name Type Description Default
mean_val Float[ArrayLike, '...']

The mean value to be un-transformed.

required

Returns:

Type Description
Float[ArrayLike, '...']

The un-transformed mean value.

Source code in gallifrey/inference/transforms.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def unapply_mean(
    self,
    mean_val: Float[ArrayLike, "..."],
) -> Float[ArrayLike, "..."]:
    """
    Unapplies the linear transformation to a mean value.

    Parameters
    ----------
    mean_val : Float[ArrayLike, "..."]
        The mean value to be un-transformed.

    Returns
    -------
    Float[ArrayLike, "..."]
        The un-transformed mean value.
    """
    return self.unapply(jnp.asarray(mean_val))

unapply_mean_var(mean_val, var_val)

Unapplies the linear transformation to mean and variance values.

Parameters:

Name Type Description Default
mean_val Float[ArrayLike, '...']

The mean value to be un-transformed.

required
var_val Float[ArrayLike, '...']

The variance value to be un-transformed.

required

Returns:

Type Description
tuple[Float[ArrayLike, '...'], Float[ArrayLike, '...']]

A tuple containing the un-transformed mean and variance values.

Source code in gallifrey/inference/transforms.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def unapply_mean_var(
    self,
    mean_val: Float[ArrayLike, "..."],
    var_val: Float[ArrayLike, "..."],
) -> tuple[Float[ArrayLike, "..."], Float[ArrayLike, "..."]]:
    """
    Unapplies the linear transformation to mean and variance values.

    Parameters
    ----------
    mean_val : Float[ArrayLike, "..."]
        The mean value to be un-transformed.
    var_val : Float[ArrayLike, "..."]
        The variance value to be un-transformed.

    Returns
    -------
    tuple[Float[ArrayLike, "..."], Float[ArrayLike, "..."]]
        A tuple containing the un-transformed mean and variance values.

    """
    m = self.unapply_mean(mean_val)
    v = self.unapply_var(var_val)
    return (m, v)

unapply_var(var_val)

Unapplies the linear transformation to a variance value.

Parameters:

Name Type Description Default
var_val Float[ArrayLike, '...']

The variance value to be un-transformed.

required

Returns:

Type Description
Float[ndarray, '...']

The un-transformed variance value.

Source code in gallifrey/inference/transforms.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def unapply_var(
    self,
    var_val: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Unapplies the linear transformation to a variance value.

    Parameters
    ----------
    var_val : Float[ArrayLike, "..."]
        The variance value to be un-transformed.

    Returns
    -------
    Float[jnp.ndarray, "..."]
        The un-transformed variance value.
    """
    return jnp.asarray((1 / (self.slope**2)) * var_val)

LogTransform

Bases: Transform

Class for log transformations.

The transformation is defined as y = log(x).

Source code in gallifrey/inference/transforms.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
class LogTransform(Transform):
    """
    Class for log transformations.

    The transformation is defined as y = log(x).

    """

    def apply(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Applies the log transformation to input x.

        Parameters
        ----------
        x : Float[ArrayLike, "..."]
            The input data.

        Returns
        -------
        Float[jnp.ndarray, "..."]
            The transformed data.

        """

        return jnp.log(x)

    def unapply(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """
        Unapplies the log transformation to input x.

        Parameters
        ----------
        x : Float[ArrayLike, "..."]
            The (reverse) transformed data.

        Returns
        -------
         Float[jnp.ndarray, "..."]
            The un-transformed data.
        """
        return jnp.exp(x)

    def unapply_mean_var(
        self,
        mean_val: Float[ArrayLike, "..."],
        var_val: Float[ArrayLike, "..."],
    ) -> tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]:
        """
        Unapplies the log transformation to mean and variance values.

        Parameters
        ----------
        mean_val : Float[ArrayLike, "..."]
            The mean value to be un-transformed.
        var_val : Float[ArrayLike, "..."]
            The variance value to be un-transformed.

        Returns
        -------
        tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]
            A tuple containing the un-transformed mean and variance values.
        """
        m = jnp.exp(mean_val + var_val / 2)
        v = (jnp.exp(var_val) - 1) * jnp.exp(2 * mean_val + var_val)
        return (m, v)

apply(x)

Applies the log transformation to input x.

Parameters:

Name Type Description Default
x Float[ArrayLike, '...']

The input data.

required

Returns:

Type Description
Float[ndarray, '...']

The transformed data.

Source code in gallifrey/inference/transforms.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def apply(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Applies the log transformation to input x.

    Parameters
    ----------
    x : Float[ArrayLike, "..."]
        The input data.

    Returns
    -------
    Float[jnp.ndarray, "..."]
        The transformed data.

    """

    return jnp.log(x)

unapply(x)

Unapplies the log transformation to input x.

Parameters:

Name Type Description Default
x Float[ArrayLike, '...']

The (reverse) transformed data.

required

Returns:

Type Description
Float[jnp.ndarray, "..."]

The un-transformed data.

Source code in gallifrey/inference/transforms.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def unapply(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """
    Unapplies the log transformation to input x.

    Parameters
    ----------
    x : Float[ArrayLike, "..."]
        The (reverse) transformed data.

    Returns
    -------
     Float[jnp.ndarray, "..."]
        The un-transformed data.
    """
    return jnp.exp(x)

unapply_mean_var(mean_val, var_val)

Unapplies the log transformation to mean and variance values.

Parameters:

Name Type Description Default
mean_val Float[ArrayLike, '...']

The mean value to be un-transformed.

required
var_val Float[ArrayLike, '...']

The variance value to be un-transformed.

required

Returns:

Type Description
tuple[Float[ndarray, '...'], Float[ndarray, '...']]

A tuple containing the un-transformed mean and variance values.

Source code in gallifrey/inference/transforms.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def unapply_mean_var(
    self,
    mean_val: Float[ArrayLike, "..."],
    var_val: Float[ArrayLike, "..."],
) -> tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]:
    """
    Unapplies the log transformation to mean and variance values.

    Parameters
    ----------
    mean_val : Float[ArrayLike, "..."]
        The mean value to be un-transformed.
    var_val : Float[ArrayLike, "..."]
        The variance value to be un-transformed.

    Returns
    -------
    tuple[Float[jnp.ndarray, "..."], Float[jnp.ndarray, "..."]]
        A tuple containing the un-transformed mean and variance values.
    """
    m = jnp.exp(mean_val + var_val / 2)
    v = (jnp.exp(var_val) - 1) * jnp.exp(2 * mean_val + var_val)
    return (m, v)

Transform

Bases: ABC

Abstract base class for data transformations.

Source code in gallifrey/inference/transforms.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class Transform(ABC):
    """Abstract base class for data transformations."""

    def __call__(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """Applies the transformation to input."""
        return self.apply(x)

    @abstractmethod
    def apply(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """Applies the transformation to input."""
        pass

    @abstractmethod
    def unapply(
        self,
        x: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """Unapplies the transformation to input."""
        pass

    @abstractmethod
    def apply_var(
        self,
        var_val: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """Applies the transformation to a variance value."""
        pass

    @abstractmethod
    def unapply_var(
        self,
        var_val: Float[ArrayLike, "..."],
    ) -> Float[jnp.ndarray, "..."]:
        """Unapplies the transformation to a variance value."""
        pass

    @classmethod
    @abstractmethod
    def from_data_range(
        cls,
        data: Float[ArrayLike, "..."],
        lo: ScalarFloat | ScalarInt,
        hi: ScalarFloat | ScalarInt,
    ) -> Transform:
        """Creates a Transform instance such that data is scaled to [lo, hi]."""
        pass

    @classmethod
    @abstractmethod
    def from_data_width(
        cls,
        data: Float[ArrayLike, "..."],
        width: ScalarFloat | ScalarInt,
    ) -> Transform:
        """Creates a Transform instance such that the width of the data
        is scaled to the given width."""
        pass

__call__(x)

Applies the transformation to input.

Source code in gallifrey/inference/transforms.py
14
15
16
17
18
19
def __call__(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """Applies the transformation to input."""
    return self.apply(x)

apply(x) abstractmethod

Applies the transformation to input.

Source code in gallifrey/inference/transforms.py
21
22
23
24
25
26
27
@abstractmethod
def apply(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """Applies the transformation to input."""
    pass

apply_var(var_val) abstractmethod

Applies the transformation to a variance value.

Source code in gallifrey/inference/transforms.py
37
38
39
40
41
42
43
@abstractmethod
def apply_var(
    self,
    var_val: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """Applies the transformation to a variance value."""
    pass

from_data_range(data, lo, hi) abstractmethod classmethod

Creates a Transform instance such that data is scaled to [lo, hi].

Source code in gallifrey/inference/transforms.py
53
54
55
56
57
58
59
60
61
62
@classmethod
@abstractmethod
def from_data_range(
    cls,
    data: Float[ArrayLike, "..."],
    lo: ScalarFloat | ScalarInt,
    hi: ScalarFloat | ScalarInt,
) -> Transform:
    """Creates a Transform instance such that data is scaled to [lo, hi]."""
    pass

from_data_width(data, width) abstractmethod classmethod

Creates a Transform instance such that the width of the data is scaled to the given width.

Source code in gallifrey/inference/transforms.py
64
65
66
67
68
69
70
71
72
73
@classmethod
@abstractmethod
def from_data_width(
    cls,
    data: Float[ArrayLike, "..."],
    width: ScalarFloat | ScalarInt,
) -> Transform:
    """Creates a Transform instance such that the width of the data
    is scaled to the given width."""
    pass

unapply(x) abstractmethod

Unapplies the transformation to input.

Source code in gallifrey/inference/transforms.py
29
30
31
32
33
34
35
@abstractmethod
def unapply(
    self,
    x: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """Unapplies the transformation to input."""
    pass

unapply_var(var_val) abstractmethod

Unapplies the transformation to a variance value.

Source code in gallifrey/inference/transforms.py
45
46
47
48
49
50
51
@abstractmethod
def unapply_var(
    self,
    var_val: Float[ArrayLike, "..."],
) -> Float[jnp.ndarray, "..."]:
    """Unapplies the transformation to a variance value."""
    pass