Evaluating the Trained Channel Estimator
Now that we have a trained model, we can use it in the communication pipeline and compare its performance with baselines. In this case we compare it with perfect channel estimation and NeoRadium’s Least-Square channel estimation method.
The following diagram shows the pipeline used for evaluation of our deep-learning-based channel estimator.

So, lets get started by importing the required modules.
[1]:
import numpy as np
import scipy.io
import time
import matplotlib.pyplot as plt
from neoradium import Carrier, PDSCH, CdlChannel, AntennaPanel, Grid, random
import torch
from torch import nn
Load the trained mode
Here we define the channel estimator model and initialize it with the trained parameters.
[2]:
# ResNet Block
class ResBlock(nn.Module):
def __init__(self, inDepth, midDepth, outDepth, kernel=(3,3), stride=(1,1)):
super().__init__()
if isinstance(stride, int): stride = (stride, stride)
if isinstance(kernel, int): kernel = (kernel, kernel)
self.conv1 = nn.Conv2d(inDepth, midDepth, 1, stride, padding='valid') # 1x1 conv.
self.bn1 = nn.BatchNorm2d(midDepth)
self.conv2 = nn.Conv2d(midDepth, midDepth, kernel, padding='same')
self.bn2 = nn.BatchNorm2d(midDepth)
self.conv3 = nn.Conv2d(midDepth, outDepth, 1, stride, padding='valid') # 1x1 conv.
self.bn3 = nn.BatchNorm2d(outDepth)
self.relu = nn.ReLU(inplace=True)
self.downSampleNet = None
if ((stride != (1,1)) or (inDepth!=outDepth)):
self.downSampleNet = nn.Sequential(nn.Conv2d(inDepth, outDepth, 1, stride), # 1x1 conv.
nn.BatchNorm2d(outDepth) )
for bn in [self.bn1, self.bn2, self.bn3]:
nn.init.ones_(bn.weight)
nn.init.zeros_(bn.bias)
for conv in [self.conv1, self.conv2, self.conv3]:
nn.init.trunc_normal_(conv.weight, std=1/np.sqrt(np.prod(list(conv.weight.shape)[1:])))
nn.init.zeros_(conv.bias)
nn.init.zeros_(self.bn3.weight) # This could improve results. It makes the block start like identity.
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downSampleNet is None: out += x
else: out += self.downSampleNet(x)
out = self.relu(out)
return out
# The Channel Estimator model
class ChEstNet(nn.Module):
def __init__(self):
super().__init__()
self.res1 = ResBlock(2, 16, 64, (9,11)) # Res Block 9x11 kernel
self.res2 = ResBlock(64, 16, 64, (3,7)) # Res Block 3x7 kernel
self.res3 = ResBlock(64, 16, 64, (3,7)) # Res Block 3x7 kernel
self.conv = nn.Conv2d(64, 2, 3, padding='same')
nn.init.trunc_normal_(self.conv.weight, std=1/np.sqrt(np.prod(list(self.conv.weight.shape)[1:])))
nn.init.zeros_(self.conv.bias)
def forward(self, x):
out = self.res1(x)
out = self.res2(out)
out = self.res3(out)
out = self.conv(out)
return out
# Checking GPU availability
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("Using '%s' device."%({'cuda':'Cuda', 'mps':'Metal','cpu':'CPU'}[device]))
# Instantiate the model and move it to the target device
model = ChEstNet().to(device)
# Load the trained model parameters:
# Note: The model file "ChEstModel-300.pth" was trained with 300 epochs using a parameter search. It performs
# better than the model trained in the previous step with 100 epochs. You can try both by choosing the
# corresponding line below.
model.load_state_dict(torch.load('Models/ChEstModelWeights.pth', map_location=device, weights_only=True)); # From prev. step.
# model.load_state_dict(torch.load('Models/ChEstModel-300.pth', map_location=device, weights_only=True)); # Trained for 300 Epochs
model.eval(); # Set the model to evaluation mode
Using 'Metal' device.
mlChanEst
function
The mlChanEst
function in the following cell receives the DMRS information, the received resource grid, and the trained model as input. It first calculates the channel estimates at the pilot locations using LS method and then converts these estimates to a set of L x K
complex matrixes that are fed to the model for inference. (L
is the number of OFDM symbols per slot and K
is the number of subcarriers)
The model outputs another set of L x K
matrixes which contain the predicted channel information. These matrixes are then re-packaged as a 4-D complex numpy array that is returned as the estimated channel.
[3]:
def mlChanEst(dmrs, rxGrid, model):
rsGrid = rxGrid.bwp.createGrid( len(dmrs.pxxch.portSet) ) # Create an empty resource grid
dmrs.populateGrid(rsGrid) # Populate the grid with DMRS values
rsIndexes = rsGrid.getReIndexes("DMRS") # This contains the locations of DMRS values
rr, ll, kk = rxGrid.shape # Number of RX antenna, Number of symbols, Number of subcarriers
pp, ll2, kk2 = rsGrid.shape # Number of Ports (From DMRS)
if (ll!=ll2) or (kk!=kk2): raise ValueError("The Gird size (%dx%d) does not match the DMRs (%dx%d)."%(ll,kk,ll2,kk2))
modelIn = []
for p in range(pp): # For each DMRS port (i.e. each layer)
portLs = rsIndexes[1][(rsIndexes[0]==p)] # Indexes of symbols containing pilots in this port
portKs = rsIndexes[2][(rsIndexes[0]==p)] # Indexes of subcarriers containing pilots in this port
ls = np.unique(portLs) # Unique Indexes of symbols containing pilots in this port
ks = portKs[portLs==ls[0]] # Unique Indexes of subcarriers containing pilots in this port
numLs, numKs = len(ls), len(ks) # Number of OFDM symbols and number of subcarriers
pilotValues = rsGrid[p,ls,:][:,ks] # Pilot values in this port
rxValues = rxGrid.grid[:,ls,:][:,:,ks] # Received values for pilot signals
hEst = np.transpose(rxValues/pilotValues[None,:,:], (1,2,0)) # Channel estimates at pilot locations (L,K,Nr)
for r in range(rr): # For each receiver antenna
inH = np.zeros((2,)+rxGrid.shape[1:], dtype=np.float64) # Create one 3D matrix with all zeros
for li,l in enumerate(ls):
inH[0,l,ks] = hEst[li,:,r].real # Set the LS estimates at pilot location (Real)
inH[1,l,ks] = hEst[li,:,r].imag # Set the LS estimates at pilot location (Imaginary)
modelIn += [ inH ]
# Package all inputs as a batch in a PyTorch tensor
modelIn = torch.from_numpy( np.float32(np.stack(modelIn)) ).to(device)
with torch.no_grad(): modelOut = model(modelIn) # Run inference for the whole batch
modelOut = modelOut.cpu().numpy() # Bring the results back to CPU and convert to numpy
estChan = np.transpose( modelOut.reshape((pp,rr)+modelOut.shape[1:]), (3,4,1,0,2) ) # Convert to a 5-D tensor
estChan = estChan[...,0] + 1j*estChan[...,1] # Convert to a 4-D complex tensor
return estChan
Evaluation Pipeline
The following cell implements the evaluation pipeline as shown above. It runs the pipeline 3 times with perfect, ML, and LS channel estimation methods and prints the results at the end. As it can be seen, the ML-based channel estimation performs better than LS method which is based on interpolation.
[4]:
numFrames = 2 # Number of time-domain frames
snrDbs = [5,10,15,20,25] # SNR values (in dB) for which we want to evaluate the model
freqDomain = False # Set to True to apply channel in frequency domain
carrier = Carrier(numRbs=51, spacing=30) # Create a carrier with 51 RBs and 30KHz subcarrier spacing
bwp = carrier.curBwp # The only bandwidth part in the carrier
# Create a PDSCH object
pdsch = PDSCH(bwp, interleavingBundleSize=0, numLayers=2, nID=carrier.cellId, modulation="16QAM")
pdsch.setDMRS(prgSize=0, configType=2, additionalPos=2) # Specify the DMRS configuration
numSlots = bwp.slotsPerFrame*numFrames # Total number of slots
results = {} # Dictionary to save the results
for chanEstMethod in ["Perfect", "ML", "LS"]: # Three different channel estimation methods
results[chanEstMethod] = {}
print("\nSimulating end-to-end for \"%s\", with \"%s\" channel estimation, in %s domain."%
("16QAM", chanEstMethod, "frequency" if freqDomain else "time"))
print("SNR(dB) Total Bits Bit Errors BER(%) time(Sec.)")
print("------- ---------- ---------- ------ ----------")
for snrDb in snrDbs: # For each SNR value in snrDbs
random.setSeed(123) # Making the results reproducible for each SNR
t0 = time.time() # Start time for each SNR
carrier.slotNo = 0
# Creating a CdlChannel object
channel = CdlChannel('C', delaySpread=300, carrierFreq=4e9, dopplerShift=5,
txAntenna = AntennaPanel([2,2], polarization="x"), # 8 TX antenna
rxAntenna = AntennaPanel([1,1], polarization="+"), # 2 RX antenna
seed = 123,
timing = "nearest")
bitErrors = 0
totalBits = 0
for slotNo in range(numSlots):
grid = pdsch.getGrid() # Create a resource grid populated with DMRS
numBits = pdsch.getBitSizes(grid)[0] # Number of bits available in the resource grid
txBits = random.bits(numBits) # Create random binary data
pdsch.populateGrid(grid, txBits) # Map/modulate the data to the resource grid
# Store the indexes of the PDSCH data in pdschIndexes to be used later.
pdschIndexes = pdsch.getReIndexes(grid, "PDSCH")
# Getting the Precoding Matrix, and precoding the resource grid
channelMatrix = channel.getChannelMatrix(bwp) # Get the channel matrix
precoder = pdsch.getPrecodingMatrix(channelMatrix) # Get the precoder matrix from PDSCH object
precodedGrid = grid.precode(precoder) # Perform the precoding
if freqDomain:
rxGrid = precodedGrid.applyChannel(channelMatrix) # Apply the channel in frequency domain
rxGrid = rxGrid.addNoise(snrDb=snrDb) # Add noise
else:
txWaveform = precodedGrid.ofdmModulate() # OFDM Modulation
maxDelay = channel.getMaxDelay() # Get the max. channel delay
txWaveform = txWaveform.pad(maxDelay) # Pad with zeros
rxWaveform = channel.applyToSignal(txWaveform) # Apply channel in time domain
noisyRxWaveform = rxWaveform.addNoise(snrDb=snrDb, nFFT=bwp.nFFT) # Add noise
offset = channel.getTimingOffset() # Get timing info for synchronization
syncedWaveform = noisyRxWaveform.sync(offset) # Synchronization
rxGrid = syncedWaveform.ofdmDemodulate(bwp) # OFDM demodulation
if chanEstMethod == "Perfect": # Perfect Channel Estimation
estChannelMatrix = channelMatrix @ precoder[None,...]
elif chanEstMethod == "LS": # LS + Interpolation Channel Estimation
estChannelMatrix, noiseEst = rxGrid.estimateChannelLS(pdsch.dmrs, polarInt=False,
kernel='linear')
elif chanEstMethod == "ML": # ML-Based Channel Estimation
estChannelMatrix = mlChanEst(pdsch.dmrs, rxGrid, model)
else: assert(0)
eqGrid, llrScales = rxGrid.equalize(estChannelMatrix) # Equalization
rxBits = pdsch.getHardBitsFromGrid(eqGrid, pdschIndexes)[0] # Demodulation
bitErrors += np.abs(rxBits-txBits).sum() # Calculating number of bit errors
totalBits += numBits
print("\r %3d %8d %8d %6.2f %6.2f"%(snrDb, totalBits, bitErrors,
bitErrors*100/totalBits,time.time()-t0),end='')
carrier.goNext() # Prepare the carrier object for the next slot
channel.goNext() # Prepare the channel model for the next slot
dt = time.time()-t0 # Total time for this SNR
results[chanEstMethod][snrDb] = {"totalBits":totalBits,
"bitErrors":bitErrors,
"BER": bitErrors*100/totalBits,
"Time": dt}
print("\r %3d %8d %8d %6.2f %6.2f"%(snrDb, totalBits, bitErrors,
bitErrors*100/totalBits, dt))
# Compare the results in a plot:
for i,chanEstMethod in enumerate(['Perfect', 'ML', 'LS']):
bers = [results[chanEstMethod][snrDb]["BER"] for snrDb in snrDbs]
plt.plot(snrDbs, bers, label=chanEstMethod)
plt.legend()
plt.title("Bit Error Rate for different mothods of Channel Estimation.");
plt.grid()
plt.xlabel("SNR")
plt.xticks(snrDbs)
plt.ylabel("BER (%)")
plt.yscale('log')
plt.show()
Simulating end-to-end for "16QAM", with "Perfect" channel estimation, in time domain.
SNR(dB) Total Bits Bit Errors BER(%) time(Sec.)
------- ---------- ---------- ------ ----------
5 2545920 1036799 40.72 5.68
10 2545920 850716 33.41 5.74
15 2545920 606267 23.81 5.72
20 2545920 350625 13.77 5.80
25 2545920 145817 5.73 5.75
Simulating end-to-end for "16QAM", with "ML" channel estimation, in time domain.
SNR(dB) Total Bits Bit Errors BER(%) time(Sec.)
------- ---------- ---------- ------ ----------
5 2545920 1083427 42.56 7.34
10 2545920 887882 34.87 6.49
15 2545920 637612 25.04 6.28
20 2545920 389796 15.31 6.18
25 2545920 197993 7.78 6.21
Simulating end-to-end for "16QAM", with "LS" channel estimation, in time domain.
SNR(dB) Total Bits Bit Errors BER(%) time(Sec.)
------- ---------- ---------- ------ ----------
5 2545920 1138786 44.73 5.85
10 2545920 968019 38.02 5.90
15 2545920 717426 28.18 5.93
20 2545920 448406 17.61 5.91
25 2545920 217902 8.56 6.01

[ ]: