class
EquivarianceWrapper
[source]
EquivarianceWrapper
(preprocessing
:dl4to.preprocessing.Preprocessing
=None
,rotate
:bool
=True
,mirror
:bool
=True
,dim
:int
=2
,rotate_twice
:bool
=False
,sample_rate
:float
=1.0
)
A class that represents an equivariance wrapper [1] that implements group equivariance via group averaging [2].
Type | Default | Details | |
---|---|---|---|
preprocessing |
dl4to.preprocessing.Preprocessing |
None |
The preprocessing strategy to use. This is used in the equivariance wrapper to obtain the scalar and vector field information of the input. |
rotate |
bool |
True |
Whether to include rotational equivariance in the transformation group. |
mirror |
bool |
True |
Whether to include mirror equivariance in the transformation group. |
dim |
int |
2 |
The dimension of the transformation group. Specifically, a 2d transformation group does not consider rotations and mirrors along the z-axis. |
rotate_twice |
bool |
False |
Whether double-rotations should be used, where the input is rotated twice, along two different axes. This may result in a larger transformation group. |
sample_rate |
float |
1.0 |
The rate of transformations that should be randomly sampled in the forward pass. sample_rate=1. defaults to all transformations being used in the wrapper. A smaller choice may be beneficial if memory constraints don't allow for the applications of all transformations in each forward pass. |
EquivarianceWrapper.mirror_input
[source]
EquivarianceWrapper.mirror_input
(x
:Tensor
,flip_dimensions
:list
)
Returnes a torch.Tensor
, which is a mirrored version of the input x
.
Type | Default | Details | |
---|---|---|---|
x |
Tensor |
The input that should be mirrored/flipped. | |
flip_dimensions |
list |
The dimension along which the input should be mirrored. |
EquivarianceWrapper.rotate_input
[source]
EquivarianceWrapper.rotate_input
(x
:Tensor
,rotations
:int
,plane
:list
)
Returnes a torch.Tensor
, which is a rotated version of the input x
.
Type | Default | Details | |
---|---|---|---|
x |
Tensor |
The input that should be rotated. | |
rotations |
int |
The number of 90° rations that should be performed. Four rotations result in the identity. | |
plane |
list |
On which plane the input should be rotated. |
EquivarianceWrapper.get_transforms
[source]
EquivarianceWrapper.get_transforms
(sample_rate
:float
=None
)
Returns a list of all group actions that are applied to an input in the equivariance wrapper.
Type | Default | Details | |
---|---|---|---|
sample_rate |
float |
None |
The rate of transformations that should be randomly samples from the equivariance wrapper. None defaults to equivariance_wrapper.sample_rate . 1. means that all transformations are considered. |
EquivarianceWrapper.__call__
[source]
EquivarianceWrapper.__call__
(model
:Module
)
Applies the equivariance wrapper to a torch.nn.Module
model object and returns an dl4to.models.EquivariantModel
object.
Type | Default | Details | |
---|---|---|---|
model |
Module |
The model that should be turned into an equivariant model. |
class
EquivariantModel
[source]
EquivariantModel
(model
:Module
,equivariance_wrapper
:dl4to.models.EquivarianceWrapper
) ::Module
A class that represents an equivariant model with respect to a specific equivariance wrapper.
Type | Default | Details | |
---|---|---|---|
model |
Module |
A PyTorch neural network. | |
equivariance_wrapper |
dl4to.models.EquivarianceWrapper |
The equivariance wrapper that is applied to the model. |
EquivariantModel.__call__
[source]
EquivariantModel.__call__
(model_inputs
:Tensor
,sample_rate
:float
=None
)
The forward method for the equivariant model.
Type | Default | Details | |
---|---|---|---|
model_inputs |
Tensor |
The model inputs that are obtained as output of the preprocessing. | |
sample_rate |
float |
None |
The rate of transformations that should be randomly samples from the equivariance wrapper. None defaults to equivariance_wrapper.sample_rate .1. means that all transformations are applied in the forward pass. |
[1] Dittmer, Sören, et al. "SELTO: Sample-Efficient Learned Topology Optimization." arXiv preprint arXiv:2209.05098 (2022).
[2] Puny, Omri, et al. "Frame averaging for invariant and equivariant network design." arXiv preprint arXiv:2110.03336 (2021).