Source code for model.train_loop

import errno
import numpy as np
import os
import torch
from torch.nn import functional as F


[docs]def train_model(model: torch.nn.Module, optimizer: torch.optim.Adam, train_loader: torch.utils.data.DataLoader, nEpochs: int, saveModel: bool = False, resumeFrom: int = 0, saveFilePath: str = '', loadFilePath: str = ''): """Train a model. Args: * model (Module): Model to be trained. * optimizer (Adam): Optimization parameters, e.g. learning rate. * train_loader (DataLoader): Dataset parameters, e.g. batch size. * nEpochs (int): Number of epochs for training. * saveModel (bool): Save model if loss improves. * resumeFrom (int): Resume training from epoch number. * saveFilePath (str): Where to save the model. * loadFilePath (str): Which model to load. Example: .. code-block:: python :linenos: import airnetSNL.model.train_loop as tl import airnetSNL.model.airnet_snl as snl import airnetSNL.dataset.dataset_utils as du import torch from torch.utils.data import TensorDataset, DataLoader from torch import optim angles = du.decimateAngles(nAnglesFull=451, downsample=8) imgSize = 336 batchSize = 10 model = snl.AirNetSNL(angles=angles, n_iterations=12, n_cnn=10, imgSize=imgSize, batchSize=batchSize, includeSkipConnection=False) optimizer = optim.Adam(model.parameters(), lr=1e-5) trainSinograms = torch.zeros(100, 1, len(angles), imgSize) trainImages = torch.zeros(100, 1, imgSize, imgSize) trainSet = TensorDataset(trainSinograms, trainImages) trainLoader = DataLoader(trainSet, batch_size=batchSize) tl.train_model(model=model, optimizer=optimizer, train_loader=trainLoader, nEpochs=1, saveModel=False, resumeFrom=0, saveFilePath='./testModel.pth') """ def train_one_batch(x_sinogram, y_img_gt): y_img_pred = model(x_sinogram) loss = F.mse_loss(y_img_pred, y_img_gt) optimizer.zero_grad() # Clears Gradient loss.backward() # Calculate Gradients optimizer.step() # Update Weights loss = loss.item() return loss def train_epoch(): total_loss = 0 for x_sinogram, y_img in train_loader: x_sinogram = x_sinogram.to('cuda') y_img = y_img.to('cuda') batch_loss = train_one_batch(x_sinogram, y_img) total_loss += batch_loss return total_loss def save_model(epoch_num): state = { 'epoch': epoch_num, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss_per_epoch': loss_per_epoch, } torch.save(state, saveFilePath) model.cuda() if not saveFilePath: raise ValueError("Need to specify saveFilePath to save model.") saved_epoch = 0 loss_per_epoch = -1 * torch.ones(nEpochs) loss_per_epoch = loss_per_epoch.type(torch.cuda.FloatTensor).cuda() if resumeFrom > 0: if not os.path.isfile(loadFilePath): raise FileNotFoundError(errno.ENOENT, os.stderror(errno.ENOENT), loadFilePath) checkpoint = torch.load(loadFilePath) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) saved_epoch = checkpoint['epoch'] saved_loss = checkpoint['loss_per_epoch'] # Initialize loss_per_epoch with saved_loss if len(saved_loss) < nEpochs: loss_per_epoch = -1 * np.ones(nEpochs) loss_per_epoch[:len(saved_loss)] = saved_loss print('Loaded model!') print(f'saved_epoch = {saved_epoch}, nEpochs = {nEpochs}') save_interval = 1 minLoss = torch.tensor(1e9).type(torch.cuda.FloatTensor).cuda() for e in range(saved_epoch, nEpochs): total_loss = train_epoch() loss_per_epoch[e] = total_loss # Save model if loss decreases if e % save_interval == 0 and saveModel and total_loss < minLoss: minLoss = total_loss save_model(e) print(f'epoch = {e}, loss = {total_loss}, saved!') else: print(f'epoch = {e}, loss = {total_loss}') # Save if checkpoint before nEpochs if saved_epoch < nEpochs and saveModel: epoch = nEpochs - 1 save_model(epoch) print('Saved model!')
[docs]def run_inference(testLoader, model, filepath, isFileDict=True): """Run the AirNet-SNL model by passing it in the argument. Args: * testLoader (DataLoader): Dataset for running inference. * model (nn.Module): Use this model for inference. * filepath (str): Saved model weights. * isFileDict (bool): True if filepath is a dictionary Returns: Predictions: Tensor of size [nSamples, 1, imgSize, imgSize]. Example: .. code-block:: python :linenos: import airnetSNL.model.airnet_snl as snl import airnetSNL.dataset.dataset_utils as du import torch from torch.utils.data import TensorDataset, DataLoader angles = du.decimateAngles(nAnglesFull=451, downsample=8) imgSize = 336 batchSize = 10 totalSamples = 100 model = snl.AirNetSNL(angles=angles, n_iterations=12, n_cnn=10, imgSize=imgSize, batchSize=batchSize, includeSkipConnection=False) model = model.cuda() filepath = './model.pth' testSinograms = torch.zeros(totalSamples, 1, len(angles), imgSize) testImages = torch.zeros(totalSamples, 1, imgSize, imgSize) testSet = TensorDataset(testSinograms.cpu(), testImages.cpu()) testLoader = DataLoader(testSet, batch_size=batchSize) y_img_pred = run_inference(testLoader, model, filepath) """ if not filepath: raise ValueError("Need to specify filepath to load model.") if isFileDict: checkpoint = torch.load(filepath) model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(torch.load(filepath)) model.eval().cuda() totalLoss = 0 for i, (xSino, yImg) in enumerate(testLoader): xSino_gpu = xSino.to('cuda') pred = model(xSino_gpu).data.cpu() loss = F.mse_loss(pred, yImg) loss += loss.item() totalLoss += loss print(f'totalLoss = {totalLoss}') if i == 0: preds = pred else: preds = torch.cat((preds, pred), dim=0) return preds