Convolution¤
generax.flows.conv.CircularConv (BijectiveTransform)
¤
Circular convolution. Equivalent to a regular convolution with circular padding. https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf
Source code in generax/flows/conv.py
class CircularConv(BijectiveTransform):
"""Circular convolution. Equivalent to a regular convolution with circular padding.
https://papers.nips.cc/paper/2019/file/b1f62fa99de9f27a048344d55c5ef7a6-Paper.pdf
"""
filter_shape: Tuple[int] = eqx.field(static=True)
w: Array
def __init__(self,
input_shape: Tuple[int],
filter_shape: Tuple[int]=(3, 3),
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `filter_shape`: Height and width for the convolutional filter, (Kx, Ky). The full
kernel will have shape (Kx, Ky, C, C)
"""
assert len(filter_shape) == 2
self.filter_shape = filter_shape
super().__init__(input_shape=input_shape,
**kwargs)
H, W, C = input_shape
w = random.normal(key, shape=self.filter_shape + (C, C))
self.w = jax.vmap(jax.vmap(util.whiten))(w)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
See http://developer.download.nvidia.com/compute/cuda/2_2/sdk/website/projects/convolutionFFT2D/doc/convolutionFFT2D.pdf
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
Kx, Ky, _, _ = self.w.shape
# See how much we need to roll the filter
W_x = (Kx - 1) // 2
W_y = (Ky - 1) // 2
# Pad the filter to match the fft size and roll it so that its center is at (0,0)
W_padded = jnp.pad(self.w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))
# Apply the FFT to get the convolution
if inverse == False:
image_fft = fft_channel_vmap(x)
else:
image_fft = fft_channel_vmap(x)
W_fft = fft_double_channel_vmap(W_padded)
if inverse == True:
z_fft = jnp.einsum("abij,abj->abi", W_fft, image_fft)
z = ifft_channel_vmap(z_fft).real
else:
# For deconv, we need to invert the W over the channel dims
W_fft_inv = inv_height_width_vmap(W_fft)
x_fft = jnp.einsum("abij,abj->abi", W_fft_inv, image_fft)
z = ifft_channel_vmap(x_fft).real
# The log determinant is the log det of the frequencies over the channel dims
log_det = -slogdet_height_width_vmap(W_fft).sum()
if inverse:
log_det = -log_det
return z, log_det
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int] = (3, 3), *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initializationfilter_shape
: Height and width for the convolutional filter, (Kx, Ky). The full kernel will have shape (Kx, Ky, C, C)
Source code in generax/flows/conv.py
def __init__(self,
input_shape: Tuple[int],
filter_shape: Tuple[int]=(3, 3),
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `filter_shape`: Height and width for the convolutional filter, (Kx, Ky). The full
kernel will have shape (Kx, Ky, C, C)
"""
assert len(filter_shape) == 2
self.filter_shape = filter_shape
super().__init__(input_shape=input_shape,
**kwargs)
H, W, C = input_shape
w = random.normal(key, shape=self.filter_shape + (C, C))
self.w = jax.vmap(jax.vmap(util.whiten))(w)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/conv.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
See http://developer.download.nvidia.com/compute/cuda/2_2/sdk/website/projects/convolutionFFT2D/doc/convolutionFFT2D.pdf Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/conv.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
See http://developer.download.nvidia.com/compute/cuda/2_2/sdk/website/projects/convolutionFFT2D/doc/convolutionFFT2D.pdf
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
Kx, Ky, _, _ = self.w.shape
# See how much we need to roll the filter
W_x = (Kx - 1) // 2
W_y = (Ky - 1) // 2
# Pad the filter to match the fft size and roll it so that its center is at (0,0)
W_padded = jnp.pad(self.w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))
# Apply the FFT to get the convolution
if inverse == False:
image_fft = fft_channel_vmap(x)
else:
image_fft = fft_channel_vmap(x)
W_fft = fft_double_channel_vmap(W_padded)
if inverse == True:
z_fft = jnp.einsum("abij,abj->abi", W_fft, image_fft)
z = ifft_channel_vmap(z_fft).real
else:
# For deconv, we need to invert the W over the channel dims
W_fft_inv = inv_height_width_vmap(W_fft)
x_fft = jnp.einsum("abij,abj->abi", W_fft_inv, image_fft)
z = ifft_channel_vmap(x_fft).real
# The log determinant is the log det of the frequencies over the channel dims
log_det = -slogdet_height_width_vmap(W_fft).sum()
if inverse:
log_det = -log_det
return z, log_det
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/conv.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
generax.flows.conv.CaleyOrthogonalConv (BijectiveTransform)
¤
Caley parametrization of an orthogonal convolution. https://arxiv.org/pdf/2104.07167.pdf
Source code in generax/flows/conv.py
class CaleyOrthogonalConv(BijectiveTransform):
"""Caley parametrization of an orthogonal convolution.
https://arxiv.org/pdf/2104.07167.pdf
"""
filter_shape: Tuple[int] = eqx.field(static=True)
v: Array
g: Array
def __init__(self,
input_shape: Tuple[int],
filter_shape: Tuple[int] = (3, 3),
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `filter_shape`: Height and width for the convolutional filter, (Kx, Ky). The full
kernel will have shape (Kx, Ky, C, C)
"""
assert len(filter_shape) == 2
self.filter_shape = filter_shape
super().__init__(input_shape=input_shape,
**kwargs)
H, W, C = input_shape
k1, k2 = random.split(key, 2)
self.v = random.normal(k1, shape=self.filter_shape + (C, C))
self.g = random.normal(k2, shape=())
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
w = self.g*self.v/jnp.linalg.norm(self.v)
Kx, Ky, _, _ = w.shape
# See how much we need to roll the filter
W_x = (Kx - 1) // 2
W_y = (Ky - 1) // 2
# Pad the filter to match the fft size and roll it so that its center is at (0,0)
W_padded = jnp.pad(w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))
# Apply the FFT to get the convolution
if inverse == False:
image_fft = fft_channel_vmap(x)
else:
image_fft = fft_channel_vmap(x)
W_fft = fft_double_channel_vmap(W_padded)
A_fft = W_fft - W_fft.conj().transpose((0, 1, 3, 2))
I = jnp.eye(W_fft.shape[-1])
if inverse == True:
IpA_inv = inv_height_width_vmap(I[None,None] + A_fft)
y_fft = jnp.einsum("abij,abj->abi", IpA_inv, image_fft)
z_fft = y_fft - jnp.einsum("abij,abj->abi", A_fft, y_fft)
z = ifft_channel_vmap(z_fft).real
else:
ImA_inv = inv_height_width_vmap(I[None,None] - A_fft)
y_fft = jnp.einsum("abij,abj->abi", ImA_inv, image_fft)
z_fft = y_fft + jnp.einsum("abij,abj->abi", A_fft, y_fft)
z = ifft_channel_vmap(z_fft).real
log_det = jnp.array(0.0)
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/conv.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/conv.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], filter_shape: Tuple[int] = (3, 3), *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initializationfilter_shape
: Height and width for the convolutional filter, (Kx, Ky). The full kernel will have shape (Kx, Ky, C, C)
Source code in generax/flows/conv.py
def __init__(self,
input_shape: Tuple[int],
filter_shape: Tuple[int] = (3, 3),
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
- `filter_shape`: Height and width for the convolutional filter, (Kx, Ky). The full
kernel will have shape (Kx, Ky, C, C)
"""
assert len(filter_shape) == 2
self.filter_shape = filter_shape
super().__init__(input_shape=input_shape,
**kwargs)
H, W, C = input_shape
k1, k2 = random.split(key, 2)
self.v = random.normal(k1, shape=self.filter_shape + (C, C))
self.g = random.normal(k2, shape=())
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/conv.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
w = self.g*self.v/jnp.linalg.norm(self.v)
Kx, Ky, _, _ = w.shape
# See how much we need to roll the filter
W_x = (Kx - 1) // 2
W_y = (Ky - 1) // 2
# Pad the filter to match the fft size and roll it so that its center is at (0,0)
W_padded = jnp.pad(w[::-1,::-1,:,:], ((0, H - Kx), (0, W - Ky), (0, 0), (0, 0)))
W_padded = jnp.roll(W_padded, (-W_x, -W_y), axis=(0, 1))
# Apply the FFT to get the convolution
if inverse == False:
image_fft = fft_channel_vmap(x)
else:
image_fft = fft_channel_vmap(x)
W_fft = fft_double_channel_vmap(W_padded)
A_fft = W_fft - W_fft.conj().transpose((0, 1, 3, 2))
I = jnp.eye(W_fft.shape[-1])
if inverse == True:
IpA_inv = inv_height_width_vmap(I[None,None] + A_fft)
y_fft = jnp.einsum("abij,abj->abi", IpA_inv, image_fft)
z_fft = y_fft - jnp.einsum("abij,abj->abi", A_fft, y_fft)
z = ifft_channel_vmap(z_fft).real
else:
ImA_inv = inv_height_width_vmap(I[None,None] - A_fft)
y_fft = jnp.einsum("abij,abj->abi", ImA_inv, image_fft)
z_fft = y_fft + jnp.einsum("abij,abj->abi", A_fft, y_fft)
z = ifft_channel_vmap(z_fft).real
log_det = jnp.array(0.0)
return z, log_det
generax.flows.conv.OneByOneConv (BijectiveTransform)
¤
1x1 convolution. Uses a dense parametrization because the channel dimension will probably never be that big. Costs O(C^3). Used in GLOW https://arxiv.org/pdf/1807.03039.pdf
Source code in generax/flows/conv.py
class OneByOneConv(BijectiveTransform):
""" 1x1 convolution. Uses a dense parametrization because the channel dimension will probably
never be that big. Costs O(C^3). Used in GLOW https://arxiv.org/pdf/1807.03039.pdf
"""
w: Array
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
H, W, C = input_shape
w = random.normal(key, shape=(C, C))
self.w = util.whiten(w)
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
# Using lax.conv instead of matrix multiplication over the channel dimension
# is faster and also more numerically stable for some reason.
# Run the flow
if inverse == False:
z = util.conv(self.w[None,None,:,:], x)
else:
w_inv = jnp.linalg.inv(self.w)
z = util.conv(w_inv[None,None,:,:], x)
log_det = jnp.linalg.slogdet(self.w)[1]*H*W
if inverse:
log_det = -log_det
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/conv.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/conv.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/flows/conv.py
def __init__(self,
input_shape: Tuple[int],
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(input_shape=input_shape,
**kwargs)
H, W, C = input_shape
w = random.normal(key, shape=(C, C))
self.w = util.whiten(w)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/conv.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
# Using lax.conv instead of matrix multiplication over the channel dimension
# is faster and also more numerically stable for some reason.
# Run the flow
if inverse == False:
z = util.conv(self.w[None,None,:,:], x)
else:
w_inv = jnp.linalg.inv(self.w)
z = util.conv(w_inv[None,None,:,:], x)
log_det = jnp.linalg.slogdet(self.w)[1]*H*W
if inverse:
log_det = -log_det
return z, log_det
generax.flows.conv.HaarWavelet (BijectiveTransform)
¤
Wavelet flow https://arxiv.org/pdf/2010.13821.pdf
Source code in generax/flows/conv.py
class HaarWavelet(BijectiveTransform):
"""Wavelet flow https://arxiv.org/pdf/2010.13821.pdf"""
W: Array = eqx.field(static=True)
output_shape: Tuple[int] = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
"""
H, W, C = input_shape
if H%2 != 0:
raise ValueError('Height must be even')
if W%2 != 0:
raise ValueError('Width must be even')
super().__init__(input_shape=input_shape,
**kwargs)
self.output_shape = (H//2, W//2, 4*C)
# Construct the filter
p, n = 0.5, -0.5
W = np.array([[[p, p],
[p, p]],
[[p, n],
[p, n]],
[[p, p],
[n, n]],
[[p, n],
[n, p]]])
W = W.transpose((1, 2, 0)) # (H, W, O)
W = W[:,:,None,:] # (H, W, I, O). We'll be applying this channelwise
self.W = W
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
def haar_conv(x):
"""(H, W) -> (H/2, W/2, 4))"""
return util.conv(self.W, x[:,:,None], stride=2)
if inverse == False:
H, W, C = x.shape
z = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(x)
# Rescale the lowpass to have same mean
z = z.at[:,:,0].mul(0.5)
z = einops.rearrange(z, 'H W D C -> H W (C D)', D=4)
else:
h = einops.rearrange(x, 'H W (C D) -> H W C D', D=4)
# Rescale
h = h.at[:,:,:,0].mul(2.0)
h = einops.rearrange(h, 'H W C (M N) -> (H M) (W N) C', M=2, N=2)
h = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(h)
z = einops.rearrange(h, 'H W (M N) C -> (H M) (W N) C', M=2, N=2)
total_dim = util.list_prod(x.shape)
log_det = jnp.log(0.5)*total_dim/4
if inverse:
log_det = -log_det
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/conv.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/conv.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.
Source code in generax/flows/conv.py
def __init__(self,
input_shape: Tuple[int],
**kwargs):
"""**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
"""
H, W, C = input_shape
if H%2 != 0:
raise ValueError('Height must be even')
if W%2 != 0:
raise ValueError('Width must be even')
super().__init__(input_shape=input_shape,
**kwargs)
self.output_shape = (H//2, W//2, 4*C)
# Construct the filter
p, n = 0.5, -0.5
W = np.array([[[p, p],
[p, p]],
[[p, n],
[p, n]],
[[p, p],
[n, n]],
[[p, n],
[n, p]]])
W = W.transpose((1, 2, 0)) # (H, W, O)
W = W[:,:,None,:] # (H, W, I, O). We'll be applying this channelwise
self.W = W
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/conv.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
def haar_conv(x):
"""(H, W) -> (H/2, W/2, 4))"""
return util.conv(self.W, x[:,:,None], stride=2)
if inverse == False:
H, W, C = x.shape
z = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(x)
# Rescale the lowpass to have same mean
z = z.at[:,:,0].mul(0.5)
z = einops.rearrange(z, 'H W D C -> H W (C D)', D=4)
else:
h = einops.rearrange(x, 'H W (C D) -> H W C D', D=4)
# Rescale
h = h.at[:,:,:,0].mul(2.0)
h = einops.rearrange(h, 'H W C (M N) -> (H M) (W N) C', M=2, N=2)
h = jax.vmap(haar_conv, in_axes=-1, out_axes=-1)(h)
z = einops.rearrange(h, 'H W (M N) C -> (H M) (W N) C', M=2, N=2)
total_dim = util.list_prod(x.shape)
log_det = jnp.log(0.5)*total_dim/4
if inverse:
log_det = -log_det
return z, log_det
generax.flows.pac_flow.EmergingConv (PACFlow)
¤
Emerging convolutions https://arxiv.org/pdf/1901.11137.pdf This is a special case of PAC flows
Source code in generax/flows/pac_flow.py
class EmergingConv(PACFlow):
"""Emerging convolutions https://arxiv.org/pdf/1901.11137.pdf
This is a special case of PAC flows
"""
def __init__(self,
input_shape: Tuple[int],
kernel_size: int = 5,
order_type: str = "s_curve",
zero_init: bool = True,
*,
key: PRNGKeyArray,
**kwargs):
super().__init__(input_shape=input_shape,
feature_dim=None,
kernel_size=kernel_size,
order_type=order_type,
pixel_adaptive=False,
zero_init=zero_init,
key=key,
**kwargs)
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/flows/pac_flow.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Apply the inverse transformation.
Arguments:
x
: The input to the transformationy
: The conditioning information
Returns: (z, log_det)
Source code in generax/flows/pac_flow.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Inherited from generax.flows.pac_flow.PACFlow.__call__
.
Source code in generax/flows/pac_flow.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
# Apply the linear function
z, diag_jacobian = pac_ldu_mvp(x,
self.theta,
self.w,
self.order,
inverse=inverse,
**self.im2col_kwargs)
# Get the log det
if self.theta is not None:
flat_diag = diag_jacobian.reshape(self.theta.shape[:-3] + (-1,))
else:
flat_diag = diag_jacobian.reshape(x.shape[:1] + (-1,))
log_det = jnp.log(jnp.abs(flat_diag)).sum()
if inverse:
log_det = -log_det
assert z.shape == self.input_shape
assert log_det.shape == ()
return z, log_det
__init__(self, input_shape: Tuple[int], kernel_size: int = 5, order_type: str = 's_curve', zero_init: bool = True, *, key: PRNGKeyArray, **kwargs)
¤
Source code in generax/flows/pac_flow.py
def __init__(self,
input_shape: Tuple[int],
kernel_size: int = 5,
order_type: str = "s_curve",
zero_init: bool = True,
*,
key: PRNGKeyArray,
**kwargs):
super().__init__(input_shape=input_shape,
feature_dim=None,
kernel_size=kernel_size,
order_type=order_type,
pixel_adaptive=False,
zero_init=zero_init,
key=key,
**kwargs)
generax.flows.pac_flow.PACFlow (BijectiveTransform)
¤
Pixel adaptive convolutions. Gets too numerically unstable to use in practice... https://eddiecunningham.github.io/pdfs/PAC_Flow.pdf
Source code in generax/flows/pac_flow.py
class PACFlow(BijectiveTransform):
"""Pixel adaptive convolutions. Gets too numerically unstable to use in practice...
https://eddiecunningham.github.io/pdfs/PAC_Flow.pdf
"""
kernel_shape: Tuple[int] = eqx.field(static=True)
feature_dim: int = eqx.field(static=True)
order_type: str = eqx.field(static=True)
pixel_adaptive: bool = eqx.field(static=True)
im2col_kwargs: Any = eqx.field(static=True)
order: Array = eqx.field(static=True)
w: Array
theta: Union[Array, None]
def __init__(self,
input_shape: Tuple[int],
feature_dim: int = 8,
kernel_size: int=5,
order_type: str="s_curve",
pixel_adaptive: bool=True,
zero_init: bool = True,
*,
key: PRNGKeyArray,
**kwargs):
"""
**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `feature_dim`: The dimension of the features
- `kernel_size`: Height and width for the convolutional filter, (Kx, Ky). The full
kernel will have shape (Kx, Ky, C, C)
- `order_type`: The order to convolve in. Either "raster" or "s_curve"
- `pixel_adaptive`: Whether to use pixel adaptive convolutions
- `zero_init`: Whether to initialize the weights to zero
"""
super().__init__(input_shape=input_shape,
**kwargs)
assert kernel_size%2 == 1
self.kernel_shape = (kernel_size, kernel_size)
self.feature_dim = feature_dim
self.order_type = order_type
self.pixel_adaptive = pixel_adaptive
H, W, C = input_shape
# Extract the im2col kwargs
Kx, Ky = self.kernel_shape
pad = Kx//2
self.im2col_kwargs = dict(filter_shape=self.kernel_shape,
stride=(1, 1),
padding=((pad, Kx - pad - 1),
(pad, Ky - pad - 1)),
lhs_dilation=(1, 1),
rhs_dilation=(1, 1),
dimension_numbers=("NHWC", "HWIO", "NHWC"))
# Determine the order to convolve
order_shape = H, W, 1
if self.order_type == "raster":
order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
elif self.order_type == "s_curve":
order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
order[::2, :, :] = order[::2, :, :][:, ::-1]
order = order*1.0 # Turn into a float
self.order = order
# Initialize the weights
k1, k2 = random.split(key, 2)
w = random.normal(k1, shape=self.kernel_shape + (C, C))
if zero_init:
pad = Kx//2
w = w.at[pad,pad,jnp.arange(C),jnp.arange(C)].set(1.0)
self.w = w
if self.pixel_adaptive == True:
self.theta = random.normal(k2, shape=(H, W, 2*C + self.feature_dim))
else:
self.theta = None
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
# Apply the linear function
z, diag_jacobian = pac_ldu_mvp(x,
self.theta,
self.w,
self.order,
inverse=inverse,
**self.im2col_kwargs)
# Get the log det
if self.theta is not None:
flat_diag = diag_jacobian.reshape(self.theta.shape[:-3] + (-1,))
else:
flat_diag = diag_jacobian.reshape(x.shape[:1] + (-1,))
log_det = jnp.log(jnp.abs(flat_diag)).sum()
if inverse:
log_det = -log_det
assert z.shape == self.input_shape
assert log_det.shape == ()
return z, log_det
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None)
¤
Inherited from generax.flows.base.BijectiveTransform.data_dependent_init
.
Source code in generax/flows/pac_flow.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None):
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
inverse(self, x: Array, y: Optional[Array] = None, **kwargs) -> Array
¤
Inherited from generax.flows.base.BijectiveTransform.inverse
.
Source code in generax/flows/pac_flow.py
def inverse(self,
x: Array,
y: Optional[Array] = None,
**kwargs) -> Array:
"""Apply the inverse transformation.
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
**Returns**:
(z, log_det)
"""
return self(x, y=y, inverse=True, **kwargs)
__init__(self, input_shape: Tuple[int], feature_dim: int = 8, kernel_size: int = 5, order_type: str = 's_curve', pixel_adaptive: bool = True, zero_init: bool = True, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input shape. Output size is the same as shape.feature_dim
: The dimension of the featureskernel_size
: Height and width for the convolutional filter, (Kx, Ky). The full kernel will have shape (Kx, Ky, C, C)order_type
: The order to convolve in. Either "raster" or "s_curve"pixel_adaptive
: Whether to use pixel adaptive convolutionszero_init
: Whether to initialize the weights to zero
Source code in generax/flows/pac_flow.py
def __init__(self,
input_shape: Tuple[int],
feature_dim: int = 8,
kernel_size: int=5,
order_type: str="s_curve",
pixel_adaptive: bool=True,
zero_init: bool = True,
*,
key: PRNGKeyArray,
**kwargs):
"""
**Arguments**:
- `input_shape`: The input shape. Output size is the same as shape.
- `feature_dim`: The dimension of the features
- `kernel_size`: Height and width for the convolutional filter, (Kx, Ky). The full
kernel will have shape (Kx, Ky, C, C)
- `order_type`: The order to convolve in. Either "raster" or "s_curve"
- `pixel_adaptive`: Whether to use pixel adaptive convolutions
- `zero_init`: Whether to initialize the weights to zero
"""
super().__init__(input_shape=input_shape,
**kwargs)
assert kernel_size%2 == 1
self.kernel_shape = (kernel_size, kernel_size)
self.feature_dim = feature_dim
self.order_type = order_type
self.pixel_adaptive = pixel_adaptive
H, W, C = input_shape
# Extract the im2col kwargs
Kx, Ky = self.kernel_shape
pad = Kx//2
self.im2col_kwargs = dict(filter_shape=self.kernel_shape,
stride=(1, 1),
padding=((pad, Kx - pad - 1),
(pad, Ky - pad - 1)),
lhs_dilation=(1, 1),
rhs_dilation=(1, 1),
dimension_numbers=("NHWC", "HWIO", "NHWC"))
# Determine the order to convolve
order_shape = H, W, 1
if self.order_type == "raster":
order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
elif self.order_type == "s_curve":
order = np.arange(1, 1 + util.list_prod(order_shape)).reshape(order_shape)
order[::2, :, :] = order[::2, :, :][:, ::-1]
order = order*1.0 # Turn into a float
self.order = order
# Initialize the weights
k1, k2 = random.split(key, 2)
w = random.normal(k1, shape=self.kernel_shape + (C, C))
if zero_init:
pad = Kx//2
w = w.at[pad,pad,jnp.arange(C),jnp.arange(C)].set(1.0)
self.w = w
if self.pixel_adaptive == True:
self.theta = random.normal(k2, shape=(H, W, 2*C + self.feature_dim))
else:
self.theta = None
__call__(self, x: Array, y: Optional[Array] = None, inverse: bool = False, **kwargs) -> Array
¤
Arguments:
x
: The input to the transformationy
: The conditioning informationinverse
: Whether to inverse the transformation
Returns:
(z, log_det)
Source code in generax/flows/pac_flow.py
def __call__(self,
x: Array,
y: Optional[Array] = None,
inverse: bool = False,
**kwargs) -> Array:
"""
**Arguments**:
- `x`: The input to the transformation
- `y`: The conditioning information
- `inverse`: Whether to inverse the transformation
**Returns**:
`(z, log_det)`
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = x.shape
# Apply the linear function
z, diag_jacobian = pac_ldu_mvp(x,
self.theta,
self.w,
self.order,
inverse=inverse,
**self.im2col_kwargs)
# Get the log det
if self.theta is not None:
flat_diag = diag_jacobian.reshape(self.theta.shape[:-3] + (-1,))
else:
flat_diag = diag_jacobian.reshape(x.shape[:1] + (-1,))
log_det = jnp.log(jnp.abs(flat_diag)).sum()
if inverse:
log_det = -log_det
assert z.shape == self.input_shape
assert log_det.shape == ()
return z, log_det