{ "cells": [ { "cell_type": "markdown", "id": "cdd93edc", "metadata": {}, "source": [ "# Training A ResNet model for Channel Estimation\n", "Now that we have a channel estimation dataset, we can train a model to receive the channel estimates at the pilot locations and predict the channel estimates for the whole channel matrix.\n", "\n", "We use [PyTorch](https://pytorch.org) freamework to train our models but other machine learning tools can also be used. The following diagram shows the Neural Network structure used for the channel estimation model.\n", "\n", "![NN Structure](NN.png)\n", "\n", "So let's get started by importing the required modules." ] }, { "cell_type": "code", "execution_count": 1, "id": "3636339c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import time, datetime\n", "\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import Dataset, DataLoader\n", "from torch.optim.lr_scheduler import ExponentialLR\n" ] }, { "cell_type": "markdown", "id": "aaf34bd5", "metadata": {}, "source": [ "## Loading dataset\n", "We first load our dataset files generated in the [previous step](MLChestDataGen.ipynb). Then we create three sets of datasets and dataloaders for training, validation, and test." ] }, { "cell_type": "code", "execution_count": 2, "id": "62358b6f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "16000 Training Samples\n", "2400 Validation Samples\n", "2400 Test Samples\n" ] } ], "source": [ "# Define our dataset object\n", "class ChEstDataset(Dataset):\n", " def __init__(self, fileName):\n", " samples, labels = np.load(fileName)\n", " self.samples = np.float32(np.transpose(samples,(0,3,1,2)))\n", " self.labels = np.float32(np.transpose(labels,(0,3,1,2)))\n", " self.numSamples = len(self.samples)\n", "\n", " def __len__(self):\n", " return self.numSamples\n", "\n", " def __getitem__(self, idx):\n", " return self.samples[idx], self.labels[idx]\n", "\n", "# Instantiate 3 datasets for train, validation, and test\n", "trainDs = ChEstDataset(\"ChestTrain.npy\")\n", "validDs = ChEstDataset(\"ChestValid.npy\")\n", "testDs = ChEstDataset(\"ChestTest.npy\")\n", "\n", "print(f\"{len(trainDs)} Training Samples\")\n", "print(f\"{len(validDs)} Validation Samples\")\n", "print(f\"{len(testDs)} Test Samples\")\n", "\n", "# Create the data loaders\n", "batchSize = 64\n", "trainDl = DataLoader(trainDs, batchSize, shuffle=True)\n", "validDl = DataLoader(validDs, batchSize, shuffle=False)\n", "testDl = DataLoader(testDs, batchSize, shuffle=False)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "74f53fe9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using 'Metal' device.\n" ] } ], "source": [ "# Checking GPU availability\n", "device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", "print(\"Using '%s' device.\"%({'cuda':'Cuda', 'mps':'Metal','cpu':'CPU'}[device]))\n" ] }, { "cell_type": "markdown", "id": "696d753a", "metadata": {}, "source": [ "## Creating the model\n", "Now we can create a model object that will be used for training. We first define the ``ResBlock`` (See the picture above)" ] }, { "cell_type": "code", "execution_count": 4, "id": "ed570201", "metadata": {}, "outputs": [], "source": [ "# Define ResNet block\n", "class ResBlock(nn.Module):\n", " def __init__(self, inDepth, midDepth, outDepth, kernel=(3,3), stride=(1,1)):\n", " super().__init__()\n", " if isinstance(stride, int): stride = (stride, stride)\n", " if isinstance(kernel, int): kernel = (kernel, kernel)\n", " \n", " self.conv1 = nn.Conv2d(inDepth, midDepth, 1, stride, padding='valid') # 1x1 conv.\n", " self.bn1 = nn.BatchNorm2d(midDepth)\n", " self.conv2 = nn.Conv2d(midDepth, midDepth, kernel, padding='same')\n", " self.bn2 = nn.BatchNorm2d(midDepth)\n", " self.conv3 = nn.Conv2d(midDepth, outDepth, 1, stride, padding='valid') # 1x1 conv.\n", " self.bn3 = nn.BatchNorm2d(outDepth)\n", " self.relu = nn.ReLU(inplace=True)\n", " \n", " self.downSampleNet = None\n", " if ((stride != (1,1)) or (inDepth!=outDepth)):\n", " self.downSampleNet = nn.Sequential(nn.Conv2d(inDepth, outDepth, 1, stride), # 1x1 conv.\n", " nn.BatchNorm2d(outDepth) )\n", "\n", " for bn in [self.bn1, self.bn2, self.bn3]:\n", " nn.init.ones_(bn.weight)\n", " nn.init.zeros_(bn.bias)\n", " \n", " for conv in [self.conv1, self.conv2, self.conv3]:\n", " nn.init.trunc_normal_(conv.weight, std=1/np.sqrt(np.prod(list(conv.weight.shape)[1:]))) \n", " nn.init.zeros_(conv.bias)\n", " \n", " nn.init.zeros_(self.bn3.weight) # This could improve results. It makes the block start like identity.\n", " \n", " def forward(self, x):\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", " \n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", " \n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", " \n", " if self.downSampleNet is None: out += x\n", " else: out += self.downSampleNet(x)\n", " out = self.relu(out)\n", " \n", " return out" ] }, { "cell_type": "markdown", "id": "f912336e", "metadata": {}, "source": [ "And now we define our Channel Estimator model (``ChEstNet``) using 2 instances of ``ResBlock`` defined above together with an additional convolutional layer." ] }, { "cell_type": "code", "execution_count": 5, "id": "56f18c27", "metadata": {}, "outputs": [], "source": [ "# Now the actual ChEstNet model\n", "class ChEstNet(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.res1 = ResBlock(2, 16, 64, (9,11)) # Res Block 9x11 kernel\n", " self.res2 = ResBlock(64, 16, 64, (3,7)) # Res Block 3x7 kernel\n", " self.res3 = ResBlock(64, 16, 64, (3,7)) # Res Block 3x7 kernel\n", " \n", " self.conv = nn.Conv2d(64, 2, 3, padding='same')\n", " nn.init.trunc_normal_(self.conv.weight, std=1/np.sqrt(np.prod(list(self.conv.weight.shape)[1:]))) \n", " nn.init.zeros_(self.conv.bias)\n", "\n", " def forward(self, x):\n", " out = self.res1(x)\n", " out = self.res2(out)\n", " out = self.res3(out)\n", " out = self.conv(out)\n", " return out\n", " \n", "# Instantiate the model and move it the target device\n", "model = ChEstNet().to(device)" ] }, { "cell_type": "markdown", "id": "300f065a", "metadata": {}, "source": [ "## Training the model\n", "Now we first create functions for the training and evaluation loops and use them to train the model.\n", "\n", "**Note**: The following cell can take several hours to complete. A file containing the trained parameters of the model is included in this directory, so you can skip the following cells and proceed to the [next step](MLChestEvaluateTorch.ipynb)." ] }, { "cell_type": "code", "execution_count": 6, "id": "311e963e", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch Learning Rate Training Loss Validation Loss\n", "----- ------------- ------------- ---------------\n", " 1 0.000100 0.131095 0.018107 \n", " 2 0.000095 0.007698 0.004004 * \n", " 3 0.000091 0.003329 0.002922 * \n", " 4 0.000087 0.002697 0.002642 * \n", " 5 0.000083 0.002430 0.002455 * \n", " 6 0.000079 0.002270 0.002382 * \n", " 7 0.000076 0.002169 0.002182 * \n", " 8 0.000072 0.002083 0.002352 \n", " 9 0.000069 0.002019 0.002460 \n", " 10 0.000066 0.001962 0.002147 * \n", " 11 0.000063 0.001917 0.001909 * \n", " 12 0.000060 0.001873 0.002105 \n", " 13 0.000057 0.001842 0.001950 \n", " 14 0.000055 0.001809 0.001785 * \n", " 15 0.000052 0.001781 0.001917 \n", " 16 0.000050 0.001757 0.002527 \n", " 17 0.000048 0.001736 0.001796 \n", " 18 0.000045 0.001719 0.001724 * \n", " 19 0.000043 0.001700 0.001919 \n", " 20 0.000041 0.001682 0.001789 \n", " 21 0.000039 0.001668 0.001765 \n", " 22 0.000038 0.001657 0.001684 * \n", " 23 0.000036 0.001651 0.001970 \n", " 24 0.000034 0.001632 0.001833 \n", " 25 0.000033 0.001624 0.001627 * \n", " 26 0.000031 0.001621 0.001659 \n", " 27 0.000030 0.001609 0.001677 \n", " 28 0.000028 0.001600 0.001732 \n", " 29 0.000027 0.001594 0.001583 * \n", " 30 0.000026 0.001586 0.001582 * \n", " 31 0.000025 0.001579 0.001688 \n", " 32 0.000024 0.001575 0.001565 * \n", " 33 0.000023 0.001572 0.001834 \n", " 34 0.000022 0.001563 0.001663 \n", " 35 0.000021 0.001561 0.001573 \n", " 36 0.000020 0.001554 0.001563 * \n", " 37 0.000019 0.001550 0.001626 \n", " 38 0.000018 0.001545 0.001555 * \n", " 39 0.000017 0.001542 0.001657 \n", " 40 0.000016 0.001540 0.001564 \n", " 41 0.000016 0.001534 0.001556 \n", " 42 0.000015 0.001532 0.001579 \n", " 43 0.000014 0.001534 0.001529 * \n", " 44 0.000014 0.001527 0.001537 \n", " 45 0.000013 0.001527 0.001547 \n", " 46 0.000012 0.001521 0.001553 \n", " 47 0.000012 0.001521 0.001513 * \n", " 48 0.000011 0.001517 0.001514 \n", " 49 0.000011 0.001517 0.001514 \n", " 50 0.000010 0.001513 0.001499 * \n", " 51 0.000010 0.001512 0.001496 * \n", " 52 0.000009 0.001509 0.001511 \n", " 53 0.000009 0.001507 0.001502 \n", " 54 0.000008 0.001508 0.001525 \n", " 55 0.000008 0.001506 0.001508 \n", " 56 0.000008 0.001504 0.001497 \n", " 57 0.000007 0.001503 0.001489 * \n", " 58 0.000007 0.001501 0.001576 \n", " 59 0.000007 0.001498 0.001489 * \n", " 60 0.000006 0.001500 0.001511 \n", " 61 0.000006 0.001496 0.001501 \n", " 62 0.000006 0.001496 0.001485 * \n", " 63 0.000006 0.001494 0.001509 \n", " 64 0.000005 0.001497 0.001494 \n", " 65 0.000005 0.001494 0.001485 \n", " 66 0.000005 0.001493 0.001508 \n", " 67 0.000005 0.001490 0.001483 * \n", " 68 0.000004 0.001492 0.001482 * \n", " 69 0.000004 0.001492 0.001480 * \n", " 70 0.000004 0.001490 0.001484 \n", " 71 0.000004 0.001492 0.001483 \n", " 72 0.000004 0.001488 0.001497 \n", " 73 0.000004 0.001486 0.001487 \n", " 74 0.000003 0.001486 0.001488 \n", " 75 0.000003 0.001486 0.001477 * \n", " 76 0.000003 0.001485 0.001484 \n", " 77 0.000003 0.001487 0.001480 \n", " 78 0.000003 0.001484 0.001477 * \n", " 79 0.000003 0.001485 0.001484 \n", " 80 0.000003 0.001482 0.001474 * \n", " 81 0.000002 0.001485 0.001479 \n", " 82 0.000002 0.001482 0.001483 \n", " 83 0.000002 0.001481 0.001477 \n", " 84 0.000002 0.001481 0.001475 \n", " 85 0.000002 0.001482 0.001474 \n", " 86 0.000002 0.001481 0.001485 \n", " 87 0.000002 0.001481 0.001471 * \n", " 88 0.000002 0.001481 0.001473 \n", " 89 0.000002 0.001483 0.001472 \n", " 90 0.000002 0.001479 0.001470 * \n", " 91 0.000002 0.001481 0.001474 \n", " 92 0.000001 0.001481 0.001474 \n", " 93 0.000001 0.001482 0.001472 \n", " 94 0.000001 0.001480 0.001478 \n", " 95 0.000001 0.001481 0.001478 \n", " 96 0.000001 0.001480 0.001472 \n", " 97 0.000001 0.001480 0.001472 \n", " 98 0.000001 0.001480 0.001469 * \n", " 99 0.000001 0.001479 0.001475 \n", " 100 0.000001 0.001477 0.001470 \n", "Training complete. (Training Time: 0:53:29)\n" ] } ], "source": [ "# Training loop for one epoch:\n", "def trainEpoch(dataLoader, model, lossFunction, optimizer): \n", " modelDevice = next(model.parameters()).device\n", " model.train() # Set the model to training mode\n", " lossMin, lossMean, lossMax = np.inf, 0, -np.inf\n", " for batchNo, (batchSamples, batchLabels) in enumerate(dataLoader):\n", " # Compute prediction and loss\n", " batchPredictions = model( batchSamples.to(modelDevice) )\n", " loss = lossFunction(batchPredictions, batchLabels.to(modelDevice))\n", "\n", " loss.backward() # Backpropagation\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " lossValue = loss.item()\n", " lossMean += lossValue\n", " if lossValue>lossMax: lossMax = lossValue\n", " if lossValue use exponentially decaying learning rate\n", " from torch.optim.lr_scheduler import ExponentialLR\n", " lr1st, lrLast = learningRate\n", " optimizer = torch.optim.Adam(model.parameters(), lr=lr1st)\n", " lrScheduler = ExponentialLR(optimizer, np.exp(np.log(lrLast/lr1st)/(numEpochs-1)))\n", "else:\n", " optimizer = torch.optim.Adam(model.parameters(), lr=learningRate) # learningRate is a number \n", " lrScheduler = None # No LR scheduling needed\n", " \n", "t0 = time.time()\n", "print(\"Epoch Learning Rate Training Loss Validation Loss\")\n", "print(\"----- ------------- ------------- ---------------\")\n", "lowestLoss = None\n", "for epoch in range(numEpochs):\n", " print(\" %-4d %-10f \"%(epoch+1, lrScheduler.get_last_lr()[0]), end=\"\")\n", " lossMin, lossMean, lossMax = trainEpoch(trainDl, model, lossFunction, optimizer)\n", " print(\"%-10f \"%(lossMean), end=\"\")\n", " validLoss = evaluate(validDl, model, lossFunction)\n", " if lowestLoss is None:\n", " lowestLoss = validLoss\n", " print(\"%-10f \"%(validLoss))\n", " elif validLoss Save it\n", " lowestLoss, bestEpoch = validLoss, epoch+1\n", " torch.save(model.state_dict(), 'Models/ChEstModelWeights.pth')\n", " print(\"%-10f * \"%(validLoss)) # The '*' indicates best so far and saving\n", " else:\n", " print(\"%-10f \"%(validLoss))\n", "\n", " if lrScheduler is not None: lrScheduler.step()\n", "\n", "print(\"Training complete. (Training Time: %s)\"%(str(datetime.timedelta(seconds=int(time.time()-t0)))))\n" ] }, { "cell_type": "markdown", "id": "6e5ccd80", "metadata": {}, "source": [ "## Evaluating the model" ] }, { "cell_type": "code", "execution_count": 7, "id": "2ad038dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.001460\n" ] } ], "source": [ "testLoss = evaluate(testDl, model, lossFunction)\n", "print(f\"Test Loss: %.6f\"%(testLoss))" ] }, { "cell_type": "code", "execution_count": null, "id": "e8f2114a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }