[Kernel] Why is CrossEntropyLoss Faster with Triton?
Published in , 2025
Why is CrossEntropyLoss Faster with Triton?
Cross Entropy Loss is one of the most widely used loss functions in training deep learning models. However, for large language models (LLMs) with vocab sizes in the hundreds of thousands, this computation can become a performance bottleneck. In this article, we will explain why a Triton-based implementation of CrossEntropyLoss
can be significantly faster, using both mathematical reasoning and code examples.
1. Objective: Compute Cross Entropy Loss
Given:
Logits $x \in \mathbb{R}^{N \times V}$
- $N$: number of total tokens (batch size × sequence length)
- $V$: vocabulary size (e.g., 32000, 256000)
Labels $y \in \mathbb{N}^N$: correct class indices per token
The cross entropy loss per token is defined as:
\[\mathcal{L}_i = -\log \left( \frac{\exp(x_{i, y_i})}{\sum_{j=1}^V \exp(x_{i, j})} \right) = \log\left( \sum_{j=1}^V \exp(x_{i,j}) \right) - x_{i,y_i}\]2. PyTorch’s Standard Method
PyTorch internally follows this process for computing CrossEntropyLoss
:
logits = model(...) # [N, V]
probs = torch.softmax(logits, dim=-1) # [N, V]
loss = -torch.log(probs[range(N), labels])
In other words:
- Compute softmax over all classes: $\exp(x) / \sum \exp(x)$
- Extract the probability corresponding to the correct class
- Apply negative log to compute the loss
This approach requires computing softmax across all V classes, which involves exp
, sum
, div
, and log
operations.
3. Triton’s Approach: The Core Idea
The Triton-based _cross_entropy_forward
implementation optimizes the computation using the identity:
Core Equation:
\[\mathcal{L}_i = \underbrace{\log\left(\sum_j \exp(x_{i,j})\right)}_{\text{logsumexp}} - \underbrace{x_{i, y_i}}_{\text{correct logit}}\]Instead of computing the full softmax, we compute logsumexp
and subtract the logit of the correct class.
Numerical Stability:
\[\log \sum_j \exp(x_j) = c + \log \sum_j \exp(x_j - c), \quad \text{with } c = \max_j x_j\]This form avoids overflow issues from large exponentials.
4. Matrix-Level Comparison
Suppose:
x = [2.0, 0.5, -1.0, 3.0, 0.1] # logits
label = 3
Standard PyTorch:
- Compute softmax(x)
- Compute loss = -log(p[label])
Triton:
- max(x) = 3.0
- logsumexp = 3.0 + log(sum(exp(x - 3.0)))
- loss = logsumexp - x[3]
Both are mathematically equivalent, but the Triton method avoids computing full softmax and reduces operations.
5. Triton Kernel: Code Example
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr, logsumexp_ptr, labels_ptr,
VOCAB_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
logits_ptr += row * logits_row_stride
label = tl.load(labels_ptr + row)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
logits = tl.load(logits_ptr + col_offsets, mask=mask)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label != -100:
x_label = tl.load(logits_ptr + label)
loss = logsumexp - x_label
else:
loss = 0.0
tl.store(loss_ptr + row, loss)
tl.store(logsumexp_ptr + row, logsumexp)
6. Performance Comparison
Operation | PyTorch | Triton |
---|---|---|
exp | $N \times V$ | $N \times V$ |
div | $N \times V$ | not required |
log | $N$ | $N$ |
gather | $N$ | $N$ |
Triton’s approach skips division and softmax normalization, saving compute and memory bandwidth.
7. Conclusion
The Triton-based CrossEntropyLoss is:
- Fully parallelized across token rows on the GPU
- Only computing the necessary values for loss
- Numerically stable via logsumexp with max-shift
- Scalable to large vocab sizes