TrOCR: A Mathematical Explanation
TrOCR
TrOCR (Transformer-based Optical Character Recognition) maps an input image \( \mathbf{x} \)
into a text sequence \( \mathbf{y} = (y_1, \dots, y_T) \). It models the conditional probability:
$$
p(\mathbf{y} \mid \mathbf{x}) = \prod_{t=1}^{T} p(y_t \mid y_{<t}, \mathbf{x})
$$
Notation:
\( \mathbf{x} \): input image,
\( y_t \): predicted token at step \( t \),
\( y_{<t} \): sequence of previously generated tokens,
\( T \): output sequence length.
TrOCR consists of two major components:
- A Vision Transformer (ViT) as the encoder
- A Text Transformer as the decoder
1. Encoder — Vision Transformer (ViT)
1.1 Image to Patch Embeddings
Input image:
$$
\mathbf{x} \in \mathbb{R}^{H \times W \times 3}
$$
The image is divided into \( N = \frac{H \times W}{P^2} \) patches of size \( P \times P \).
Each patch is flattened and projected into a D-dimensional embedding:
$$
\mathbf{z}_0^{(i)} = \mathbf{W}_p \, \text{vec}(\mathbf{x}^{(i)}) + \mathbf{e}_{\text{pos}}^{(i)}, \quad i = 1, \dots, N
$$
Notation:
\( \mathbf{W}_p \): linear projection matrix,
\( \mathbf{e}_{\text{pos}}^{(i)} \): positional embedding,
\( D \): embedding dimension,
\( N \): number of patches.
$$
\mathbf{Z}_0 = [\mathbf{z}_0^{(1)}, \dots, \mathbf{z}_0^{(N)}]^\top \in \mathbb{R}^{N \times D}
$$
1.2 Transformer Encoder Layers
Each encoder layer \( \ell = 1, \dots, L_e \) transforms \( \mathbf{Z}_{\ell-1} \to \mathbf{Z}_\ell \).
(a) Multi-Head Self-Attention (MSA)
$$
\text{MSA}(\mathbf{Z}) = [\text{head}_1; \dots; \text{head}_h] \mathbf{W}_O
$$
$$
\text{head}_i = \text{Softmax}\!\left(\frac{\mathbf{Q}_i \mathbf{K}_i^\top}{\sqrt{d_k}}\right)\mathbf{V}_i
$$
Notation:
\( h \): number of attention heads,
\( \mathbf{Q}_i = \mathbf{Z}\mathbf{W}_Q^{(i)} \),
\( \mathbf{K}_i = \mathbf{Z}\mathbf{W}_K^{(i)} \),
\( \mathbf{V}_i = \mathbf{Z}\mathbf{W}_V^{(i)} \),
\( \mathbf{W}_O \): output projection,
\( d_k = D / h \): dimensionality of each head.
Each head attends to different representation subspaces, allowing the model to capture diverse relationships.
$$
\mathbf{A}_\ell = \text{MSA}(\mathbf{Z}_{\ell-1})
$$
Residual connection and normalization:
$$
\mathbf{H}_\ell = \text{LayerNorm}(\mathbf{Z}_{\ell-1} + \mathbf{A}_\ell)
$$
(b) Feed-Forward Network (FFN)
$$
\mathbf{Z}_\ell = \text{LayerNorm}\!\left(\mathbf{H}_\ell + \text{FFN}(\mathbf{H}_\ell)\right)
$$
FFN: Two linear layers with ReLU (or GELU) activation:
\( \text{FFN}(\mathbf{h}) = \text{ReLU}(\mathbf{h}\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2 \).
$$
\mathbf{H}_{\text{enc}} = \mathbf{Z}_{L_e} \in \mathbb{R}^{N \times D}
$$
2. Decoder — Text Transformer
2.1 Token Embeddings
$$
\mathbf{s}_0^{(i)} = \mathbf{E}_{\text{text}}(y_i) + \mathbf{e}_{\text{pos}}^{(i)}
$$
Notation:
\( \mathbf{E}_{\text{text}} \): text embedding matrix,
\( \mathbf{e}_{\text{pos}}^{(i)} \): positional embedding for token \( i \).
2.2 Transformer Decoder Layers
(a) Masked Self-Attention
$$
\mathbf{A}_m^{\text{self}} =
\text{Softmax}\!\left(
\frac{\mathbf{Q}_m^{(s)} \mathbf{K}_m^{(s)\top}}{\sqrt{d_k}} + \mathbf{M}
\right)\mathbf{V}_m^{(s)}
$$
(b) Cross-Attention
$$
\mathbf{A}_m^{\text{cross}} =
\text{Softmax}\!\left(
\frac{\mathbf{Q}_m^{(c)} \mathbf{K}_{\text{enc}}^\top}{\sqrt{d_k}}
\right)\mathbf{V}_{\text{enc}}
$$
(c) Feed-Forward Network
$$
\mathbf{S}_m = \text{LayerNorm}\!\left(
\mathbf{H}_m^{(2)} + \text{FFN}(\mathbf{H}_m^{(2)})
\right)
$$
Notation (Decoder section):
\( \mathbf{S}_{m-1} \in \mathbb{R}^{T' \times D} \): decoder input states to layer \(m\) ( \(T'\) = current sequence length).
\( \mathbf{Q}_m^{(s)}, \mathbf{K}_m^{(s)}, \mathbf{V}_m^{(s)} \): queries/keys/values for masked self-attention; computed as \( \mathbf{Q}_m^{(s)}=\mathbf{S}_{m-1}\mathbf{W}_Q^{(s)} \), \( \mathbf{K}_m^{(s)}=\mathbf{S}_{m-1}\mathbf{W}_K^{(s)} \), \( \mathbf{V}_m^{(s)}=\mathbf{S}_{m-1}\mathbf{W}_V^{(s)} \). Each has shape \(T' \times d_k\).
\( \mathbf{Q}_m^{(c)} \): query for cross-attention (from decoder); \( \mathbf{K}_{\text{enc}}, \mathbf{V}_{\text{enc}} \) come from the encoder and have shapes \(N \times d_k\).
\( \mathbf{M} \): causal mask used in masked self-attention (values \(0\) or \(-\infty\)) to prevent attending to future positions.
\( \mathbf{H}_m^{(2)} \in \mathbb{R}^{T' \times D} \): intermediate decoder hidden states (after cross-attention, before FFN).
\( \mathbf{S}_m \in \mathbb{R}^{T' \times D} \): output states of decoder layer \(m\).
\( \mathbf{s}_t \in \mathbb{R}^{D} \): decoder state at time-step \(t\) used for token prediction.
\( \mathbf{W}_o \in \mathbb{R}^{V \times D} \): output projection to vocabulary logits ( \(V\) = vocab size ).
\( D \): model (embedding) dimension; \( h \): number of attention heads; \( d_k = D / h \).
2.3 Token Prediction
$$
p(y_t \mid y_{<t}, \mathbf{x}) = \text{Softmax}(\mathbf{W}_o \mathbf{s}_t)
$$
3. Training Objective
$$
\mathcal{L} = -\sum_{t=1}^{T} \log p(y_t^{*} \mid y_{<t}^{*}, \mathbf{x})
$$
4. Inference
$$
\hat{y}_t = \arg\max_{y_t} p(y_t \mid \hat{y}_{<t}, \mathbf{x})
$$
References
Chen, M., Wang, C., Li, S., et al. (2021).
TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models.
arXiv preprint,
arXiv:2109.10282.
Comments
Post a Comment