class
TrainableTopoSolver
[source]
TrainableTopoSolver
(criterion
:dl4to.criteria.Criterion
,model
:Module
,optimizer
:Optimizer
,preprocessing
:dl4to.preprocessing.Preprocessing
=<dl4to.preprocessing.TrivialPreprocessing object at 0x7ff686e806a0>
,name
:str
=None
) ::TopoSolver
A topo solver that is trainable and can be used for learned topology optimization.
Type | Default | Details | |
---|---|---|---|
criterion |
dl4to.criteria.Criterion |
The loss criterion that should be used for the training. | |
model |
Module |
A PyTorch neural network. Make sure that the input and output dimensions are correct. | |
optimizer |
Optimizer |
A PyTorch optimizer, for instance torch.optim.Adam. Make sure to set params=model.parameters() if you want to use the optimizer to train the neural network. |
|
preprocessing |
dl4to.preprocessing.Preprocessing |
<dl4to.preprocessing.TrivialPreprocessing object at 0x7ff686e806a0> |
The preprocessing that should be used in the pipeline. |
name |
str |
None |
The name of the topo solver. |
TrainableTopoSolver.get_args_as_dict
[source]
TrainableTopoSolver.get_args_as_dict
()
Returns basic properties and arguments of the topo solver as a dictionary.
TrainableTopoSolver.train
[source]
TrainableTopoSolver.train
(root
:str
,dataloader_train
:DataLoader
,dataloader_val
:DataLoader
=None
,epochs
:int
=100
,validation_interval
:int
=10
,verbose
:bool
=True
,patience
:bool
=None
)
Run the training for the topo solver.
Type | Default | Details | |
---|---|---|---|
root |
str |
The directory where the training results should be saved. | |
dataloader_train |
DataLoader |
The dataloader that contains the training data. | |
dataloader_val |
DataLoader |
None |
The dataloader that contains the validation data. |
epochs |
int |
100 |
The maximal number of training epochs. |
validation_interval |
int |
10 |
The number of epochs after which a validation step is performed and printed. |
verbose |
bool |
True |
Whether to print information on the current training status, like the current loss and epoch. |
patience |
bool |
None |
If the validation score does not improve for patience epochs in a row, then the training is stopped and the best model is used. |