model package
model.airnet_snl module
model.train_loop module
- model.train_loop.run_inference(testLoader, model, filepath, isFileDict=True)[source]
Run the AirNet-SNL model by passing it in the argument.
- Args:
testLoader (DataLoader): Dataset for running inference.
model (nn.Module): Use this model for inference.
filepath (str): Saved model weights.
isFileDict (bool): True if filepath is a dictionary
- Returns:
Predictions: Tensor of size [nSamples, 1, imgSize, imgSize].
- Example:
1import airnetSNL.model.airnet_snl as snl 2import airnetSNL.dataset.dataset_utils as du 3import torch 4from torch.utils.data import TensorDataset, DataLoader 5 6angles = du.decimateAngles(nAnglesFull=451, 7 downsample=8) 8imgSize = 336 9batchSize = 10 10totalSamples = 100 11 12model = snl.AirNetSNL(angles=angles, 13 n_iterations=12, 14 n_cnn=10, 15 imgSize=imgSize, 16 batchSize=batchSize, 17 includeSkipConnection=False) 18model = model.cuda() 19filepath = './model.pth' 20testSinograms = torch.zeros(totalSamples, 1, len(angles), imgSize) 21testImages = torch.zeros(totalSamples, 1, imgSize, imgSize) 22testSet = TensorDataset(testSinograms.cpu(), testImages.cpu()) 23testLoader = DataLoader(testSet, batch_size=batchSize) 24y_img_pred = run_inference(testLoader, model, filepath)
- model.train_loop.train_model(model: Module, optimizer: Adam, train_loader: DataLoader, nEpochs: int, saveModel: bool = False, resumeFrom: int = 0, saveFilePath: str = '', loadFilePath: str = '')[source]
Train a model.
- Args:
model (Module): Model to be trained.
optimizer (Adam): Optimization parameters, e.g. learning rate.
train_loader (DataLoader): Dataset parameters, e.g. batch size.
nEpochs (int): Number of epochs for training.
saveModel (bool): Save model if loss improves.
resumeFrom (int): Resume training from epoch number.
saveFilePath (str): Where to save the model.
loadFilePath (str): Which model to load.
- Example:
1import airnetSNL.model.train_loop as tl 2import airnetSNL.model.airnet_snl as snl 3import airnetSNL.dataset.dataset_utils as du 4import torch 5from torch.utils.data import TensorDataset, DataLoader 6from torch import optim 7 8angles = du.decimateAngles(nAnglesFull=451, 9 downsample=8) 10imgSize = 336 11batchSize = 10 12model = snl.AirNetSNL(angles=angles, 13 n_iterations=12, 14 n_cnn=10, 15 imgSize=imgSize, 16 batchSize=batchSize, 17 includeSkipConnection=False) 18 19optimizer = optim.Adam(model.parameters(), lr=1e-5) 20 21trainSinograms = torch.zeros(100, 1, len(angles), imgSize) 22trainImages = torch.zeros(100, 1, imgSize, imgSize) 23trainSet = TensorDataset(trainSinograms, trainImages) 24trainLoader = DataLoader(trainSet, batch_size=batchSize) 25 26tl.train_model(model=model, 27 optimizer=optimizer, 28 train_loader=trainLoader, 29 nEpochs=1, 30 saveModel=False, 31 resumeFrom=0, 32 saveFilePath='./testModel.pth')