AirNet-SNL Examples
Check out the examples below!
Example 1: Training AirNet-SNL
1import airnetSNL.model.train_loop as tl
2import airnetSNL.model.airnet_snl as snl
3import airnetSNL.dataset.dataset_utils as du
4import torch
5from torch.utils.data import TensorDataset, DataLoader
6from torch import optim
7
8angles = du.decimateAngles(nAnglesFull=451,
9 downsample=8)
10imgSize = 336
11batchSize = 10
12model = snl.AirNetSNL(angles=angles,
13 n_iterations=12,
14 n_cnn=10,
15 imgSize=imgSize,
16 batchSize=batchSize,
17 includeSkipConnection=False)
18
19optimizer = optim.Adam(model.parameters(), lr=1e-5)
20
21trainSinograms = torch.zeros(100, 1, len(angles), imgSize)
22trainImages = torch.zeros(100, 1, imgSize, imgSize)
23trainSet = TensorDataset(trainSinograms, trainImages)
24trainLoader = DataLoader(trainSet, batch_size=batchSize)
25
26tl.train_model(model=model,
27 optimizer=optimizer,
28 train_loader=trainLoader,
29 nEpochs=1,
30 saveModel=False,
31 resumeFrom=0,
32 saveFilePath='./testModel.pth')
Example 2: Running inference with AirNet-SNL
1import airnetSNL.model.airnet_snl as snl
2import airnetSNL.dataset.dataset_utils as du
3import torch
4from torch.utils.data import TensorDataset, DataLoader
5
6angles = du.decimateAngles(nAnglesFull=451,
7 downsample=8)
8imgSize = 336
9batchSize = 10
10totalSamples = 100
11
12model = snl.AirNetSNL(angles=angles,
13 n_iterations=12,
14 n_cnn=10,
15 imgSize=imgSize,
16 batchSize=batchSize,
17 includeSkipConnection=False)
18model = model.cuda()
19filepath = './model.pth'
20testSinograms = torch.zeros(totalSamples, 1, len(angles), imgSize)
21testImages = torch.zeros(totalSamples, 1, imgSize, imgSize)
22testSet = TensorDataset(testSinograms.cpu(), testImages.cpu())
23testLoader = DataLoader(testSet, batch_size=batchSize)
24y_img_pred = run_inference(testLoader, model, filepath)