Training A ResNet model for Channel Estimation

Now that we have a channel estimation dataset, we can train a model to receive the channel estimates at the pilot locations and predict the channel estimates for the whole channel matrix.

We use PyTorch freamework to train our models but other machine learning tools can also be used. The following diagram shows the Neural Network structure used for the channel estimation model.

NN Structure

So let’s get started by importing the required modules.

[1]:
import numpy as np
import time, datetime

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR

Loading dataset

We first load our dataset files generated in the previous step. Then we create three sets of datasets and dataloaders for training, validation, and test.

[2]:
# Define our dataset object
class ChEstDataset(Dataset):
    def __init__(self, fileName):
        samples, labels = np.load(fileName)
        self.samples = np.float32(np.transpose(samples,(0,3,1,2)))
        self.labels = np.float32(np.transpose(labels,(0,3,1,2)))
        self.numSamples = len(self.samples)

    def __len__(self):
        return self.numSamples

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

# Instantiate 3 datasets for train, validation, and test
trainDs = ChEstDataset("ChestTrain.npy")
validDs = ChEstDataset("ChestValid.npy")
testDs  = ChEstDataset("ChestTest.npy")

print(f"{len(trainDs)} Training Samples")
print(f"{len(validDs)} Validation Samples")
print(f"{len(testDs)} Test Samples")

# Create the data loaders
batchSize = 64
trainDl = DataLoader(trainDs, batchSize, shuffle=True)
validDl = DataLoader(validDs, batchSize, shuffle=False)
testDl = DataLoader(testDs, batchSize, shuffle=False)

16000 Training Samples
2400 Validation Samples
2400 Test Samples
[3]:
# Checking GPU availability
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("Using '%s' device."%({'cuda':'Cuda', 'mps':'Metal','cpu':'CPU'}[device]))

Using 'Metal' device.

Creating the model

Now we can create a model object that will be used for training. We first define the ResBlock (See the picture above)

[4]:
# Define ResNet block
class ResBlock(nn.Module):
    def __init__(self, inDepth, midDepth, outDepth, kernel=(3,3), stride=(1,1)):
        super().__init__()
        if isinstance(stride, int): stride = (stride, stride)
        if isinstance(kernel, int): kernel = (kernel, kernel)

        self.conv1 = nn.Conv2d(inDepth, midDepth, 1, stride, padding='valid')  # 1x1 conv.
        self.bn1 = nn.BatchNorm2d(midDepth)
        self.conv2 = nn.Conv2d(midDepth, midDepth, kernel, padding='same')
        self.bn2 = nn.BatchNorm2d(midDepth)
        self.conv3 = nn.Conv2d(midDepth, outDepth, 1, stride, padding='valid') # 1x1 conv.
        self.bn3 = nn.BatchNorm2d(outDepth)
        self.relu = nn.ReLU(inplace=True)

        self.downSampleNet = None
        if ((stride != (1,1)) or (inDepth!=outDepth)):
            self.downSampleNet = nn.Sequential(nn.Conv2d(inDepth, outDepth, 1, stride),  # 1x1 conv.
                                               nn.BatchNorm2d(outDepth) )

        for bn in [self.bn1, self.bn2, self.bn3]:
            nn.init.ones_(bn.weight)
            nn.init.zeros_(bn.bias)

        for conv in [self.conv1, self.conv2, self.conv3]:
            nn.init.trunc_normal_(conv.weight, std=1/np.sqrt(np.prod(list(conv.weight.shape)[1:])))
            nn.init.zeros_(conv.bias)

        nn.init.zeros_(self.bn3.weight)  # This could improve results. It makes the block start like identity.

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downSampleNet is None:  out += x
        else:                           out += self.downSampleNet(x)
        out = self.relu(out)

        return out

And now we define our Channel Estimator model (ChEstNet) using 2 instances of ResBlock defined above together with an additional convolutional layer.

[5]:
# Now the actual ChEstNet model
class ChEstNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.res1 = ResBlock(2, 16, 64, (9,11))   # Res Block 9x11 kernel
        self.res2 = ResBlock(64, 16, 64, (3,7))   # Res Block 3x7 kernel
        self.res3 = ResBlock(64, 16, 64, (3,7))   # Res Block 3x7 kernel

        self.conv = nn.Conv2d(64, 2, 3, padding='same')
        nn.init.trunc_normal_(self.conv.weight, std=1/np.sqrt(np.prod(list(self.conv.weight.shape)[1:])))
        nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        out = self.res1(x)
        out = self.res2(out)
        out = self.res3(out)
        out = self.conv(out)
        return out

# Instantiate the model and move it the target device
model = ChEstNet().to(device)

Training the model

Now we first create functions for the training and evaluation loops and use them to train the model.

Note: The following cell can take several hours to complete. A file containing the trained parameters of the model is included in this directory, so you can skip the following cells and proceed to the next step.

[6]:
# Training loop for one epoch:
def trainEpoch(dataLoader, model, lossFunction, optimizer):
    modelDevice = next(model.parameters()).device
    model.train() # Set the model to training mode
    lossMin, lossMean, lossMax = np.inf, 0, -np.inf
    for batchNo, (batchSamples, batchLabels) in enumerate(dataLoader):
        # Compute prediction and loss
        batchPredictions = model( batchSamples.to(modelDevice) )
        loss = lossFunction(batchPredictions, batchLabels.to(modelDevice))

        loss.backward()  # Backpropagation
        optimizer.step()
        optimizer.zero_grad()

        lossValue = loss.item()
        lossMean += lossValue
        if lossValue>lossMax: lossMax = lossValue
        if lossValue<lossMin: lossMin = lossValue

    lossMean /= len(dataLoader)
    return lossMin, lossMean, lossMax

# Evaluation loop:
def evaluate(dataLoader, model, lossFunction):
    modelDevice = next(model.parameters()).device
    model.eval()  # Set the model to evaluation mode
    numBatches = len(dataLoader)
    lossMean = 0

    with torch.no_grad():
        for batchNo, (batchSamples, batchLabels) in enumerate(dataLoader):
            batchPredictions = model( batchSamples.to(modelDevice) )
            lossMean += lossFunction(batchPredictions, batchLabels.to(modelDevice)).item()

    lossMean /= numBatches
    return lossMean

numEpochs = 100                    # Number of epochs
learningRate = (0.0001, 0.000001)  # Learning rate starts at 0.0001 and exponentially decays to 0.000001
lossFunction = nn.MSELoss()        # Using MSE as Loss

if isinstance(learningRate,tuple):  # learningRate is a tuple -> use exponentially decaying learning rate
    from torch.optim.lr_scheduler import ExponentialLR
    lr1st, lrLast = learningRate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr1st)
    lrScheduler = ExponentialLR(optimizer, np.exp(np.log(lrLast/lr1st)/(numEpochs-1)))
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)  # learningRate is a number
    lrScheduler = None                                                 # No LR scheduling needed

t0 = time.time()
print("Epoch   Learning Rate   Training Loss   Validation Loss")
print("-----   -------------   -------------   ---------------")
lowestLoss = None
for epoch in range(numEpochs):
    print(" %-4d     %-10f      "%(epoch+1, lrScheduler.get_last_lr()[0]), end="")
    lossMin, lossMean, lossMax = trainEpoch(trainDl, model, lossFunction, optimizer)
    print("%-10f      "%(lossMean), end="")
    validLoss = evaluate(validDl, model, lossFunction)
    if lowestLoss is None:
        lowestLoss = validLoss
        print("%-10f   "%(validLoss))
    elif validLoss<lowestLoss:          # This is the best model so far -> Save it
        lowestLoss, bestEpoch = validLoss, epoch+1
        torch.save(model.state_dict(), 'Models/ChEstModelWeights.pth')
        print("%-10f * "%(validLoss))   # The '*' indicates best so far and saving
    else:
        print("%-10f   "%(validLoss))

    if lrScheduler is not None: lrScheduler.step()

print("Training complete. (Training Time: %s)"%(str(datetime.timedelta(seconds=int(time.time()-t0)))))

Epoch   Learning Rate   Training Loss   Validation Loss
-----   -------------   -------------   ---------------
 1        0.000100        0.131095        0.018107
 2        0.000095        0.007698        0.004004   *
 3        0.000091        0.003329        0.002922   *
 4        0.000087        0.002697        0.002642   *
 5        0.000083        0.002430        0.002455   *
 6        0.000079        0.002270        0.002382   *
 7        0.000076        0.002169        0.002182   *
 8        0.000072        0.002083        0.002352
 9        0.000069        0.002019        0.002460
 10       0.000066        0.001962        0.002147   *
 11       0.000063        0.001917        0.001909   *
 12       0.000060        0.001873        0.002105
 13       0.000057        0.001842        0.001950
 14       0.000055        0.001809        0.001785   *
 15       0.000052        0.001781        0.001917
 16       0.000050        0.001757        0.002527
 17       0.000048        0.001736        0.001796
 18       0.000045        0.001719        0.001724   *
 19       0.000043        0.001700        0.001919
 20       0.000041        0.001682        0.001789
 21       0.000039        0.001668        0.001765
 22       0.000038        0.001657        0.001684   *
 23       0.000036        0.001651        0.001970
 24       0.000034        0.001632        0.001833
 25       0.000033        0.001624        0.001627   *
 26       0.000031        0.001621        0.001659
 27       0.000030        0.001609        0.001677
 28       0.000028        0.001600        0.001732
 29       0.000027        0.001594        0.001583   *
 30       0.000026        0.001586        0.001582   *
 31       0.000025        0.001579        0.001688
 32       0.000024        0.001575        0.001565   *
 33       0.000023        0.001572        0.001834
 34       0.000022        0.001563        0.001663
 35       0.000021        0.001561        0.001573
 36       0.000020        0.001554        0.001563   *
 37       0.000019        0.001550        0.001626
 38       0.000018        0.001545        0.001555   *
 39       0.000017        0.001542        0.001657
 40       0.000016        0.001540        0.001564
 41       0.000016        0.001534        0.001556
 42       0.000015        0.001532        0.001579
 43       0.000014        0.001534        0.001529   *
 44       0.000014        0.001527        0.001537
 45       0.000013        0.001527        0.001547
 46       0.000012        0.001521        0.001553
 47       0.000012        0.001521        0.001513   *
 48       0.000011        0.001517        0.001514
 49       0.000011        0.001517        0.001514
 50       0.000010        0.001513        0.001499   *
 51       0.000010        0.001512        0.001496   *
 52       0.000009        0.001509        0.001511
 53       0.000009        0.001507        0.001502
 54       0.000008        0.001508        0.001525
 55       0.000008        0.001506        0.001508
 56       0.000008        0.001504        0.001497
 57       0.000007        0.001503        0.001489   *
 58       0.000007        0.001501        0.001576
 59       0.000007        0.001498        0.001489   *
 60       0.000006        0.001500        0.001511
 61       0.000006        0.001496        0.001501
 62       0.000006        0.001496        0.001485   *
 63       0.000006        0.001494        0.001509
 64       0.000005        0.001497        0.001494
 65       0.000005        0.001494        0.001485
 66       0.000005        0.001493        0.001508
 67       0.000005        0.001490        0.001483   *
 68       0.000004        0.001492        0.001482   *
 69       0.000004        0.001492        0.001480   *
 70       0.000004        0.001490        0.001484
 71       0.000004        0.001492        0.001483
 72       0.000004        0.001488        0.001497
 73       0.000004        0.001486        0.001487
 74       0.000003        0.001486        0.001488
 75       0.000003        0.001486        0.001477   *
 76       0.000003        0.001485        0.001484
 77       0.000003        0.001487        0.001480
 78       0.000003        0.001484        0.001477   *
 79       0.000003        0.001485        0.001484
 80       0.000003        0.001482        0.001474   *
 81       0.000002        0.001485        0.001479
 82       0.000002        0.001482        0.001483
 83       0.000002        0.001481        0.001477
 84       0.000002        0.001481        0.001475
 85       0.000002        0.001482        0.001474
 86       0.000002        0.001481        0.001485
 87       0.000002        0.001481        0.001471   *
 88       0.000002        0.001481        0.001473
 89       0.000002        0.001483        0.001472
 90       0.000002        0.001479        0.001470   *
 91       0.000002        0.001481        0.001474
 92       0.000001        0.001481        0.001474
 93       0.000001        0.001482        0.001472
 94       0.000001        0.001480        0.001478
 95       0.000001        0.001481        0.001478
 96       0.000001        0.001480        0.001472
 97       0.000001        0.001480        0.001472
 98       0.000001        0.001480        0.001469   *
 99       0.000001        0.001479        0.001475
 100      0.000001        0.001477        0.001470
Training complete. (Training Time: 0:53:29)

Evaluating the model

[7]:
testLoss = evaluate(testDl, model, lossFunction)
print(f"Test Loss: %.6f"%(testLoss))
Test Loss: 0.001460
[ ]: