At work, we focus on optimizing LLM serving, and one topic that comes up repeatedly is kernel optimization. I want to share some insights into what a kernel actually is and where it fits in the stack, because believe it or not, every modern LLM and diffusion model is ultimately powered by kernels running on a GPU.
I have some familiarity with Compute Unified Device Architecture (CUDA) and I also happen to have an NVIDIA Blackwell GPU in my workstation, so in this post I will explain what a kernel is and walk through writing one from scratch in CUDA.
Kernel - A GPU function
In simplest terms, a kernel is a function that runs on a GPU. You decide what it does – the logic is programmable using frameworks like CUDA for NVIDIA, ROCm for AMD, MUSA for Moore Threads, and Metal for Apple Silicon.
Kernels are hardware-dependent. A kernel written for NVIDIA GPUs will not run on AMD or Apple hardware because the underlying device architecture differs. However, the operations they perform can be exactly the same. A kernel for vector addition should produce identical results regardless of the hardware, even though the implementation details vary.
In this post, we will focus on kernels written in CUDA for NVIDIA GPUs. That said, other vendors typically keep their APIs and design patterns similar to CUDA, so the concepts transfer well.
Host and Device boundary
A typical machine learning system has the following components:
- A CPU
- Main memory (RAM)
- A Graphics Processing Unit (GPU)
- Compute cores
- Video memory (VRAM)
Programs you write normally execute on the CPU with allocations in RAM. Kernels, on the other hand, execute on the GPU with allocations in VRAM. In CUDA terminology, the GPU is referred to as the “Device” and the rest of the system (CPU + RAM) as the “Host”. When writing kernels, you need to keep these two worlds separate in your mental model – think of yourself as sitting on the GPU.
Thinking about parallelism
GPUs excel at massively parallel computation. That’s why they are used for workloads that benefit from parallelism – from block hash checking in crypto mining to training and serving LLMs.
However, managing thousands of concurrent threads introduces challenges. You want your kernel to be as parallel as possible, but you also need to worry about consistency and race conditions. CUDA provides abstractions that make this manageable.
Threads, Warps and Blocks
A GPU thread is a single stream of execution, similar to a CPU thread. Threads are grouped into blocks for management purposes. Each thread in a block can share data with other threads in that block through fast shared memory, while threads across different blocks can only communicate through slower global memory.
Under the hood, threads within a block are further organized into warps – groups of 32 threads that execute instructions in lockstep on the hardware. A warp is the actual unit of execution on an NVIDIA GPU’s streaming multiprocessor (SM). While you don’t manage warps directly when writing kernels, being aware of them matters for performance: if threads within a warp diverge (e.g., take different branches in an if statement), the warp must execute both paths sequentially, reducing efficiency.
When you launch a GPU kernel, you specify the grid size (number of blocks) and the block size (threads per block).
Grid (blocksPerGrid = 4, blockDim.x = 4)
+-----------------+-----------------+-----------------+-----------------+
| Block 0 | Block 1 | Block 2 | Block 3 |
| blockIdx.x = 0 | blockIdx.x = 1 | blockIdx.x = 2 | blockIdx.x = 3 |
| | | | |
| T0 T1 T2 T3 | T0 T1 T2 T3 | T0 T1 T2 T3 | T0 T1 T2 T3 |
+-----------------+-----------------+-----------------+-----------------+
Global index = blockIdx.x * blockDim.x + threadIdx.x
Block 0: 0*4+0=0 0*4+1=1 0*4+2=2 0*4+3=3
Block 1: 1*4+0=4 1*4+1=5 1*4+2=6 1*4+3=7
Block 2: 2*4+0=8 2*4+1=9 2*4+2=10 2*4+3=11
Block 3: 3*4+0=12 3*4+1=13 3*4+2=14 3*4+3=15
Writing first Kernel - Blur 1D array
NVIDIA GPU kernels are written in C++ using the CUDA framework. Before writing our first kernel, let’s define what it will do.
We will perform a blur operation on a 1D-array of floats. Given an input array $A$ of length $n$, the blur operation produces an output array $O$ of the same length, where each element is the sum of its immediate neighbors and itself.
Mathematically,
$$ O_i = A_{i-1} + A_i + A_{i+1}, \quad \text{for } i = 0, 1, \ldots, n-1 $$where
$$ A_j = 0 \quad \text{if } j < 0 \text{ or } j \geq n $$This is essentially a 1D convolution with a uniform kernel $K = [1, 1, 1]$, or more formally:
$$ O_i = \sum_{k=-1}^{1} A_{i+k} $$For example:
Input = [2, 3, 5]
Output = [5, 10, 8]
Sounds simple right? But writing a kernel for it will teach us about thread synchronization, shared memory, and other parallel programming concepts. Let’s jump in.
We first allocate memory for our input and output on the CPU using a standard malloc call, and then populate the input array.
#include <cuda_runtime.h>
#include <iostream>
#include <cstdlib>
int main() {
int n = 10000; // Number of elements in the input
size_t size = n * sizeof(float);
float *h_in = (float *)malloc(size);
float *h_out = (float *)malloc(size);
for (int i=0;i<n;i++) h_in[i] = i+1;
}
That’s it for the CPU side. Next, we allocate memory of the same size on the GPU and copy the input data from host to device. Note that kernels don’t return values – instead, they write results to a device pointer that we provide.
float *d_in, *d_out;
cudaMalloc((void **)&d_in, size);
cudaMalloc((void **)&d_out, size);
cudaMemcpy(d_in, h_in, size, cudaMemcpyHostToDevice);
Now we define the kernel and launch it from main, passing d_in and d_out. Upon completion, the results will be in d_out.
int main() {
// rest of the code
int threadsPerBlock = 256;
int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;
blur<<<blocksPerGrid, threadsPerBlock>>>(d_in, d_out, n);
}
A kernel launch is always accompanied by execution configuration – at minimum, the number of blocks and threads per block (and as we will see later, optionally the size of shared memory). The <<<blocksPerGrid, threadsPerBlock>>> syntax is CUDA-specific and tells the GPU how to distribute the work.
We chose 256 threads per block, which is a common default. It’s a multiple of the warp size (32), ensures good occupancy across most NVIDIA architectures, and stays well within the hardware limit of 1024 threads per block. The number of blocks is computed from the input size to ensure every element gets a thread.
Now, let’s implement the actual kernel.
__global__ void blur(const float *d_in, float *d_out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float low = (i > 0) ? d_in[i-1] : 0;
float high = (i != n-1) ? d_in[i+1] : 0;
d_out[i] = low + d_in[i] + high;
}
}
The __global__ qualifier marks this function as a kernel – a function that is called from host (CPU) code but executes on the device (GPU). CUDA has two other qualifiers: __device__ for functions that run on the GPU and are callable only from other GPU functions, and __host__ for regular CPU functions (the default).
When the kernel is launched, thread blocks are scheduled across the GPU’s streaming multiprocessors (SMs). All threads within a block execute concurrently on the same SM, while blocks themselves may or may not run simultaneously depending on available resources. Each thread performs a single independent operation; together, all threads complete the task. In this kernel, each thread computes the blurred output for its own index.
Inside the kernel, built-in variables threadIdx, blockDim, and blockIdx provide the thread’s index within its block, the block size, and the block’s index within the grid respectively. Together, they compute a unique global index for each thread. These indices can be multi-dimensional, but for our 1-D case, .x is sufficient.
Once the kernel finishes, we copy the results back from device to host and clean up.
cudaMemcpy(h_out, d_out, size, cudaMemcpyDeviceToHost);
for (int i=0; i<n; i++) {
std::cout << h_out[i] << " ";
}
cudaFree(d_in);
cudaFree(d_out);
free(h_in);
free(h_out);
You have written your first kernel! Save it as blur.cu and compile with nvcc:
nvcc blur.cu -o blur
A note on error handling: For brevity, the code above omits error checking entirely. In practice, every CUDA API call (
cudaMalloc,cudaMemcpy, etc.) returns acudaError_tstatus code, and kernel launches can also fail silently. You should always check for errors usingcudaGetLastError()after a kernel launch and verify the return value of API calls. A common pattern is to wrap calls in aCUDA_CHECKmacro that prints the error and aborts on failure – this will save you hours of debugging.
Modernizing
You might find the explicit memory allocation and copying a bit verbose – that was intentional. I wanted to show what actually happens under the hood. In practice, CUDA provides higher-level primitives that reduce this boilerplate significantly.
Managed memory allocation
Instead of allocating memory separately on host and device and manually copying between them, CUDA provides cudaMallocManaged. This returns a single pointer accessible from both CPU and GPU. The CUDA runtime handles data migration transparently using page faults and other OS-level mechanisms.
Cooperative Groups
Instead of relying on the raw built-in variables (threadIdx, blockIdx, etc.) inside the kernel, CUDA’s cooperative_groups API provides a more structured way to work with thread groups. It also exposes synchronization primitives that we will use later.
Together, these two changes simplify the program considerably:
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
__global__ void blur(const float *d_in, float *d_out, int n) {
cg::thread_block cta = cg::this_thread_block();
int i = cta.group_index().x * cta.group_dim().x + cta.thread_index().x;
// same code
}
int main(){
float *d_in, *d_out;
int n = 10000;
size_t size = sizeof(float) * n;
cudaMallocManaged(&d_in, size);
cudaMallocManaged(&d_out, size);
for (int i = 0; i < n; i++) d_in[i] = i + 1;
// same launch logic
cudaDeviceSynchronize();
// read the results from d_out;
}
cudaDeviceSynchronize() is needed because kernel launches are asynchronous – the CPU continues executing immediately after the launch and might try to read from d_out before the GPU has finished writing results. This call blocks until the kernel completes.
This wasn’t needed in our earlier version because cudaMemcpy is itself a blocking operation that implicitly synchronizes. With managed memory, there is no explicit copy, so the synchronization must be done manually.
Sharing memory
In our kernel so far, every thread reads from global memory three times: once for its own element, once for the previous, and once for the next. But the neighboring threads are already reading those same values for their own computation. This means we are accessing global memory roughly 3x more than necessary. In kernel programming, you want to squeeze every last drop of performance, so redundant memory accesses matter.
CUDA provides shared memory (__shared__), a fast on-chip memory that is local to each thread block. If threads x-1 and x+1 are in the same block as thread x, they can all share data through shared memory instead of going back to slow global memory each time.
The idea is straightforward:
- Each thread loads its own element from global memory into shared memory – one access per thread.
- The first and last threads in each block additionally load the neighboring elements that fall outside the block’s range.
- A synchronization barrier ensures all loads complete before any thread reads from shared memory.
Let’s modify our kernel to use shared memory. We allocate a temporary buffer temp of size blockDim + 2 – two extra slots for the left and right neighbors.
__global__ void blur(const float *d_in, float *d_out, int n) {
cg::thread_block cta = cg::this_thread_block();
extern __shared__ float temp[];
// Global index for global data
// Local index for temp
int gidx = cta.group_index().x * cta.group_dim().x + cta.thread_index().x;
int lidx = cta.thread_index().x + 1;
if (gidx < n) {
temp[lidx] = d_in[gidx]; // access its own global data
} else {
temp[lidx] = 0.0f; // zero out slots for out-of-bounds threads
}
// If first thread of a block, read global for previous
if (lidx == 1) {
temp[0] = (gidx > 0) ? d_in[gidx-1] : 0.0f;
}
// If last thread of a block, read global for next
if (lidx == cta.group_dim().x) {
temp[lidx+1] = (gidx < n - 1) ? d_in[gidx+1] : 0.0f;
}
// barrier to make sure all threads have filled the shared buffer
cg::sync(cta);
if (gidx < n) {
d_out[gidx] = temp[lidx-1] + temp[lidx] + temp[lidx+1];
}
}
We also need to pass the shared memory size when launching the kernel:
int sharedMemSize = (threadsPerBlock + 2) * sizeof(float);
blur<<<blocksPerGrid, threadsPerBlock, sharedMemSize>>>(d_in, d_out, n);
The third parameter in the <<<>>> launch syntax specifies the amount of dynamically allocated shared memory per block in bytes. This is what backs the extern __shared__ declaration inside the kernel.
We now have an optimized CUDA kernel for blurring a 1D array.
Compile and run with: Complete code in a single file (
blur.cu)#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <iostream>
namespace cg = cooperative_groups;
// Kernel: each thread computes one blurred output element using shared memory.
// Shared memory (temp[]) has blockDim.x + 2 slots:
// temp[0] = left neighbor from previous block (or 0 if out of bounds)
// temp[1..blockDim] = elements owned by this block's threads
// temp[blockDim+1] = right neighbor from next block (or 0 if out of bounds)
__global__ void blur(const float *d_in, float *d_out, int n) {
cg::thread_block cta = cg::this_thread_block();
// Dynamically allocated shared memory, size passed at kernel launch
extern __shared__ float temp[];
// gidx: global index into d_in / d_out
// lidx: local index into temp (offset by 1 to leave room for left neighbor)
int gidx = cta.group_index().x * cta.group_dim().x + cta.thread_index().x;
int lidx = cta.thread_index().x + 1;
// Each thread loads its own element from global memory into shared memory
if (gidx < n) {
temp[lidx] = d_in[gidx];
} else {
temp[lidx] = 0.0f;
}
// First thread in the block loads the left neighbor element
if (lidx == 1) {
temp[0] = (gidx > 0) ? d_in[gidx - 1] : 0.0f;
}
// Last thread in the block loads the right neighbor element
if (lidx == cta.group_dim().x) {
temp[lidx + 1] = (gidx < n - 1) ? d_in[gidx + 1] : 0.0f;
}
// Wait for all threads to finish loading before reading shared memory
cg::sync(cta);
// Compute blurred output: sum of left neighbor, self, and right neighbor
if (gidx < n) {
d_out[gidx] = temp[lidx - 1] + temp[lidx] + temp[lidx + 1];
}
}
int main() {
int n = 10000;
size_t size = n * sizeof(float);
// Managed memory: accessible from both CPU and GPU
float *d_in, *d_out;
cudaMallocManaged(&d_in, size);
cudaMallocManaged(&d_out, size);
// Populate input: [1, 2, 3, ..., 10000]
for (int i = 0; i < n; i++) d_in[i] = i + 1;
// 256 threads per block: multiple of warp size (32), good default occupancy
int threadsPerBlock = 256;
// Compute grid size: enough blocks to cover all n elements
int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;
// Shared memory per block: blockDim + 2 elements (left and right neighbors)
int sharedMemSize = (threadsPerBlock + 2) * sizeof(float);
// Launch kernel: <<<blocks, threads, shared memory bytes>>>
blur<<<blocksPerGrid, threadsPerBlock, sharedMemSize>>>(d_in, d_out, n);
// Block CPU until GPU kernel completes (needed because managed memory
// has no implicit sync unlike cudaMemcpy)
cudaDeviceSynchronize();
for (int i = 0; i < n; i++) {
std::cout << d_out[i] << " ";
}
std::cout << std::endl;
cudaFree(d_in);
cudaFree(d_out);
return 0;
}
nvcc blur.cu -o blur && ./blur