class TrainableTopoSolver[source]
TrainableTopoSolver(criterion:dl4to.criteria.Criterion,model:Module,optimizer:Optimizer,preprocessing:dl4to.preprocessing.Preprocessing=<dl4to.preprocessing.TrivialPreprocessing object at 0x7ff686e806a0>,name:str=None) ::TopoSolver
A topo solver that is trainable and can be used for learned topology optimization.
| Type | Default | Details | |
|---|---|---|---|
criterion |
dl4to.criteria.Criterion |
The loss criterion that should be used for the training. | |
model |
Module |
A PyTorch neural network. Make sure that the input and output dimensions are correct. | |
optimizer |
Optimizer |
A PyTorch optimizer, for instance torch.optim.Adam. Make sure to set params=model.parameters() if you want to use the optimizer to train the neural network. |
|
preprocessing |
dl4to.preprocessing.Preprocessing |
<dl4to.preprocessing.TrivialPreprocessing object at 0x7ff686e806a0> |
The preprocessing that should be used in the pipeline. |
name |
str |
None |
The name of the topo solver. |
TrainableTopoSolver.get_args_as_dict[source]
TrainableTopoSolver.get_args_as_dict()
Returns basic properties and arguments of the topo solver as a dictionary.
TrainableTopoSolver.train[source]
TrainableTopoSolver.train(root:str,dataloader_train:DataLoader,dataloader_val:DataLoader=None,epochs:int=100,validation_interval:int=10,verbose:bool=True,patience:bool=None)
Run the training for the topo solver.
| Type | Default | Details | |
|---|---|---|---|
root |
str |
The directory where the training results should be saved. | |
dataloader_train |
DataLoader |
The dataloader that contains the training data. | |
dataloader_val |
DataLoader |
None |
The dataloader that contains the validation data. |
epochs |
int |
100 |
The maximal number of training epochs. |
validation_interval |
int |
10 |
The number of epochs after which a validation step is performed and printed. |
verbose |
bool |
True |
Whether to print information on the current training status, like the current loss and epoch. |
patience |
bool |
None |
If the validation score does not improve for patience epochs in a row, then the training is stopped and the best model is used. |