Resnet Blocks¤
generax.nn.resnet_blocks.GatedResBlock
¤
Gated residual block for 1d data or images.
Source code in generax/nn/resnet_blocks.py
class GatedResBlock(eqx.Module):
"""Gated residual block for 1d data or images."""
linear_cond: Union[Union[WeightNormDense,ConvAndGroupNorm], None]
linear1: Union[WeightNormDense,ConvAndGroupNorm]
linear2: Union[WeightNormDense,ConvAndGroupNorm]
activation: Callable
input_shape: Tuple[int] = eqx.field(static=True)
hidden_size: int = eqx.field(static=True)
cond_shape: Tuple[int] = eqx.field(static=True)
filter_shape: Union[Tuple[int],None] = eqx.field(static=True)
groups: Union[int,None] = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
hidden_size: int,
groups: Optional[int] = None,
filter_shape: Optional[Tuple[int]] = None,
cond_shape: Optional[Tuple[int]] = None,
activation: Callable = jax.nn.swish,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input size. Output size is the same as `input_shape`.
- `hidden_size`: The hidden layer size.
- `cond_shape`: The size of the conditioning information.
- `activation`: The activation function after each hidden layer.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(**kwargs)
if len(input_shape) not in [1, 3]:
raise ValueError(f'Expected 1d or 3d input shape')
image = False
if len(input_shape) == 3:
H, W, C = input_shape
image = True
assert filter_shape is not None, 'Must pass in filter shape when processing images'
self.input_shape = input_shape
self.hidden_size = hidden_size
self.cond_shape = cond_shape
self.filter_shape = filter_shape
self.activation = activation
if groups is not None:
assert image
if hidden_size % groups != 0:
raise ValueError(f'Hidden size must be divisible by groups')
self.groups = groups
k1, k2, k3 = random.split(key, 3)
# Initialize the conditioning parameters
if cond_shape is not None:
if len(cond_shape) == 1:
self.linear_cond = WeightNormDense(in_size=cond_shape[0],
out_size=2*hidden_size,
key=k1)
else:
self.linear_cond = ConvAndGroupNorm(input_shape=cond_shape,
out_size=2*hidden_size,
filter_shape=filter_shape,
groups=groups,
key=k1)
else:
self.linear_cond = None
if image:
self.linear1 = ConvAndGroupNorm(input_shape=input_shape,
out_size=hidden_size,
filter_shape=filter_shape,
groups=groups,
key=k2)
hidden_shape = (H, W, hidden_size)
self.linear2 = WeightNormConv(input_shape=hidden_shape,
out_size=2*C,
filter_shape=filter_shape,
key=k3)
else:
self.linear1 = WeightNormDense(in_size=input_shape[0],
out_size=hidden_size,
key=k2)
self.linear2 = WeightNormDense(in_size=hidden_size,
out_size=2*input_shape[0],
key=k3)
def data_dependent_init(self,
x: Array,
y: Array = None,
key: PRNGKeyArray = None) -> eqx.Module:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
k1, k2, k3 = random.split(key, 3)
# Initialize the conditioning parameters
if y is not None:
linear_cond = self.linear_cond.data_dependent_init(y, key=k1)
h = eqx.filter_vmap(linear_cond)(y)
shift, scale = jnp.split(h, 2, axis=-1)
else:
linear_cond = None
# Linear + shift/scale + activation
linear1 = self.linear1.data_dependent_init(x, key=k2)
x = eqx.filter_vmap(linear1)(x)
if y is not None:
x = shift + x*(1 + scale)
x = eqx.filter_vmap(self.activation)(x)
# Linear + gate
linear2 = self.linear2.data_dependent_init(x, key=k3)
# Turn the new parameters into a new module
get_linear_cond = lambda tree: tree.linear_cond
get_linear1 = lambda tree: tree.linear1
get_linear2 = lambda tree: tree.linear2
updated_layer = eqx.tree_at(get_linear_cond, self, linear_cond)
updated_layer = eqx.tree_at(get_linear1, updated_layer, linear1)
updated_layer = eqx.tree_at(get_linear2, updated_layer, linear2)
return updated_layer
def __call__(self, x: Array, y: Array = None) -> Array:
"""**Arguments:**
- `x`: A JAX array with shape `input_shape`.
- `y`: A JAX array to condition on with shape `cond_shape`.
**Returns:**
A JAX array with shape `input_shape`.
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
x_in = x
# The conditioning input will shift/scale x
if y is not None:
h = self.linear_cond(self.activation(y))
shift, scale = jnp.split(h, 2, axis=-1)
# Linear + shift/scale + activation
x = self.linear1(x)
if y is not None:
x = shift + x*(1 + scale)
x = self.activation(x)
# Linear + gate
x = self.linear2(x)
a, b = jnp.split(x, 2, axis=-1)
return x_in + a*jax.nn.sigmoid(b)
__init__(self, input_shape: Tuple[int], hidden_size: int, groups: Optional[int] = None, filter_shape: Optional[Tuple[int]] = None, cond_shape: Optional[Tuple[int]] = None, activation: Callable = <PjitFunction of <function silu at 0x7fd99f02dfc0>>, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input size. Output size is the same asinput_shape
.hidden_size
: The hidden layer size.cond_shape
: The size of the conditioning information.activation
: The activation function after each hidden layer.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/nn/resnet_blocks.py
def __init__(self,
input_shape: Tuple[int],
hidden_size: int,
groups: Optional[int] = None,
filter_shape: Optional[Tuple[int]] = None,
cond_shape: Optional[Tuple[int]] = None,
activation: Callable = jax.nn.swish,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input size. Output size is the same as `input_shape`.
- `hidden_size`: The hidden layer size.
- `cond_shape`: The size of the conditioning information.
- `activation`: The activation function after each hidden layer.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(**kwargs)
if len(input_shape) not in [1, 3]:
raise ValueError(f'Expected 1d or 3d input shape')
image = False
if len(input_shape) == 3:
H, W, C = input_shape
image = True
assert filter_shape is not None, 'Must pass in filter shape when processing images'
self.input_shape = input_shape
self.hidden_size = hidden_size
self.cond_shape = cond_shape
self.filter_shape = filter_shape
self.activation = activation
if groups is not None:
assert image
if hidden_size % groups != 0:
raise ValueError(f'Hidden size must be divisible by groups')
self.groups = groups
k1, k2, k3 = random.split(key, 3)
# Initialize the conditioning parameters
if cond_shape is not None:
if len(cond_shape) == 1:
self.linear_cond = WeightNormDense(in_size=cond_shape[0],
out_size=2*hidden_size,
key=k1)
else:
self.linear_cond = ConvAndGroupNorm(input_shape=cond_shape,
out_size=2*hidden_size,
filter_shape=filter_shape,
groups=groups,
key=k1)
else:
self.linear_cond = None
if image:
self.linear1 = ConvAndGroupNorm(input_shape=input_shape,
out_size=hidden_size,
filter_shape=filter_shape,
groups=groups,
key=k2)
hidden_shape = (H, W, hidden_size)
self.linear2 = WeightNormConv(input_shape=hidden_shape,
out_size=2*C,
filter_shape=filter_shape,
key=k3)
else:
self.linear1 = WeightNormDense(in_size=input_shape[0],
out_size=hidden_size,
key=k2)
self.linear2 = WeightNormDense(in_size=hidden_size,
out_size=2*input_shape[0],
key=k3)
data_dependent_init(self, x: Array, y: Array = None, key: PRNGKeyArray = None) -> Module
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/nn/resnet_blocks.py
def data_dependent_init(self,
x: Array,
y: Array = None,
key: PRNGKeyArray = None) -> eqx.Module:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
assert x.shape[1:] == self.input_shape, 'Only works on batched data'
k1, k2, k3 = random.split(key, 3)
# Initialize the conditioning parameters
if y is not None:
linear_cond = self.linear_cond.data_dependent_init(y, key=k1)
h = eqx.filter_vmap(linear_cond)(y)
shift, scale = jnp.split(h, 2, axis=-1)
else:
linear_cond = None
# Linear + shift/scale + activation
linear1 = self.linear1.data_dependent_init(x, key=k2)
x = eqx.filter_vmap(linear1)(x)
if y is not None:
x = shift + x*(1 + scale)
x = eqx.filter_vmap(self.activation)(x)
# Linear + gate
linear2 = self.linear2.data_dependent_init(x, key=k3)
# Turn the new parameters into a new module
get_linear_cond = lambda tree: tree.linear_cond
get_linear1 = lambda tree: tree.linear1
get_linear2 = lambda tree: tree.linear2
updated_layer = eqx.tree_at(get_linear_cond, self, linear_cond)
updated_layer = eqx.tree_at(get_linear1, updated_layer, linear1)
updated_layer = eqx.tree_at(get_linear2, updated_layer, linear2)
return updated_layer
__call__(self, x: Array, y: Array = None) -> Array
¤
Arguments:
x
: A JAX array with shapeinput_shape
.y
: A JAX array to condition on with shapecond_shape
.
Returns:
A JAX array with shape input_shape
.
Source code in generax/nn/resnet_blocks.py
def __call__(self, x: Array, y: Array = None) -> Array:
"""**Arguments:**
- `x`: A JAX array with shape `input_shape`.
- `y`: A JAX array to condition on with shape `cond_shape`.
**Returns:**
A JAX array with shape `input_shape`.
"""
assert x.shape == self.input_shape, 'Only works on unbatched data'
x_in = x
# The conditioning input will shift/scale x
if y is not None:
h = self.linear_cond(self.activation(y))
shift, scale = jnp.split(h, 2, axis=-1)
# Linear + shift/scale + activation
x = self.linear1(x)
if y is not None:
x = shift + x*(1 + scale)
x = self.activation(x)
# Linear + gate
x = self.linear2(x)
a, b = jnp.split(x, 2, axis=-1)
return x_in + a*jax.nn.sigmoid(b)
generax.nn.resnet_blocks.Block
¤
Group norm, (shift+scale), activation, conv
Source code in generax/nn/resnet_blocks.py
class Block(eqx.Module):
"""Group norm, (shift+scale), activation, conv
"""
input_shape: int = eqx.field(static=True)
conv: WeightNormConv
norm: eqx.nn.GroupNorm
def __init__(self,
input_shape: Tuple[int],
out_size: int,
groups: int,
*,
key: PRNGKeyArray,
**kwargs):
super().__init__(**kwargs)
H, W, C = input_shape
if C%groups != 0:
raise ValueError("The number of groups must divide the number of channels.")
self.norm = ChannelConvention(eqx.nn.GroupNorm(groups=groups, channels=C))
self.conv = WeightNormConv(input_shape=input_shape,
filter_shape=(3, 3),
out_size=out_size,
key=key)
self.input_shape = self.conv.input_shape
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> eqx.Module:
return self
def __call__(self,
x: Array,
y: Array = None,
shift_scale: Optional[Array] = None) -> Array:
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = self.input_shape
h = self.norm(x)
if shift_scale is not None:
shift, scale = shift_scale
h = shift + h*(1 + scale)
h = jax.nn.silu(h)
h = self.conv(h)
return h
__init__(self, input_shape: Tuple[int], out_size: int, groups: int, *, key: PRNGKeyArray, **kwargs)
¤
Source code in generax/nn/resnet_blocks.py
def __init__(self,
input_shape: Tuple[int],
out_size: int,
groups: int,
*,
key: PRNGKeyArray,
**kwargs):
super().__init__(**kwargs)
H, W, C = input_shape
if C%groups != 0:
raise ValueError("The number of groups must divide the number of channels.")
self.norm = ChannelConvention(eqx.nn.GroupNorm(groups=groups, channels=C))
self.conv = WeightNormConv(input_shape=input_shape,
filter_shape=(3, 3),
out_size=out_size,
key=key)
self.input_shape = self.conv.input_shape
data_dependent_init(self, x: Array, y: Optional[Array] = None, key: PRNGKeyArray = None) -> Module
¤
Source code in generax/nn/resnet_blocks.py
def data_dependent_init(self,
x: Array,
y: Optional[Array] = None,
key: PRNGKeyArray = None) -> eqx.Module:
return self
__call__(self, x: Array, y: Array = None, shift_scale: Optional[Array] = None) -> Array
¤
Call self as a function.
Source code in generax/nn/resnet_blocks.py
def __call__(self,
x: Array,
y: Array = None,
shift_scale: Optional[Array] = None) -> Array:
assert x.shape == self.input_shape, 'Only works on unbatched data'
H, W, C = self.input_shape
h = self.norm(x)
if shift_scale is not None:
shift, scale = shift_scale
h = shift + h*(1 + scale)
h = jax.nn.silu(h)
h = self.conv(h)
return h
generax.nn.resnet_blocks.ImageResBlock
¤
Gated residual block for images.
Source code in generax/nn/resnet_blocks.py
class ImageResBlock(eqx.Module):
"""Gated residual block for images."""
linear_cond: Union[ConvAndGroupNorm, None]
block1: Block
block2: Block
res_conv: Union[ConvAndGroupNorm,eqx.nn.Identity]
gca: GatedGlobalContext
input_shape: Tuple[int] = eqx.field(static=True)
hidden_size: int = eqx.field(static=True)
out_size: int = eqx.field(static=True)
cond_shape: Tuple[int] = eqx.field(static=True)
groups: Union[int,None] = eqx.field(static=True)
def __init__(self,
input_shape: Tuple[int],
hidden_size: int,
out_size: int,
groups: Optional[int] = None,
cond_shape: Optional[Tuple[int]] = None,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input size. Output size is the same as `input_shape`.
- `hidden_size`: The hidden layer size.
- `cond_shape`: The size of the conditioning information.
- `activation`: The activation function after each hidden layer.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(**kwargs)
H, W, C = input_shape
self.input_shape = input_shape
self.hidden_size = hidden_size
self.cond_shape = cond_shape
self.out_size = out_size
if hidden_size % groups != 0:
raise ValueError(f'Hidden size must be divisible by groups')
self.groups = groups
k1, k2, k3, k4, k5 = random.split(key, 5)
# Initialize the conditioning parameters
if cond_shape is not None:
if len(cond_shape) != 1:
raise ValueError(f'Conditioning shape must be 1d')
self.linear_cond = WeightNormDense(in_size=cond_shape[0],
out_size=2*hidden_size,
key=k1)
else:
self.linear_cond = None
self.block1 = Block(input_shape=input_shape,
out_size=hidden_size,
groups=groups,
key=k2)
self.block2 = Block(input_shape=(H, W, hidden_size),
out_size=out_size,
groups=groups,
key=k3)
self.gca = GatedGlobalContext(input_shape=(H, W, out_size),
key=k4)
if out_size != C:
self.res_conv = WeightNormConv(input_shape=input_shape,
out_size=out_size,
filter_shape=(3, 3),
key=k5)
else:
self.res_conv = eqx.nn.Identity()
def data_dependent_init(self,
x: Array,
y: Array = None,
key: PRNGKeyArray = None) -> eqx.Module:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
def __call__(self, x: Array, y: Array = None) -> Array:
"""**Arguments:**
- `x`: A JAX array with shape `input_shape`.
- `y`: A JAX array to condition on with shape `cond_shape`.
**Returns:**
A JAX array with shape `input_shape`.
"""
x_in = x
h = self.block1(x)
# The conditioning input will shift/scale x
if y is not None:
hh = self.linear_cond(jax.nn.silu(y))
shift_scale = jnp.split(hh, 2, axis=-1)
else:
shift_scale = None
h = self.block2(h, shift_scale=shift_scale)
h = self.gca(h)
return self.res_conv(x_in) + h
__init__(self, input_shape: Tuple[int], hidden_size: int, out_size: int, groups: Optional[int] = None, cond_shape: Optional[Tuple[int]] = None, *, key: PRNGKeyArray, **kwargs)
¤
Arguments:
input_shape
: The input size. Output size is the same asinput_shape
.hidden_size
: The hidden layer size.cond_shape
: The size of the conditioning information.activation
: The activation function after each hidden layer.key
: Ajax.random.PRNGKey
for initialization
Source code in generax/nn/resnet_blocks.py
def __init__(self,
input_shape: Tuple[int],
hidden_size: int,
out_size: int,
groups: Optional[int] = None,
cond_shape: Optional[Tuple[int]] = None,
*,
key: PRNGKeyArray,
**kwargs):
"""**Arguments**:
- `input_shape`: The input size. Output size is the same as `input_shape`.
- `hidden_size`: The hidden layer size.
- `cond_shape`: The size of the conditioning information.
- `activation`: The activation function after each hidden layer.
- `key`: A `jax.random.PRNGKey` for initialization
"""
super().__init__(**kwargs)
H, W, C = input_shape
self.input_shape = input_shape
self.hidden_size = hidden_size
self.cond_shape = cond_shape
self.out_size = out_size
if hidden_size % groups != 0:
raise ValueError(f'Hidden size must be divisible by groups')
self.groups = groups
k1, k2, k3, k4, k5 = random.split(key, 5)
# Initialize the conditioning parameters
if cond_shape is not None:
if len(cond_shape) != 1:
raise ValueError(f'Conditioning shape must be 1d')
self.linear_cond = WeightNormDense(in_size=cond_shape[0],
out_size=2*hidden_size,
key=k1)
else:
self.linear_cond = None
self.block1 = Block(input_shape=input_shape,
out_size=hidden_size,
groups=groups,
key=k2)
self.block2 = Block(input_shape=(H, W, hidden_size),
out_size=out_size,
groups=groups,
key=k3)
self.gca = GatedGlobalContext(input_shape=(H, W, out_size),
key=k4)
if out_size != C:
self.res_conv = WeightNormConv(input_shape=input_shape,
out_size=out_size,
filter_shape=(3, 3),
key=k5)
else:
self.res_conv = eqx.nn.Identity()
data_dependent_init(self, x: Array, y: Array = None, key: PRNGKeyArray = None) -> Module
¤
Initialize the parameters of the layer based on the data.
Arguments:
x
: The data to initialize the parameters with.y
: The conditioning informationkey
: Ajax.random.PRNGKey
for initialization
Returns: A new layer with the parameters initialized.
Source code in generax/nn/resnet_blocks.py
def data_dependent_init(self,
x: Array,
y: Array = None,
key: PRNGKeyArray = None) -> eqx.Module:
"""Initialize the parameters of the layer based on the data.
**Arguments**:
- `x`: The data to initialize the parameters with.
- `y`: The conditioning information
- `key`: A `jax.random.PRNGKey` for initialization
**Returns**:
A new layer with the parameters initialized.
"""
return self
__call__(self, x: Array, y: Array = None) -> Array
¤
Arguments:
x
: A JAX array with shapeinput_shape
.y
: A JAX array to condition on with shapecond_shape
.
Returns:
A JAX array with shape input_shape
.
Source code in generax/nn/resnet_blocks.py
def __call__(self, x: Array, y: Array = None) -> Array:
"""**Arguments:**
- `x`: A JAX array with shape `input_shape`.
- `y`: A JAX array to condition on with shape `cond_shape`.
**Returns:**
A JAX array with shape `input_shape`.
"""
x_in = x
h = self.block1(x)
# The conditioning input will shift/scale x
if y is not None:
hh = self.linear_cond(jax.nn.silu(y))
shift_scale = jnp.split(hh, 2, axis=-1)
else:
shift_scale = None
h = self.block2(h, shift_scale=shift_scale)
h = self.gca(h)
return self.res_conv(x_in) + h