Training A ResNet model for Channel Estimation
Now that we have a channel estimation dataset, we can train a model to receive the channel estimates at the pilot locations and predict the channel estimates for the whole channel matrix.
We use PyTorch freamework to train our models but other machine learning tools can also be used. The following diagram shows the Neural Network structure used for the channel estimation model.

So let’s get started by importing the required modules.
[1]:
import numpy as np
import time, datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
Loading dataset
We first load our dataset files generated in the previous step. Then we create three sets of datasets and dataloaders for training, validation, and test.
[2]:
# Define our dataset object
class ChEstDataset(Dataset):
def __init__(self, fileName):
samples, labels = np.load(fileName)
self.samples = np.float32(np.transpose(samples,(0,3,1,2)))
self.labels = np.float32(np.transpose(labels,(0,3,1,2)))
self.numSamples = len(self.samples)
def __len__(self):
return self.numSamples
def __getitem__(self, idx):
return self.samples[idx], self.labels[idx]
# Instantiate 3 datasets for train, validation, and test
trainDs = ChEstDataset("ChestTrain.npy")
validDs = ChEstDataset("ChestValid.npy")
testDs = ChEstDataset("ChestTest.npy")
print(f"{len(trainDs)} Training Samples")
print(f"{len(validDs)} Validation Samples")
print(f"{len(testDs)} Test Samples")
# Create the data loaders
batchSize = 64
trainDl = DataLoader(trainDs, batchSize, shuffle=True)
validDl = DataLoader(validDs, batchSize, shuffle=False)
testDl = DataLoader(testDs, batchSize, shuffle=False)
16000 Training Samples
2400 Validation Samples
2400 Test Samples
[3]:
# Checking GPU availability
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("Using '%s' device."%({'cuda':'Cuda', 'mps':'Metal','cpu':'CPU'}[device]))
Using 'Metal' device.
Creating the model
Now we can create a model object that will be used for training. We first define the ResBlock
(See the picture above)
[4]:
# Define ResNet block
class ResBlock(nn.Module):
def __init__(self, inDepth, midDepth, outDepth, kernel=(3,3), stride=(1,1)):
super().__init__()
if isinstance(stride, int): stride = (stride, stride)
if isinstance(kernel, int): kernel = (kernel, kernel)
self.conv1 = nn.Conv2d(inDepth, midDepth, 1, stride, padding='valid') # 1x1 conv.
self.bn1 = nn.BatchNorm2d(midDepth)
self.conv2 = nn.Conv2d(midDepth, midDepth, kernel, padding='same')
self.bn2 = nn.BatchNorm2d(midDepth)
self.conv3 = nn.Conv2d(midDepth, outDepth, 1, stride, padding='valid') # 1x1 conv.
self.bn3 = nn.BatchNorm2d(outDepth)
self.relu = nn.ReLU(inplace=True)
self.downSampleNet = None
if ((stride != (1,1)) or (inDepth!=outDepth)):
self.downSampleNet = nn.Sequential(nn.Conv2d(inDepth, outDepth, 1, stride), # 1x1 conv.
nn.BatchNorm2d(outDepth) )
for bn in [self.bn1, self.bn2, self.bn3]:
nn.init.ones_(bn.weight)
nn.init.zeros_(bn.bias)
for conv in [self.conv1, self.conv2, self.conv3]:
nn.init.trunc_normal_(conv.weight, std=1/np.sqrt(np.prod(list(conv.weight.shape)[1:])))
nn.init.zeros_(conv.bias)
nn.init.zeros_(self.bn3.weight) # This could improve results. It makes the block start like identity.
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downSampleNet is None: out += x
else: out += self.downSampleNet(x)
out = self.relu(out)
return out
And now we define our Channel Estimator model (ChEstNet
) using 2 instances of ResBlock
defined above together with an additional convolutional layer.
[5]:
# Now the actual ChEstNet model
class ChEstNet(nn.Module):
def __init__(self):
super().__init__()
self.res1 = ResBlock(2, 16, 64, (9,11)) # Res Block 9x11 kernel
self.res2 = ResBlock(64, 16, 64, (3,7)) # Res Block 3x7 kernel
self.res3 = ResBlock(64, 16, 64, (3,7)) # Res Block 3x7 kernel
self.conv = nn.Conv2d(64, 2, 3, padding='same')
nn.init.trunc_normal_(self.conv.weight, std=1/np.sqrt(np.prod(list(self.conv.weight.shape)[1:])))
nn.init.zeros_(self.conv.bias)
def forward(self, x):
out = self.res1(x)
out = self.res2(out)
out = self.res3(out)
out = self.conv(out)
return out
# Instantiate the model and move it the target device
model = ChEstNet().to(device)
Training the model
Now we first create functions for the training and evaluation loops and use them to train the model.
Note: The following cell can take several hours to complete. A file containing the trained parameters of the model is included in this directory, so you can skip the following cells and proceed to the next step.
[6]:
# Training loop for one epoch:
def trainEpoch(dataLoader, model, lossFunction, optimizer):
modelDevice = next(model.parameters()).device
model.train() # Set the model to training mode
lossMin, lossMean, lossMax = np.inf, 0, -np.inf
for batchNo, (batchSamples, batchLabels) in enumerate(dataLoader):
# Compute prediction and loss
batchPredictions = model( batchSamples.to(modelDevice) )
loss = lossFunction(batchPredictions, batchLabels.to(modelDevice))
loss.backward() # Backpropagation
optimizer.step()
optimizer.zero_grad()
lossValue = loss.item()
lossMean += lossValue
if lossValue>lossMax: lossMax = lossValue
if lossValue<lossMin: lossMin = lossValue
lossMean /= len(dataLoader)
return lossMin, lossMean, lossMax
# Evaluation loop:
def evaluate(dataLoader, model, lossFunction):
modelDevice = next(model.parameters()).device
model.eval() # Set the model to evaluation mode
numBatches = len(dataLoader)
lossMean = 0
with torch.no_grad():
for batchNo, (batchSamples, batchLabels) in enumerate(dataLoader):
batchPredictions = model( batchSamples.to(modelDevice) )
lossMean += lossFunction(batchPredictions, batchLabels.to(modelDevice)).item()
lossMean /= numBatches
return lossMean
numEpochs = 100 # Number of epochs
learningRate = (0.0001, 0.000001) # Learning rate starts at 0.0001 and exponentially decays to 0.000001
lossFunction = nn.MSELoss() # Using MSE as Loss
if isinstance(learningRate,tuple): # learningRate is a tuple -> use exponentially decaying learning rate
from torch.optim.lr_scheduler import ExponentialLR
lr1st, lrLast = learningRate
optimizer = torch.optim.Adam(model.parameters(), lr=lr1st)
lrScheduler = ExponentialLR(optimizer, np.exp(np.log(lrLast/lr1st)/(numEpochs-1)))
else:
optimizer = torch.optim.Adam(model.parameters(), lr=learningRate) # learningRate is a number
lrScheduler = None # No LR scheduling needed
t0 = time.time()
print("Epoch Learning Rate Training Loss Validation Loss")
print("----- ------------- ------------- ---------------")
lowestLoss = None
for epoch in range(numEpochs):
print(" %-4d %-10f "%(epoch+1, lrScheduler.get_last_lr()[0]), end="")
lossMin, lossMean, lossMax = trainEpoch(trainDl, model, lossFunction, optimizer)
print("%-10f "%(lossMean), end="")
validLoss = evaluate(validDl, model, lossFunction)
if lowestLoss is None:
lowestLoss = validLoss
print("%-10f "%(validLoss))
elif validLoss<lowestLoss: # This is the best model so far -> Save it
lowestLoss, bestEpoch = validLoss, epoch+1
torch.save(model.state_dict(), 'Models/ChEstModelWeights.pth')
print("%-10f * "%(validLoss)) # The '*' indicates best so far and saving
else:
print("%-10f "%(validLoss))
if lrScheduler is not None: lrScheduler.step()
print("Training complete. (Training Time: %s)"%(str(datetime.timedelta(seconds=int(time.time()-t0)))))
Epoch Learning Rate Training Loss Validation Loss
----- ------------- ------------- ---------------
1 0.000100 0.131095 0.018107
2 0.000095 0.007698 0.004004 *
3 0.000091 0.003329 0.002922 *
4 0.000087 0.002697 0.002642 *
5 0.000083 0.002430 0.002455 *
6 0.000079 0.002270 0.002382 *
7 0.000076 0.002169 0.002182 *
8 0.000072 0.002083 0.002352
9 0.000069 0.002019 0.002460
10 0.000066 0.001962 0.002147 *
11 0.000063 0.001917 0.001909 *
12 0.000060 0.001873 0.002105
13 0.000057 0.001842 0.001950
14 0.000055 0.001809 0.001785 *
15 0.000052 0.001781 0.001917
16 0.000050 0.001757 0.002527
17 0.000048 0.001736 0.001796
18 0.000045 0.001719 0.001724 *
19 0.000043 0.001700 0.001919
20 0.000041 0.001682 0.001789
21 0.000039 0.001668 0.001765
22 0.000038 0.001657 0.001684 *
23 0.000036 0.001651 0.001970
24 0.000034 0.001632 0.001833
25 0.000033 0.001624 0.001627 *
26 0.000031 0.001621 0.001659
27 0.000030 0.001609 0.001677
28 0.000028 0.001600 0.001732
29 0.000027 0.001594 0.001583 *
30 0.000026 0.001586 0.001582 *
31 0.000025 0.001579 0.001688
32 0.000024 0.001575 0.001565 *
33 0.000023 0.001572 0.001834
34 0.000022 0.001563 0.001663
35 0.000021 0.001561 0.001573
36 0.000020 0.001554 0.001563 *
37 0.000019 0.001550 0.001626
38 0.000018 0.001545 0.001555 *
39 0.000017 0.001542 0.001657
40 0.000016 0.001540 0.001564
41 0.000016 0.001534 0.001556
42 0.000015 0.001532 0.001579
43 0.000014 0.001534 0.001529 *
44 0.000014 0.001527 0.001537
45 0.000013 0.001527 0.001547
46 0.000012 0.001521 0.001553
47 0.000012 0.001521 0.001513 *
48 0.000011 0.001517 0.001514
49 0.000011 0.001517 0.001514
50 0.000010 0.001513 0.001499 *
51 0.000010 0.001512 0.001496 *
52 0.000009 0.001509 0.001511
53 0.000009 0.001507 0.001502
54 0.000008 0.001508 0.001525
55 0.000008 0.001506 0.001508
56 0.000008 0.001504 0.001497
57 0.000007 0.001503 0.001489 *
58 0.000007 0.001501 0.001576
59 0.000007 0.001498 0.001489 *
60 0.000006 0.001500 0.001511
61 0.000006 0.001496 0.001501
62 0.000006 0.001496 0.001485 *
63 0.000006 0.001494 0.001509
64 0.000005 0.001497 0.001494
65 0.000005 0.001494 0.001485
66 0.000005 0.001493 0.001508
67 0.000005 0.001490 0.001483 *
68 0.000004 0.001492 0.001482 *
69 0.000004 0.001492 0.001480 *
70 0.000004 0.001490 0.001484
71 0.000004 0.001492 0.001483
72 0.000004 0.001488 0.001497
73 0.000004 0.001486 0.001487
74 0.000003 0.001486 0.001488
75 0.000003 0.001486 0.001477 *
76 0.000003 0.001485 0.001484
77 0.000003 0.001487 0.001480
78 0.000003 0.001484 0.001477 *
79 0.000003 0.001485 0.001484
80 0.000003 0.001482 0.001474 *
81 0.000002 0.001485 0.001479
82 0.000002 0.001482 0.001483
83 0.000002 0.001481 0.001477
84 0.000002 0.001481 0.001475
85 0.000002 0.001482 0.001474
86 0.000002 0.001481 0.001485
87 0.000002 0.001481 0.001471 *
88 0.000002 0.001481 0.001473
89 0.000002 0.001483 0.001472
90 0.000002 0.001479 0.001470 *
91 0.000002 0.001481 0.001474
92 0.000001 0.001481 0.001474
93 0.000001 0.001482 0.001472
94 0.000001 0.001480 0.001478
95 0.000001 0.001481 0.001478
96 0.000001 0.001480 0.001472
97 0.000001 0.001480 0.001472
98 0.000001 0.001480 0.001469 *
99 0.000001 0.001479 0.001475
100 0.000001 0.001477 0.001470
Training complete. (Training Time: 0:53:29)
Evaluating the model
[7]:
testLoss = evaluate(testDl, model, lossFunction)
print(f"Test Loss: %.6f"%(testLoss))
Test Loss: 0.001460
[ ]: