AirNet-SNL Examples

Check out the examples below!

Example 1: Training AirNet-SNL

 1import airnetSNL.model.train_loop as tl
 2import airnetSNL.model.airnet_snl as snl
 3import airnetSNL.dataset.dataset_utils as du
 4import torch
 5from torch.utils.data import TensorDataset, DataLoader
 6from torch import optim
 7
 8angles = du.decimateAngles(nAnglesFull=451,
 9                            downsample=8)
10imgSize = 336
11batchSize = 10
12model = snl.AirNetSNL(angles=angles,
13                        n_iterations=12,
14                        n_cnn=10,
15                        imgSize=imgSize,
16                        batchSize=batchSize,
17                        includeSkipConnection=False)
18
19optimizer = optim.Adam(model.parameters(), lr=1e-5)
20
21trainSinograms = torch.zeros(100, 1, len(angles), imgSize)
22trainImages = torch.zeros(100, 1, imgSize, imgSize)
23trainSet = TensorDataset(trainSinograms, trainImages)
24trainLoader = DataLoader(trainSet, batch_size=batchSize)
25
26tl.train_model(model=model,
27                optimizer=optimizer,
28                train_loader=trainLoader,
29                nEpochs=1,
30                saveModel=False,
31                resumeFrom=0,
32                saveFilePath='./testModel.pth')

Example 2: Running inference with AirNet-SNL

 1import airnetSNL.model.airnet_snl as snl
 2import airnetSNL.dataset.dataset_utils as du
 3import torch
 4from torch.utils.data import TensorDataset, DataLoader
 5
 6angles = du.decimateAngles(nAnglesFull=451,
 7                            downsample=8)
 8imgSize = 336
 9batchSize = 10
10totalSamples = 100
11
12model = snl.AirNetSNL(angles=angles,
13                        n_iterations=12,
14                        n_cnn=10,
15                        imgSize=imgSize,
16                        batchSize=batchSize,
17                        includeSkipConnection=False)
18model = model.cuda()
19filepath = './model.pth'
20testSinograms = torch.zeros(totalSamples, 1, len(angles), imgSize)
21testImages = torch.zeros(totalSamples, 1, imgSize, imgSize)
22testSet = TensorDataset(testSinograms.cpu(), testImages.cpu())
23testLoader = DataLoader(testSet, batch_size=batchSize)
24y_img_pred = run_inference(testLoader, model, filepath)