Posted on

In the past I've written my own neural network from scratch and also used GPU-accelerated deep learning frameworks, but never tried to write any GPU-accelerated code myself. I was curious how hard that would be, so I decided to do that using NVIDIA's CUDA programming language.

Setting up the dev environment

Setting up CUDA to run directly locally is quite difficult and I struggled with getting it to work (even somehow corrupting my Windows installation?) so I decided to just use Docker instead. NVIDIA provides the NVIDIA Container Toolkit for running CUDA-enabled Docker containers on Linux. After setting up the toolkit and verifying that it worked, I set up my own container based on the images that NVIDIA provides for CUDA development.

# Dockerfile

FROM nvidia/cuda:11.3.1-devel-ubuntu20.04

# Set up time zone so cmake install doesn't hang
ENV TZ=America/Chicago
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone

# Install extras
RUN apt-get update \
    && apt-get upgrade -y \
    && apt-get install git -y \
    && apt-get install -y --no-install-recommends cuda-samples-11-3 \
    && apt-get install -y cmake protobuf-compiler \
    && apt-get install gdb -y \
    && apt-get install python3 python3-pip -y \
    && pip3 install matplotlib

# Set up a user so we're not just in root
ARG USERNAME=dev
ARG USER_UID=1000
ARG USER_GID=$USER_UID

# Create the user
RUN groupadd --gid $USER_GID $USERNAME \
    && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
    && apt-get update \
    && apt-get install -y sudo \
    && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
    && chmod 0440 /etc/sudoers.d/$USERNAME


# Set the default user
USER $USERNAME

WORKDIR /workspaces/cuda-toy

# Set bash as default shell instead of sh
ENV SHELL /bin/bash

VSCode has great support for running within containers with the Remote - Containers extension. The extension will automatically mount your project's root directory and you can use VSCode as if it were running locally. I set up a .devcontainer.json file to customize the resulting VSCode instance with a few extensions I wanted to use (notably Intellicode and NSight) and allow the container to access the GPU.

// .devcontainer.json

{
  "build": {
    "dockerfile": "Dockerfile",
    "args": {
      "tag": "cuda-dev"
    }
  },
  "extensions": [
    "ms-vscode.cpptools",
    "vscodevim.vim",
    "visualstudioexptteam.vscodeintellicode",
    "ms-vscode.cmake-tools",
    "jeff-hykin.better-cpp-syntax",
    "nvidia.nsight-vscode-edition"
  ],
  "runArgs": ["--gpus=all", "--name=cuda_dev_container"]
}

Planning

As I started the project, I knew I wanted to write a simple network in vanilla C++ before I got into messing around with CUDA just so that isolating GPU-related problems was easier. After I had a simple network working as a proof-of-concept written in C++ running only on my CPU, I'd expand increase the size of the network and train on the MNIST handwritten digits classification dataset. Then, I'd replace certain parts of the C++ CPU code with CUDA code for GPU acceleration and compare the performance of the original with the new versions. I'll follow vaguely that same outline through this post.

Implementing a CPU based network

As a temporary goal to create an extremely simple network, I decided that solving XOR was a good benchmark. This would result in a network with two input nodes, some number of hidden nodes, and one output node. As a reminder, this is what XOR looks like:

x1x2y
000
011
111
100

One of the most efficient ways to represent neural networks is as a series of matrix operations. For an overview of the math relevant to this matrix representation, read the math document that I put together to organize my thoughts while doing this project.

In order to perform these matrix operations, I created a Matrix class which is declared in matrix.hpp and defined in matrix.cpp. As a brief overview, the Matrix class includes operators for elementwise arithmetic and full multiplication of two matrices. There are also corresponding scalar arithmetic operations along with a host of miscellaneous operations such as transposition and exponentiation/logarithms, to name a few. Something to note is early in the project, I had used a std::vector<std::vector<float>> to store the data in a matrix, but this resulted in very large overheads when accessing or modifying matrix elements. Later, I replaced this with a cleverly formatted float[][] (or float**, same thing). The array was created like this:

int rows = 4;
int cols = 5;
float** matrix = new float*[rows];
matrix[0] = new float[rows * cols];
for (auto i = 1; i < rows; i++) {
    matrix[i] = matrix[i - 1] + cols;
}

This allowed me to access any element at $(i, j)$ using both matrix[i][j] and matrix[i * cols + j], which was very handy for passing them into CUDA later.

With that basic matrix functionality in place, I was able to move on to actually creating the fundamental units of the neural network itself: the layers. I wanted to support three kinds of layers (technically just one kind of layer with different activations, but this vernacular is useful for my programming model). These three were linear layers ($z=wx+b$) with ReLU, sigmoid, and softmax activation functions, each defined as LinearReluLayer, LinearSigmoidLayer, and LinearSoftmaxLayer, respectively. Sigmoid and softmax activations are each good for classification problems (binary/multi-label for sigmoid and sparse for softmax), while ReLU is good for hidden layers.

An opportunity arose for me to make use of an inheritance pattern in creating these layer classes. Each of those listed above inherits from a LinearLayer which implements much of the logic of a linear layer such as forward and backward propagation. Each of the child layer classes is much simpler than they otherwise would have been, only having to implement an activation function and weight/bias initialization. Then, by combining these layers, I could make any fully connected network using thos activation functions as below.

int x_size = 2;
int l1_size = 10;
int l2_size = 1;
LinearReluLayer l1(x_size, l1_size);
LinearSigmoidLayer l2(l1_size, l2_size);
l2.set_last_layer(true);

This network consists of an input layer of size 2, followed by a ReLU activated layer of size 10, followed by an output sigmoid layer of size 1. This is the network that was used to solve XOR. Note that when solving MNIST, the network was wrapped in an MnistModel class with methods for propagating forward and backwards throughout the whole network.

Another immediate task to be completed after setting up this framework for creating layers/models is loading data. In my XOR network, I didn't bother with batching or shuffling of input data, nor did I have to deal with importing data into my program as a Matrix. However, with MNIST data these were all concerns that needed to be addressed.

The MNIST data can be downloaded directly from Yann LeCun's website. The website also includes an explanation of the binary format in which the data is stored (I chose not to use third-party CSV versions of the data to get the authentic experience, though in hindsight all it meant was extra work). All the importing code is contained within the Mnist::read_training function.

In every training iteration, a mini-batch is propagated forward and backward through the network to calculated gradients to perform stochastic gradient descent. Those mini-batches should be randomly selected from the training set, and this is done in the Mnist::get_training_batch function.

And with that, the CPU-only implementation of my network was complete! The measured performance of that version is dicussed later.

Problems Encountered

One of the first issues I ran into was sub-optimal weight initialization. I first noticed this with the XOR network as the network seeming to converge at outputting all 0 or all 1. As of writing this, I'm still not sure what exactly was wrong with my initializations, but a hacky workaround was just to increase the complexity (number of hidden nodes) of the network.

Later on, as I was working with the MNIST dataset, I noticed an issue that appeared superficially similar: the network seemed to converge at always outputting 1. However, in this case, the root of the issue lay in the MNIST data itself. Specifically, the pixel values are stored on a scale of 0 to 255, which when supplied as input was very quickly saturating the network. The solution for this was simply to normalize the data as it was read in from the binaries.

Adding GPU acceleration

The most computationally intensive part of my neural network implementation is the matrix operations. These all work essentially through doubly (and sometimes triply for multiplication) nested for loops. When running on a CPU, as it is programmed so far, only one computation can be done at a time which makes this a very time intensive process. Here is an example of how matrix addition could be implemented in this way.

int N; // N is size
float[][] m1 = {...}; // m1 is an NxN matrix
float[][] m2 = {...}; // m2 is an NxN matrix
float[][] result = {...}; // result is an NxN matrix

for (auto i = 0; i < N; i++) {
    for (auto j = 0; j < N; j++) {
        result[i][j] = m1[i][j] + m2[i][j];
    }
}

Instead of these nested loops, we can utilize the parallel nature of GPU computations to dramatically speed up this whole process. GPUs essentially work by running many threads of computation simultaneously. Then, if we have $N^2$ threads arranged into an $N$ x $N$ grid, we can have each thread doing just one computation, entirely eliminating the nested loops. This technique is known as a grid-stride loop. An example of the same adding function written in such a way is below (note that the last three lines run on every thread).

int N;
float[][] m1 = {...};
float[][] m2 = {...};
float[][] result = {...};

int row = threadIdx.y;
int col = threadIdx.x;
result[row][col] = m1[row][col] + m2[row][col];

Of all the matrix operations, matrix multiplication is by far used the most often in my implementation. As a sort of baseline test of how much effect the speedup from GPU computations had, I decided to first just replace the matrix multiplication functionality. That setup was benchmarked, and the results are shown later.

After getting the hang of how to write GPU functions, it turned out to be pretty trivial to do all of them. Eventually, all the matrix operations were replaced by GPU functions, defined in matrix_kernels.cu.

A note on CUDA

CUDA is NVIDIA's toolkit for writing and running programs that utilize the GPUs they manufacture. Several alternatives exist, such as OpenCL (works on almost any GPU) and Apple's Metal (works on Apple devices, which mostly have AMD devices), but I decided to use CUDA anyway for a couple reasons. I already had an NVIDIA GPU in my system and I don't have an Apple system, so that eliminated Metal as an option. I couldn't find nearly as much documentation or examples for OpenCL as for CUDA, so I figured CUDA would be the better choice.

CUDA allows programmers to define kernels, which are the functions that are run on each individual thread. They are specified with the __global__ declaration. Notably, it was far simpler to transfer 1-dimensional arrays, which is why the clever internal representation of the matrices detailed earlier came in handy. Here is an example of the matrix addition kernel discussed earlier.

__global__ void matadd_k(float* lhs, float* rhs, size_t rows, size_t cols,
                         bool sub, bool broadcast) {
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  if (row < rows && col < cols) {
    lhs[row * cols + col] +=
        (sub ? -1 : 1) * (broadcast ? rhs[row] : rhs[row * cols + col]);
  }
}

There are some extra parameters in this kernel to handle subtraction and broadcasting in the same kernel. Broadcasting is effectively an elementwise operation applied uniformly to a whole row (i.e. adding a 4x1 matrix to a 4x4 matrix).

The above kernel also includes references to various indexes to calculate which row or column each thread acts upon. This is because of how CUDA is organized. Threads can be organized into 3-dimensional blocks, which can in turn be organized into 3-dimensional grids.


Example of 2-dimensional grid and block organization. Source.

Then, a user can determine how many many and what dimension of threads and blocks to use when launching each individual kernel. Here is an example that also shows how data is transferred from a CPU process to the GPU through special memory copying functions.

#define BLOCK_SIZE 16

void matadd_wrapper(float* lhs, float* rhs, size_t rows, size_t cols, bool sub,
                    bool broadcast) {
  float *d_l, *d_r;
  cudaMalloc(&d_l, sizeof(float) * rows * cols);
  cudaMalloc(&d_r, sizeof(float) * rows * (broadcast ? 1 : cols));

  cudaMemcpy(d_l, lhs, sizeof(float) * rows * cols, cudaMemcpyHostToDevice);
  cudaMemcpy(d_r, rhs, sizeof(float) * rows * (broadcast ? 1 : cols),
             cudaMemcpyHostToDevice);

  uint grid_x = (cols + BLOCK_SIZE - 1) / BLOCK_SIZE;
  uint grid_y = (rows + BLOCK_SIZE - 1) / BLOCK_SIZE;
  dim3 gridsize(grid_x, grid_y);
  dim3 blocksize(BLOCK_SIZE, BLOCK_SIZE);

  matadd_k<<<gridsize, blocksize>>>(d_l, d_r, rows, cols, sub, broadcast);

  cudaMemcpy(lhs, d_l, sizeof(float) * rows * cols, cudaMemcpyDeviceToHost);
  cudaFree(d_l);
  cudaFree(d_r);
}

The <<<>>> notation is used to designate the launching of a kernel and the parameters for the threads and blocks go within. The optimal block size is often constant for a designated class of systems, and so is provided as a preprocessor directive.

This multi-dimensional organization is largely for user convenience, as much of the data utilized by GPU processes are often multi-dimensional as well. I also think the multi-dimensionality could be somewhat representative of the physical layout of the cores on a GPU, but this is just a hunch.

Much of the information covered in this section is explained in greater detail on NVIDIA's own explanation of CUDA programming.

Performance Comparison

All of the following graphs measure the time for a training cycle excluding picking a mini-batch. This means the forward and backward progations, as well as the parameter updates. Each setup was run for 50 iterations and measures the loss over time and the distribution of iteration times in microseconds.


CPU-only network using std::vector. Iteration time approximately 3430 ms.


CPU-only network using arrays. Iteration time approximately 428 ms.


CPU network using GPU matrix multiplication. Iteration time approximately 18 ms.


GPU network. Iteration time approximately 12 ms.

From the original std::vector approach, the final network had about a 285x speedup. From the array-based CPU network, the final network had about a 35x speedup.

How to make it even faster

There are a couple ways to make the network inference and training faster with GPU acceleration.

Firstly, the overhead of copying data between the CPU and GPU can be dramatically decresed by performing whole layers on GPU. Currently, each individual operation in a layer happens in its own kernel, but by combining this into a single kernel for each layer, it's possible that the network inference could be considerably sped up.

Secondly, much of the time spent during the training of the network is dedicated to getting training data. This process is currently entirely CPU-controlled, but by transferring this to the GPU, I'm sure the training speed would increase dramatically. However, I'm not well-versed in GPU memory management enough to implement that mylself, so I'll just leave it as a thought.

Source code

The source code for this project exists at its GitHub repository. The cpu-only branch contains a snapshot of the project from before the GPU parts were implemented while the master branch contains the GPU setup.