class TrainableTopoSolver[source]

TrainableTopoSolver(criterion:dl4to.criteria.Criterion, model:Module, optimizer:Optimizer, preprocessing:dl4to.preprocessing.Preprocessing=<dl4to.preprocessing.TrivialPreprocessing object at 0x7ff686e806a0>, name:str=None) :: TopoSolver

A topo solver that is trainable and can be used for learned topology optimization.

Type Default Details
criterion dl4to.criteria.Criterion The loss criterion that should be used for the training.
model Module A PyTorch neural network. Make sure that the input and output dimensions are correct.
optimizer Optimizer A PyTorch optimizer, for instance torch.optim.Adam. Make sure to set params=model.parameters() if you want to use the optimizer to train the neural network.
preprocessing dl4to.preprocessing.Preprocessing <dl4to.preprocessing.TrivialPreprocessing object at 0x7ff686e806a0> The preprocessing that should be used in the pipeline.
name str None The name of the topo solver.

TrainableTopoSolver.get_args_as_dict[source]

TrainableTopoSolver.get_args_as_dict()

Returns basic properties and arguments of the topo solver as a dictionary.

TrainableTopoSolver.train[source]

TrainableTopoSolver.train(root:str, dataloader_train:DataLoader, dataloader_val:DataLoader=None, epochs:int=100, validation_interval:int=10, verbose:bool=True, patience:bool=None)

Run the training for the topo solver.

Type Default Details
root str The directory where the training results should be saved.
dataloader_train DataLoader The dataloader that contains the training data.
dataloader_val DataLoader None The dataloader that contains the validation data.
epochs int 100 The maximal number of training epochs.
validation_interval int 10 The number of epochs after which a validation step is performed and printed.
verbose bool True Whether to print information on the current training status, like the current loss and epoch.
patience bool None If the validation score does not improve for patience epochs in a row, then the training is stopped and the best model is used.