class EquivarianceWrapper[source]
EquivarianceWrapper(preprocessing:dl4to.preprocessing.Preprocessing=None,rotate:bool=True,mirror:bool=True,dim:int=2,rotate_twice:bool=False,sample_rate:float=1.0)
A class that represents an equivariance wrapper [1] that implements group equivariance via group averaging [2].
| Type | Default | Details | |
|---|---|---|---|
preprocessing |
dl4to.preprocessing.Preprocessing |
None |
The preprocessing strategy to use. This is used in the equivariance wrapper to obtain the scalar and vector field information of the input. |
rotate |
bool |
True |
Whether to include rotational equivariance in the transformation group. |
mirror |
bool |
True |
Whether to include mirror equivariance in the transformation group. |
dim |
int |
2 |
The dimension of the transformation group. Specifically, a 2d transformation group does not consider rotations and mirrors along the z-axis. |
rotate_twice |
bool |
False |
Whether double-rotations should be used, where the input is rotated twice, along two different axes. This may result in a larger transformation group. |
sample_rate |
float |
1.0 |
The rate of transformations that should be randomly sampled in the forward pass. sample_rate=1. defaults to all transformations being used in the wrapper. A smaller choice may be beneficial if memory constraints don't allow for the applications of all transformations in each forward pass. |
EquivarianceWrapper.mirror_input[source]
EquivarianceWrapper.mirror_input(x:Tensor,flip_dimensions:list)
Returnes a torch.Tensor, which is a mirrored version of the input x.
| Type | Default | Details | |
|---|---|---|---|
x |
Tensor |
The input that should be mirrored/flipped. | |
flip_dimensions |
list |
The dimension along which the input should be mirrored. |
EquivarianceWrapper.rotate_input[source]
EquivarianceWrapper.rotate_input(x:Tensor,rotations:int,plane:list)
Returnes a torch.Tensor, which is a rotated version of the input x.
| Type | Default | Details | |
|---|---|---|---|
x |
Tensor |
The input that should be rotated. | |
rotations |
int |
The number of 90° rations that should be performed. Four rotations result in the identity. | |
plane |
list |
On which plane the input should be rotated. |
EquivarianceWrapper.get_transforms[source]
EquivarianceWrapper.get_transforms(sample_rate:float=None)
Returns a list of all group actions that are applied to an input in the equivariance wrapper.
| Type | Default | Details | |
|---|---|---|---|
sample_rate |
float |
None |
The rate of transformations that should be randomly samples from the equivariance wrapper. None defaults to equivariance_wrapper.sample_rate. 1. means that all transformations are considered. |
EquivarianceWrapper.__call__[source]
EquivarianceWrapper.__call__(model:Module)
Applies the equivariance wrapper to a torch.nn.Module model object and returns an dl4to.models.EquivariantModel object.
| Type | Default | Details | |
|---|---|---|---|
model |
Module |
The model that should be turned into an equivariant model. |
class EquivariantModel[source]
EquivariantModel(model:Module,equivariance_wrapper:dl4to.models.EquivarianceWrapper) ::Module
A class that represents an equivariant model with respect to a specific equivariance wrapper.
| Type | Default | Details | |
|---|---|---|---|
model |
Module |
A PyTorch neural network. | |
equivariance_wrapper |
dl4to.models.EquivarianceWrapper |
The equivariance wrapper that is applied to the model. |
EquivariantModel.__call__[source]
EquivariantModel.__call__(model_inputs:Tensor,sample_rate:float=None)
The forward method for the equivariant model.
| Type | Default | Details | |
|---|---|---|---|
model_inputs |
Tensor |
The model inputs that are obtained as output of the preprocessing. | |
sample_rate |
float |
None |
The rate of transformations that should be randomly samples from the equivariance wrapper. None defaults to equivariance_wrapper.sample_rate.1. means that all transformations are applied in the forward pass. |
[1] Dittmer, Sören, et al. "SELTO: Sample-Efficient Learned Topology Optimization." arXiv preprint arXiv:2209.05098 (2022).
[2] Puny, Omri, et al. "Frame averaging for invariant and equivariant network design." arXiv preprint arXiv:2110.03336 (2021).