[Kernel] Why is CrossEntropyLoss Faster with Triton?

Published in , 2025

Why is CrossEntropyLoss Faster with Triton?

Cross Entropy Loss is one of the most widely used loss functions in training deep learning models. However, for large language models (LLMs) with vocab sizes in the hundreds of thousands, this computation can become a performance bottleneck. In this article, we will explain why a Triton-based implementation of CrossEntropyLoss can be significantly faster, using both mathematical reasoning and code examples.


1. Objective: Compute Cross Entropy Loss

Given:

  • Logits $x \in \mathbb{R}^{N \times V}$

    • $N$: number of total tokens (batch size × sequence length)
    • $V$: vocabulary size (e.g., 32000, 256000)
  • Labels $y \in \mathbb{N}^N$: correct class indices per token

The cross entropy loss per token is defined as:

\[\mathcal{L}_i = -\log \left( \frac{\exp(x_{i, y_i})}{\sum_{j=1}^V \exp(x_{i, j})} \right) = \log\left( \sum_{j=1}^V \exp(x_{i,j}) \right) - x_{i,y_i}\]

2. PyTorch’s Standard Method

PyTorch internally follows this process for computing CrossEntropyLoss:

logits = model(...)                # [N, V]
probs = torch.softmax(logits, dim=-1)   # [N, V]
loss = -torch.log(probs[range(N), labels])

In other words:

  1. Compute softmax over all classes: $\exp(x) / \sum \exp(x)$
  2. Extract the probability corresponding to the correct class
  3. Apply negative log to compute the loss

This approach requires computing softmax across all V classes, which involves exp, sum, div, and log operations.


3. Triton’s Approach: The Core Idea

The Triton-based _cross_entropy_forward implementation optimizes the computation using the identity:

Core Equation:

\[\mathcal{L}_i = \underbrace{\log\left(\sum_j \exp(x_{i,j})\right)}_{\text{logsumexp}} - \underbrace{x_{i, y_i}}_{\text{correct logit}}\]

Instead of computing the full softmax, we compute logsumexp and subtract the logit of the correct class.

Numerical Stability:

\[\log \sum_j \exp(x_j) = c + \log \sum_j \exp(x_j - c), \quad \text{with } c = \max_j x_j\]

This form avoids overflow issues from large exponentials.


4. Matrix-Level Comparison

Suppose:

x = [2.0, 0.5, -1.0, 3.0, 0.1]  # logits
label = 3

Standard PyTorch:

  • Compute softmax(x)
  • Compute loss = -log(p[label])

Triton:

  • max(x) = 3.0
  • logsumexp = 3.0 + log(sum(exp(x - 3.0)))
  • loss = logsumexp - x[3]

Both are mathematically equivalent, but the Triton method avoids computing full softmax and reduces operations.


5. Triton Kernel: Code Example

@triton.jit
def _cross_entropy_forward(
    logits_ptr, logits_row_stride,
    loss_ptr, logsumexp_ptr, labels_ptr,
    VOCAB_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    logits_ptr += row * logits_row_stride
    label = tl.load(labels_ptr + row)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < VOCAB_SIZE
    logits = tl.load(logits_ptr + col_offsets, mask=mask)
    c = tl.max(logits, 0)
    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
    if label != -100:
        x_label = tl.load(logits_ptr + label)
        loss = logsumexp - x_label
    else:
        loss = 0.0
    tl.store(loss_ptr + row, loss)
    tl.store(logsumexp_ptr + row, logsumexp)

6. Performance Comparison

OperationPyTorchTriton
exp$N \times V$$N \times V$
div$N \times V$not required
log$N$$N$
gather$N$$N$

Triton’s approach skips division and softmax normalization, saving compute and memory bandwidth.


7. Conclusion

The Triton-based CrossEntropyLoss is:

  • Fully parallelized across token rows on the GPU
  • Only computing the necessary values for loss
  • Numerically stable via logsumexp with max-shift
  • Scalable to large vocab sizes