Skip to main content
2025.12.06392 words

Direct Preference Optimization (DPO)

Bradley-Terry Model

DPO starts from the Bradley-Terry model, which connects rewards to preferences. This model expresses the probability that a human prefers response ywy_w (winner) over yly_l (loser) given prompt xx:

p(ywylx)=σ(r(x,yw)r(x,yl))p(y_w \succ y_l | x) = \sigma(\htmlClass{color-primary}{r(x, y_w)} - \htmlClass{color-secondary}{r(x, y_l)})

where r(x,y)r(x, y) is the implicit reward model (r(x,yw)r(x, y_w) – reward of preferred response, r(x,yl)r(x, y_l) – reward of rejected response) and σ\sigma is the sigmoid function.

And feedback comes as preferences over model samples

D={xi,ywi,yli}\mathcal{D} = \{x_i, y_w^i, y_l^i\}

Given this, we can define the loss function for the reward model:

LR(ϕ,D)=E(x,yw,yl)D[logσ(rϕ(x,yw)rϕ(x,yl))]\mathcal{L}_{\text{R}}(\phi, \mathcal{D}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l)) \right]

where ϕ\phi are the parameters of the reward model.

The loss function is essentially binary classification — the logit is just the difference in rewards.

Deriving the DPO Objective

Instead of training a separate reward model (as in RLHF), DPO uses a theoretical result connecting the optimal policy π\pi^* to the reward function.

The standard RLHF objective is:

maxπθExD,yπθ(yx)[rϕ(x,y)]expected rewardβDKL[πθ(yx)πref(yx)]KL penalty\max_{\pi_\theta} \underbrace{\mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta(y|x)}[r_\phi(x, y)]}_{\text{expected reward}} - \underbrace{\beta \, D_{KL}[\pi_\theta(y|x) \| \pi_{\text{ref}}(y|x)]}_{\text{KL penalty}}

In other words: maximize expected reward, minus a KL penalty to keep the policy close to the reference.

This objective has a closed-form solution for the optimal policy π\pi^*:

π(yx)=1Z(x)normalizeπref(yx)exp(1βr(x,y))reward scaling\pi^*(y|x) = \underbrace{\frac{1}{Z(x)}}_{\text{normalize}} \pi_{\text{ref}}(y|x) \underbrace{\exp\left(\frac{1}{\beta} r(x, y)\right)}_{\text{reward scaling}}

where Z(x)Z(x) is the partition function:

Z(x)=yYπref(yx)exp(1βr(x,y))Z(x) = \sum_{y \in \mathcal{Y}} \pi_{\text{ref}}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right)

Looks almost the same, but here we sum over all possible responses yYy \in \mathcal{Y}.

We can rearrange this to express the reward r(x,y)r(x, y) in terms of the optimal policy π\pi^*, reference policy πref\pi_{\text{ref}}, and partition function Z(x)Z(x):

r(x,y)=βlogπ(yx)πref(yx)+βlogZ(x)r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + \beta \log Z(x)

DPO Loss

Substituting this reward expression into the Bradley-Terry model (the Z(x)Z(x) terms cancel out, what a relief!), we get the preference probability in terms of the policy:

p(ywylx)=σ(βlogπ(ywx)πref(ywx)βlogπ(ylx)πref(ylx))p(y_w \succ y_l | x) = \sigma\left(\beta \log \frac{\pi^*(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi^*(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right)

Now we can train the policy πθ\pi_\theta directly by minimizing the negative log-likelihood of the preference data:

LDPO=E(x,yw,yl)D[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}}\left[\log \sigma\left(\htmlClass{color-primary}{\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)}} - \htmlClass{color-secondary}{\beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}}\right)\right]

This bypasses the need for a separate reward model entirely. Pretty cool, huh?


01

Implementation

import torch
import torch.nn.functional as F
02

Log Probabilities

We need to compute the log probability of a sequence, logπ(yx)\log \pi(y|x). Since the model is autoregressive, this is the sum of the log probabilities of each token given the history.

def get_log_probs(model, input_ids, attention_mask, labels):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
03

Align logits and labels. The model predicts the NEXT token, so logits[t] corresponds to labels[t+1]. We remove the last logit (no next token) and the first label (no prediction).

    logits = logits[:, :-1, :]
    labels = labels[:, 1:]
04

Create a mask for padding tokens (where labels are -100). We must mask BEFORE gathering because -100 is not a valid tensor index.

    mask = (labels != -100).float()
    
05

Replace -100 with 0 (or any valid index) to prevent index errors in gather. We clone labels to ensure we don't modify the input tensor in place.

    labels_safe = labels.clone()
    labels_safe[labels == -100] = 0
06

Compute log probabilities for all vocabulary tokens at each position.

    log_probs = F.log_softmax(logits, dim=-1)
    
07

Select the log probability of the ACTUAL token that appeared in the sequence. gather dim=-1 selects the value at the index specified by labels_safe.

    per_token_log_probs = torch.gather(log_probs, dim=-1, index=labels_safe.unsqueeze(-1)).squeeze(-1)
08

Multiply by mask to zero out padding tokens, then sum over the sequence. Result shape: (batch_size,)

    return (per_token_log_probs * mask).sum(dim=-1)
09

DPO Loss Function

This implements the DPO objective derived above. It minimizes the negative log-likelihood of the preference data under the implicit reward model.

Also, we can use substraction instead of ratio since logab=logalogb\log \frac{a}{b} = \log a - \log b.

def dpo_loss(
    policy_model,
    reference_model,
    chosen_ids, chosen_mask, chosen_labels,
    rejected_ids, rejected_mask, rejected_labels,
    beta: float = 0.1,
):
10

Compute policy log probabilities for chosen and rejected responses. This gives us logπθ(ywx)\log \pi_\theta(y_w|x) and logπθ(ylx)\log \pi_\theta(y_l|x).

    pi_chosen = get_log_probs(policy_model, chosen_ids, chosen_mask, chosen_labels)
    pi_rejected = get_log_probs(policy_model, rejected_ids, rejected_mask, rejected_labels)
11

Compute reference log probabilities (frozen model). This gives us logπref(ywx)\log \pi_{\text{ref}}(y_w|x) and logπref(ylx)\log \pi_{\text{ref}}(y_l|x). We use no_grad() because we don't update the reference model.

    with torch.no_grad():
        ref_chosen = get_log_probs(reference_model, chosen_ids, chosen_mask, chosen_labels)
        ref_rejected = get_log_probs(reference_model, rejected_ids, rejected_mask, rejected_labels)
12

Calculate implicit rewards (log-ratios). r(x,y)=β(logπθ(yx)logπref(yx))r(x, y) = \beta (\log \pi_\theta(y|x) - \log \pi_{\text{ref}}(y|x))

    chosen_logratios = pi_chosen - ref_chosen
    rejected_logratios = pi_rejected - ref_rejected
13

Compute the Bradley-Terry logits.

u=β(r(x,yw)r(x,yl))=β(logπ(yw)πref(yw)logπ(yl)πref(yl))u = \beta (r(x, y_w) - r(x, y_l)) = \beta \left( \log \frac{\pi(y_w)}{\pi_{\text{ref}}(y_w)} - \log \frac{\pi(y_l)}{\pi_{\text{ref}}(y_l)} \right)
    logits = beta * (chosen_logratios - rejected_logratios)
    
14
  1. Compute the negative log-likelihood loss. L=logσ(u)=log11+eu\mathcal{L} = -\log \sigma(u) = -\log \frac{1}{1 + e^{-u}}

F.logsigmoid(x) computes log(1 / (1 + exp(-x))) stably.

    loss = -F.logsigmoid(logits).mean()

    return loss
15

Training Loop

Standard PyTorch training step, passing both chosen and rejected sequences.

def train_step(policy_model, reference_model, batch, optimizer, beta=0.1):
    optimizer.zero_grad()
    
    loss = dpo_loss(
        policy_model, reference_model,
        batch["chosen_input_ids"], batch["chosen_attention_mask"], batch["chosen_labels"],
        batch["rejected_input_ids"], batch["rejected_attention_mask"], batch["rejected_labels"],
        beta=beta,
    )
    
    loss.backward()
    optimizer.step()
    return loss.item()