1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class MyCustomTrainer(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, device: torch.device = torch.device('cuda')): super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) self.initial_lr = 5e-3 self.num_epochs = 500 self.weight_decay = 3e-5 def configure_optimizers(self): optimizer = torch.optim.SGD( self.network.parameters(), lr=self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True ) return optimizer def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
|