Representation Learning: Teaching Machines to See the World the Right Way
Here is a question worth sitting with: when you look at a cat, your brain does not process 65,536 individual pixel values and then decide "cat." Something more interesting happens — your visual system pulls out the features that matter: the pointed ears, the whiskers, the particular way it holds itself. Everything else gets discarded. You end up with a compact, meaningful summary of what you saw.
Representation learning asks: can we teach machines to do the same thing? Not just to process raw data, but to automatically discover which aspects of that data actually matter — and compress everything into a form that is useful for reasoning, comparing, and deciding.
This turns out to be one of the most important ideas in modern AI. Every large language model, every image classifier, every system that lets you search for images using text — all of them are, at their core, representation learning systems. Understanding representation learning means understanding the mathematical foundation of most of what works in AI today.
1. What Is a Representation, Exactly?
Start with a concrete example. An image of a cat might be 256 × 256 pixels, which gives you 65,536 numbers. That is your raw input — call it \( \mathbf{x} \). But most of those 65,536 numbers are redundant, noisy, or irrelevant to what you actually care about. What you want is a much smaller vector that captures the essence of the image.
A representation is a learned mapping from that raw input to something more compact:
where \( d \ll 65536 \). The vector \( \mathbf{z} \) is called a latent representation or embedding. The function \( f_\theta \) is a neural network with learnable parameters \( \theta \). The central question is: how do we learn \( \theta \) so that \( \mathbf{z} \) is actually useful?
2. How Do We Know If a Representation Is Any Good?
This is less obvious than it sounds. You have trained some encoder \( f_\theta \) and you have a pile of embedding vectors \( \mathbf{z} \). How do you know if they are actually capturing meaningful structure?
2.1 The Linear Probing Test
The cleanest test is this: freeze the encoder entirely — do not touch \( \theta \) — and train a simple linear classifier on top of the embeddings:
If even a linear model can do well using only \( \mathbf{z} \), that means the representation has organized the data in a linearly separable way. The useful structure is sitting right on the surface, not buried in nonlinear tangles. This is called linear probing, and it is the standard benchmark for representation quality.
The intuition: if you need a complicated classifier to extract signal from \( \mathbf{z} \), the encoder did the work poorly and pushed the hard problem downstream. If a single matrix multiplication works, the encoder did its job.
2.2 Comparing Two Models: CKA
A question that comes up more than you might expect: if two models were trained separately on the same data, did they learn similar internal representations? This matters enormously for understanding whether different architectures or training procedures converge to the same solution, or whether they find fundamentally different ones.
The tool for this is Centered Kernel Alignment (CKA). Given embeddings from two models, \( Z_1, Z_2 \in \mathbb{R}^{N \times d} \), we first compute their Gram matrices \( K = Z_1 Z_1^\top \) and \( L = Z_2 Z_2^\top \), then measure their alignment after centering:
where HSIC is the Hilbert-Schmidt Independence Criterion, computed as:
and \( H = I - \frac{1}{n}\mathbf{1}\mathbf{1}^\top \) is the centering matrix. The key property: CKA = 1 means the two representations have identical structure (up to rotation and scaling), CKA = 0 means completely unrelated. Crucially, it is invariant to orthogonal transformations — two models that learned the same geometry but in a rotated coordinate system will still score CKA = 1.
2.3 Downstream Task Performance
Beyond linear probing and structural comparison, perhaps the most practical test is simply: how well does the representation transfer? You take the frozen embeddings \( \mathbf{z} \) and fine-tune a small model on a completely different task — say, a model trained on natural images being evaluated on medical scans. The degree to which performance degrades (or does not) tells you something important about whether the representation captured truly general structure versus task-specific patterns.
A representation that transfers well to unseen tasks is not just compressing data — it is capturing something fundamental about the distribution that generated that data. This is the gold standard, and it is why large pretrained models have become such a powerful starting point across almost every applied ML problem.
3. Four Ways to Learn a Representation
There are four fundamentally different answers to the question "how do we train \( f_\theta \)?" Each makes different assumptions about what data you have available, and each has a distinct mathematical flavor.
3.1 Supervised: Use Labels
The simplest case. You have labeled data \( (\mathbf{x}_i, y_i) \). Train end-to-end:
The penultimate layer of the network — the layer just before the final classification head — becomes your representation. When you then want to apply this to a new task with different labels, you discard the final layer and keep \( f_\theta \). This is transfer learning. The bet is that features learned on a large, diverse dataset (like ImageNet) capture general structure that transfers to your specific problem.
It works remarkably well. The reason it works mathematically is that gradient descent, applied to a sufficiently large and diverse classification problem, is forced to build internal representations that capture real semantic structure. You cannot correctly classify a thousand object categories without building something like a hierarchy of visual features.
3.2 Generative: Learn by Reconstructing
What if you have no labels at all? One elegant answer: learn to compress your data, and then learn to reconstruct it. If you can squeeze \( \mathbf{x} \) into a small \( \mathbf{z} \) and then recover \( \mathbf{x} \) from \( \mathbf{z} \) alone, then \( \mathbf{z} \) must contain the essential information.
This is the Variational Autoencoder (VAE). The encoder produces a distribution over latent vectors rather than a single point:
The decoder tries to reconstruct \( \mathbf{x} \) from a sample \( \mathbf{z} \sim q_\phi \). The training objective balances two competing goals:
The first term says: given the latent vector, recover the original input. The second term says: do not let the latent space go wild — keep it close to a standard Gaussian prior \( p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, I) \). That regularization is what gives the latent space its smooth, structured geometry. You can interpolate between two points in \( \mathbf{z} \)-space and get something sensible — that is the KL term doing its job.
3.3 Contrastive: Learn by Comparing
This is the approach behind most modern self-supervised learning, and it has a beautifully simple core idea. Take an image. Create two different augmented versions of it — random crops, color jitters, flips. These are called a positive pair. The claim is: their representations should be similar, because they came from the same underlying image. Meanwhile, representations of completely different images — negatives — should be pushed apart.
The loss that formalizes this (from SimCLR) is:
where \( \text{sim}(\mathbf{u}, \mathbf{v}) = \frac{\mathbf{u}^\top \mathbf{v}}{\|\mathbf{u}\|\|\mathbf{v}\|} \) is cosine similarity and \( \tau \) is a temperature parameter that controls sharpness. Minimize this loss and the encoder learns to map semantically similar things nearby and dissimilar things far apart — entirely without labels.
3.4 Multiview: Learn from Natural Pairings
The real world is full of natural pairs: an image and its caption, a spoken word and its written form, a video frame and its audio. Multiview learning exploits these pairings directly.
The classical mathematical version is Canonical Correlation Analysis (CCA). Given two views \( \mathbf{x}^{(1)} \) and \( \mathbf{x}^{(2)} \), find projections \( W_1 \) and \( W_2 \) that maximize the correlation between the projected views:
where \( \Sigma_{12} \) is the cross-covariance between views, and \( \Sigma_{11}, \Sigma_{22} \) are within-view covariances. The result: representations where matching pairs of images and captions end up in the same region of space.
CLIP scales this to neural networks with enormous datasets. For a batch of \( N \) image-text pairs, it maximizes similarity for the \( N \) matched pairs while minimizing it for all \( N^2 - N \) mismatched ones. The result is a shared embedding space where "a photo of a dog" and an actual photo of a dog are geometrically close — which is exactly why you can search images with text today.
4. The Theory Behind It: Why Does Any of This Work?
The previous section is about methods. This one is about understanding why those methods work — or more precisely, whether they provably work.
4.1 Identifiability: Can We Recover the Truth?
Here is a sobering question. Suppose the world has some true underlying factors \( \mathbf{z}^* \) — say, the actual shape, color, and position of an object. Your encoder produces \( \mathbf{z} \). The question is: does \( \mathbf{z} \) recover \( \mathbf{z}^* \)?
Formally, a model is identifiable if identical observed data implies identical parameters:
The uncomfortable truth is that for general VAEs, this fails. Many different latent configurations can produce the same observations. Without additional structure — like knowing that the latent factors are statistically independent — you cannot guarantee that the learned \( \mathbf{z} \) corresponds to anything meaningful about the real world.
Identifiable models require structural assumptions. One important example is the iVAE, which conditions the prior on an observed context variable \( \mathbf{u} \):
This factorized conditional prior, combined with observational data, is enough to uniquely recover the true latents under mild conditions. The lesson: identifiability is not automatic, and without it, you might be learning representations that are useful in practice but bear no principled relationship to the ground truth.
4.2 Information Maximization: The Unifying Principle
Perhaps the deepest theoretical unification in representation learning is this: a good representation should preserve as much information about the input as possible. Formally, maximize the mutual information between input and representation:
The problem is that mutual information is generally intractable to compute. The practical workaround is the InfoNCE bound, which provides a computable lower bound:
Look at that lower bound carefully. It looks almost exactly like the contrastive loss from Section 3.3. This is not a coincidence — it is the theoretical foundation of contrastive learning. When you train SimCLR, you are implicitly maximizing a lower bound on mutual information between the input and its representation. The practical algorithm and the theoretical objective are the same thing.
5. A Practical Failure Mode: Representation Collapse
No discussion of representation learning is complete without talking about what goes wrong. The most insidious failure mode is representation collapse — when the encoder learns to map every input to the same (or nearly the same) vector. A constant \( \mathbf{z} \) trivially satisfies any objective that only looks at positive pairs, since the similarity between a vector and itself is always 1.
This is not a hypothetical problem. Early contrastive learning papers were plagued by it. The standard solutions fall into two families:
Negative sampling: Explicitly include negative pairs in the loss (as SimCLR does). The denominator in the contrastive loss penalizes solutions where all embeddings cluster together. The catch: you need a large batch size to have enough diverse negatives, which is computationally expensive.
Architectural asymmetry: Methods like BYOL and SimSiam avoid negatives entirely by introducing asymmetry — an online network and a target network that are not updated by the same gradient. The stop-gradient on the target network breaks the symmetry that would otherwise let both networks collapse to the same constant. Barlow Twins takes yet another approach, directly penalizing off-diagonal correlations in the cross-correlation matrix of embeddings:
where \( \mathcal{C} \) is the cross-correlation matrix between embeddings of two augmented views. The first term forces each dimension to be useful; the second forces different dimensions to be decorrelated. Collapse is impossible because a constant embedding would set \( \mathcal{C}_{ii} = 0 \), maximally violating the first term.
The practical upshot: whenever you design a representation learning system, the first question to ask is "what stops this from collapsing?" If you do not have a clear answer, it probably will.
6. The Bigger Picture: Why Representation Learning Is the Substrate of Modern AI
It is worth stepping back and asking: why does any of this matter beyond an interesting set of algorithms?
Representation learning is not a subfield of AI — it is the substrate that everything else runs on. When a language model understands that "Paris is to France as Berlin is to Germany," that is geometry in a learned embedding space. When a medical image classifier outperforms a radiologist on some specific task, it is because the representation it learned from millions of images captures structure that no human-engineered feature ever captured. When you type a question to a search engine and get a document that never contains your exact words, that is semantic similarity computed in an embedding space.
The four approaches — supervised, generative, contrastive, multiview — are not competing alternatives. They are tools for different situations. When you have labels, use them. When you do not, contrastive or generative methods let you still learn from raw data. When your data has natural multiple views, multiview learning gives you supervision for free.
The open questions are real, though. Identifiability tells us that most representation learning methods are not provably recovering the true underlying structure — they are finding something useful without guarantees. Representation collapse lurks in every self-supervised setup. The gap between "useful in practice" and "provably correct" is where a lot of the frontier research lives.
But perhaps the most important insight is the simplest one: every AI system that processes data is, at some level, trying to find the right \( \mathbf{z} \). The question is just how — and whether it is finding the right one for the right reasons.

Comments
Post a Comment