model package

model.airnet_snl module

model.train_loop module

model.train_loop.run_inference(testLoader, model, filepath, isFileDict=True)[source]

Run the AirNet-SNL model by passing it in the argument.

Args:
  • testLoader (DataLoader): Dataset for running inference.

  • model (nn.Module): Use this model for inference.

  • filepath (str): Saved model weights.

  • isFileDict (bool): True if filepath is a dictionary

Returns:

Predictions: Tensor of size [nSamples, 1, imgSize, imgSize].

Example:
 1import airnetSNL.model.airnet_snl as snl
 2import airnetSNL.dataset.dataset_utils as du
 3import torch
 4from torch.utils.data import TensorDataset, DataLoader
 5
 6angles = du.decimateAngles(nAnglesFull=451,
 7                           downsample=8)
 8imgSize = 336
 9batchSize = 10
10totalSamples = 100
11
12model = snl.AirNetSNL(angles=angles,
13                      n_iterations=12,
14                      n_cnn=10,
15                      imgSize=imgSize,
16                      batchSize=batchSize,
17                      includeSkipConnection=False)
18model = model.cuda()
19filepath = './model.pth'
20testSinograms = torch.zeros(totalSamples, 1, len(angles), imgSize)
21testImages = torch.zeros(totalSamples, 1, imgSize, imgSize)
22testSet = TensorDataset(testSinograms.cpu(), testImages.cpu())
23testLoader = DataLoader(testSet, batch_size=batchSize)
24y_img_pred = run_inference(testLoader, model, filepath)
model.train_loop.train_model(model: Module, optimizer: Adam, train_loader: DataLoader, nEpochs: int, saveModel: bool = False, resumeFrom: int = 0, saveFilePath: str = '', loadFilePath: str = '')[source]

Train a model.

Args:
  • model (Module): Model to be trained.

  • optimizer (Adam): Optimization parameters, e.g. learning rate.

  • train_loader (DataLoader): Dataset parameters, e.g. batch size.

  • nEpochs (int): Number of epochs for training.

  • saveModel (bool): Save model if loss improves.

  • resumeFrom (int): Resume training from epoch number.

  • saveFilePath (str): Where to save the model.

  • loadFilePath (str): Which model to load.

Example:
 1import airnetSNL.model.train_loop as tl
 2import airnetSNL.model.airnet_snl as snl
 3import airnetSNL.dataset.dataset_utils as du
 4import torch
 5from torch.utils.data import TensorDataset, DataLoader
 6from torch import optim
 7
 8angles = du.decimateAngles(nAnglesFull=451,
 9                           downsample=8)
10imgSize = 336
11batchSize = 10
12model = snl.AirNetSNL(angles=angles,
13                      n_iterations=12,
14                      n_cnn=10,
15                      imgSize=imgSize,
16                      batchSize=batchSize,
17                      includeSkipConnection=False)
18
19optimizer = optim.Adam(model.parameters(), lr=1e-5)
20
21trainSinograms = torch.zeros(100, 1, len(angles), imgSize)
22trainImages = torch.zeros(100, 1, imgSize, imgSize)
23trainSet = TensorDataset(trainSinograms, trainImages)
24trainLoader = DataLoader(trainSet, batch_size=batchSize)
25
26tl.train_model(model=model,
27               optimizer=optimizer,
28               train_loader=trainLoader,
29               nEpochs=1,
30               saveModel=False,
31               resumeFrom=0,
32               saveFilePath='./testModel.pth')