End-to-End Medical Image AI: From Raw DICOM to Clinical Deployment
There is a question worth sitting with before writing a single line of code: when a radiologist reads a CT scan, they do not run three separate mental programs — one for loading the DICOM files, another for deciding which Hounsfield window to apply, a third for actually spotting the tumour. Something more practised happens. A trained visual system and decades of anatomical knowledge combine seamlessly into a single interpretive act. The pixels become a diagnosis.
Building a machine that replicates even a fraction of that process requires solving a surprising number of sub-problems in the right order. This post walks through every stage of the pipeline — from where to get publicly available MRI and CT data, through preprocessing, model architecture, loss functions, training mechanics, and all the way to clinical deployment and ongoing monitoring. No stage is skipped, no equation is hand-waved. If you have read the mathematics before, this should solidify it. If you have not, this is the derivation.
1. The Problem and the Notation
A medical imaging task takes a three-dimensional volume of numbers and produces structured clinical outputs. Let us fix notation that will persist through the entire post.
A single 3D volume (one CT scan or one MRI acquisition) is a tensor:
where \( D \) is the number of slices (depth), \( H \) is the in-plane height, and \( W \) is the in-plane width. Each element \( \mathbf{V}_{d,h,w} \) is one voxel (volumetric pixel). The physical size of each voxel is described by the voxel spacing:
For a typical chest CT, \( s_x = s_y \approx 0.7\,\text{mm} \) in-plane and \( s_z \in [1,5]\,\text{mm} \) along the slice axis — meaning a single volume might span \( 300 \times 512 \times 512 \) with highly anisotropic spacing. This anisotropy is one of the most important facts about medical imaging and must be handled explicitly in every preprocessing step.
A batch of \( B \) preprocessed volumes (after resampling to isotropic spacing and normalising to \( C \) channels) is:
where \( C = 1 \) for single-modality CT or \( C = 4 \) for multi-sequence brain MRI (T1, T2, FLAIR, T1ce as in BraTS). The targets depend on the task:
- Segmentation: \( \mathbf{Y} \in \{0, 1, \dots, C_{\text{seg}}-1\}^{B \times D' \times H' \times W'} \), a per-voxel class label over \( C_{\text{seg}} \) classes (background + anatomical structures).
- Classification: \( \mathbf{y} \in \{0, \dots, K-1\}^B \), one label per volume (e.g., benign vs. malignant).
- Detection: \( \{(\mathbf{b}_i, s_i)\}_{i=1}^{N} \), a set of axis-aligned 3D bounding boxes and confidence scores per volume.
2. Data Collection — What is Freely Available
Before any model can be trained, data must be found. The good news is that the medical imaging community has been systematically releasing large, well-annotated public datasets for over a decade. The Cancer Imaging Archive (TCIA) alone hosts over 150 collections totalling hundreds of terabytes of de-identified patient scans. Below is a map of the most useful sources, organised by anatomy and task.
2.1 Brain MRI Datasets
Brain MRI is the most dataset-rich area in medical imaging. The BraTS (Brain Tumour Segmentation) challenge, running annually since 2012, provides multi-sequence scans (T1, T2, FLAIR, T1ce) with expert-consensus tumour annotations across three sub-regions: whole tumour (WT), tumour core (TC), and enhancing tumour (ET). BraTS 2024 contains approximately 4,500 cases. The ADNI dataset covers Alzheimer's progression with longitudinal T1 MRI. OpenNeuro and OASIS-3 offer large healthy-brain cohorts for normative modelling.
2.2 Chest CT Datasets
The LUNA16 challenge provides 888 chest CTs from LIDC-IDRI with nodule annotations from four radiologists — making it the standard benchmark for pulmonary nodule detection. RSNA Pneumonia (2018 Kaggle challenge) offers 30,000 chest X-rays plus a CT extension. DeepLesion from the NIH covers 32,000 lesion annotations across body parts from 10,000 CT volumes. The COVID-19 CT Lung and Infection Segmentation dataset provides annotated lung and consolidation masks.
2.3 Abdominal CT Datasets
LiTS (Liver Tumour Segmentation) provides 201 CT scans with liver and liver-tumour masks — the primary benchmark for abdominal organ segmentation. KiTS23 (Kidney and Kidney Tumour Segmentation 2023) provides 599 cases with three classes: kidney, tumour, cyst. The NIH Pancreas CT dataset offers 82 scans (small but clinically challenging). CHAOS provides multi-organ abdominal CT and MRI for cross-modality comparison.
2.4 Accessing the Data
Most datasets require a data use agreement (DUA) signed through TCIA's portal or Kaggle account creation. Downloads are handled via:
pip install tcia-api pydicom nibabel
# List all available TCIA collections
from tcia_utils import nbia
collections = nbia.getCollectionValues()
# Download a specific collection (e.g., LIDC-IDRI)
nbia.downloadSeries(series_data, path="./data/lidc/")
# BraTS 2024: available via Synapse or HuggingFace
from datasets import load_dataset
ds = load_dataset("openmedical/brats2024")
3. Data Exploration and Format
Before writing a single preprocessing function, it is essential to understand the two dominant file formats and the physical meaning of the numbers stored inside them.
3.1 DICOM — The Clinical Standard
DICOM (Digital Imaging and Communications in Medicine) is not a single file but a folder of slice files, one per cross-sectional image. Each file carries the pixel data of that slice plus hundreds of embedded metadata tags. The most critical tags for preprocessing are:
| DICOM Tag | Attribute | Example Value | Why It Matters |
|---|---|---|---|
| (0018,0050) | SliceThickness | 1.5 mm | Sets voxel spacing along z-axis |
| (0028,0030) | PixelSpacing | [0.703, 0.703] mm | In-plane voxel size |
| (0028,1053) | RescaleSlope | 1.0 | Used to convert stored pixels to HU |
| (0028,1052) | RescaleIntercept | -1024 | Used to convert stored pixels to HU |
| (0028,1050) | WindowCenter | 40 | Scanner-suggested display window |
| (0028,1051) | WindowWidth | 400 | Scanner-suggested display window |
| (0020,0013) | InstanceNumber | 47 | Slice ordering — must sort before stacking |
| (0008,0060) | Modality | CT or MR | Determines normalisation strategy |
The raw stored pixel value \( p \) must be converted to a physically meaningful unit before any processing. For CT, the conversion yields Hounsfield Units (HU):
where \( s_{\text{slope}} \) is RescaleSlope (usually 1.0 for CT) and \( b_{\text{intercept}} \) is RescaleIntercept (usually −1024 for CT). Hounsfield Units are calibrated to the linear attenuation coefficient of water:
where \( \mu \) is the linear X-ray attenuation coefficient. By construction, water is always 0 HU and air is always −1000 HU, regardless of scanner manufacturer or acquisition parameters. This is why CT is much easier to normalize than MRI.
Reading a DICOM series in Python:
import pydicom, numpy as np, os
def load_ct_volume(folder):
slices = [pydicom.dcmread(os.path.join(folder, f))
for f in os.listdir(folder) if f.endswith(".dcm")]
slices.sort(key=lambda s: float(s.InstanceNumber))
# Convert each slice to HU
volume = np.stack([
s.pixel_array * float(s.RescaleSlope) + float(s.RescaleIntercept)
for s in slices
], axis=0).astype(np.float32) # shape: (D, H, W)
spacing = (
float(slices[0].SliceThickness),
float(slices[0].PixelSpacing[0]),
float(slices[0].PixelSpacing[1])
)
return volume, spacing # HU array + voxel spacing in mm
3.2 NIfTI — The Research Standard
NIfTI (Neuroimaging Informatics Technology Initiative) stores the entire 3D volume in a single .nii or .nii.gz file. The header encodes the affine matrix \( \mathbf{A} \in \mathbb{R}^{4 \times 4} \), which maps voxel indices to mm coordinates in scanner space:
The diagonal elements of \( \mathbf{A} \) give the voxel spacings; the sign encodes orientation (RAS vs LAS convention). Reading with nibabel:
import nibabel as nib
img = nib.load("brain.nii.gz")
vol = img.get_fdata().astype(np.float32) # (H, W, D) — note axis order!
spacing = img.header.get_zooms()[:3] # (sx, sy, sz) in mm
affine = img.affine # 4x4 world transform
A critical gotcha: NIfTI stores data as \( (H, W, D) \) while PyTorch expects \( (C, D, H, W) \). Always permute axes explicitly after loading.
3.3 CT Hounsfield Scale — What the Numbers Mean
Unlike MRI, CT values have a physical calibration. Every tissue type occupies a known HU range. Understanding these ranges is what makes manual windowing decisions interpretable.
3.4 MRI — Relative Signal, No Absolute Scale
Unlike CT, MRI signal intensity has no universal physical calibration. The value at a voxel depends on the scanner manufacturer, field strength (1.5T, 3T, 7T), pulse sequence parameters (TR, TE, flip angle), and the specific coil configuration. The same brain tissue can produce a bright voxel at 3T and a dim one at 1.5T.
The MRI signal for a T1-weighted spin-echo sequence is approximately:
where \( \rho \) is proton density, \( T_1 \) is the longitudinal relaxation time (how fast tissue recovers after RF pulse), \( T_2^* \) is the effective transverse relaxation time, TR is the repetition time, and TE is the echo time. Different TR/TE choices select for different tissue contrasts. Because of all these variable factors, all MRI preprocessing must normalize within each scan, not across scans.
4. The Preprocessing Pipeline — The Full Mathematics
Raw DICOM cannot be fed to a neural network. Voxel spacings differ across scanners, value ranges differ across protocols, orientations differ across acquisition conventions, and non-brain tissue in MRI confounds normalization. The preprocessing pipeline standardises all of these dimensions.
4.1 Resampling to Isotropic Spacing
The resampling step is the most computationally important. Given a volume with original spacing \( \mathcal{S} = (s_z, s_y, s_x) \) mm and target isotropic spacing \( s^* \) mm, the new dimensions are:
The value at each new voxel \( (d, h, w) \) is computed by trilinear interpolation. Let \( \phi_d = d \cdot s^*/s_z \) be the fractional position in the original volume along depth, and define \( d_0 = \lfloor \phi_d \rfloor \), \( d_1 = d_0 + 1 \), \( \alpha_d = \phi_d - d_0 \). Similarly for \( h, w \). Then:
where the interpolation weights are \( w_d^{(0)} = 1 - \alpha_d \), \( w_d^{(1)} = \alpha_d \) (and similarly for the other axes). Segmentation masks must be resampled with nearest-neighbour interpolation to avoid interpolating between class labels.
4.2 CT Windowing
A chest CT volume contains HU values ranging from −1000 (air in lungs) to +3000 (metal implants). For any given task, only a narrow range is clinically relevant. Windowing clips and rescales to that range:
where \( W_C \) is the window centre (also called window level) and \( W_L \) is the window width. Standard clinical windows are:
| Task | Window Centre | Window Width | HU Range |
|---|---|---|---|
| Brain CT | 40 | 80 | 0 to 80 |
| Lung CT | −600 | 1500 | −1350 to 150 |
| Abdomen CT | 60 | 400 | −140 to 260 |
| Bone CT | 700 | 3000 | −800 to 2200 |
| Liver CT | 60 | 150 | −15 to 135 |
For models reading multi-window representations, it is common to treat each window as a separate input channel — creating a 3-channel input from a single-modality CT scan analogous to RGB channels.
4.3 MRI Intensity Normalization
Since MRI has no absolute scale, we normalize within each scan. The standard approach is z-score normalization over the foreground mask \( \Omega \) (the set of non-zero voxels, or the brain mask after skull stripping):
This is computed per-scan, per-channel. Computing statistics over the full volume (including air background with value 0) artificially inflates \( \mu_\Omega \) and shrinks \( \sigma_\Omega \), leading to poor normalisation.
4.4 N4 Bias Field Correction (MRI)
MRI suffers from a slow, smooth intensity inhomogeneity caused by imperfect RF coil sensitivity profiles. A voxel in the same tissue type can be 30% brighter near the coil than far from it. The bias field model is:
where \( B(\mathbf{x}) \) is the multiplicative bias field (slowly varying, close to 1 everywhere) and \( n(\mathbf{x}) \) is additive noise. In the log domain:
N4 (N4ITK) estimates \( B \) iteratively by minimising the energy:
where the second term penalises non-smooth bias fields. In practice:
import SimpleITK as sitk
def n4_correct(sitk_image):
mask = sitk.OtsuThreshold(sitk_image, 0, 1, 200)
corrector = sitk.N4BiasFieldCorrectionImageFilter()
corrector.SetMaximumNumberOfIterations([50, 50, 50, 50])
return corrector.Execute(sitk_image, mask)
img = sitk.ReadImage("brain.nii.gz", sitk.sitkFloat32)
img_corrected = n4_correct(img)
4.5 The Full MONAI Preprocessing Compose
from monai.transforms import (
LoadImaged, EnsureChannelFirstd, Orientationd,
Spacingd, NormalizeIntensityd, ScaleIntensityRanged,
CropForegroundd, Compose
)
# CT segmentation preprocessing (e.g., LiTS liver)
ct_preprocess = Compose([
LoadImaged(keys=["image","label"]),
EnsureChannelFirstd(keys=["image","label"]),
Orientationd(keys=["image","label"], axcodes="RAS"),
Spacingd(keys=["image","label"],
pixdim=(1.0, 1.0, 1.0),
mode=("bilinear","nearest")),
ScaleIntensityRanged(keys=["image"],
a_min=-200, a_max=250,
b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=["image","label"], source_key="image"),
])
# MRI brain preprocessing (e.g., BraTS multi-sequence)
mri_preprocess = Compose([
LoadImaged(keys=["T1","T2","FLAIR","T1ce","label"]),
EnsureChannelFirstd(keys=["T1","T2","FLAIR","T1ce","label"]),
ConcatItemsd(keys=["T1","T2","FLAIR","T1ce"], name="image"),
Orientationd(keys=["image","label"], axcodes="RAS"),
Spacingd(keys=["image","label"],
pixdim=(1.0, 1.0, 1.0),
mode=("bilinear","nearest")),
NormalizeIntensityd(keys=["image"],
nonzero=True, channel_wise=True),
])
5. Dataset Splits and Augmentation
5.1 The Patient-Level Split Rule
The most critical rule in medical imaging ML: always split at the patient level, never at the image or slice level. A single patient can contribute dozens or hundreds of slices or patches. If patient A's slice 47 is in training and slice 48 is in validation, the model has effectively seen the patient — this is data leakage and will produce optimistically wrong validation metrics. Formally, let \( \mathcal{P} = \{p_1, p_2, \dots, p_N\} \) be the set of unique patients. We partition:
Typical split: 70% train, 15% val, 15% test. Stratify by: disease label, scanner manufacturer, imaging site, and — for longitudinal datasets — ensure no patient's follow-up scan ends up in a different split from their baseline.
For small datasets (\( N < 300 \) patients), use \( k \)-fold cross-validation with GroupKFold(n_splits=5, groups=patient_ids) from scikit-learn, where the grouping variable is the patient ID.
5.2 Patch-Based Training — The Memory Solution
A full 3D CT volume at 1 mm spacing occupies approximately \( 300 \times 512 \times 512 \times 4 \approx 300\,\text{MB}\) in float32. With a batch of four and a U-Net's feature maps, GPU memory exceeds 80 GB — impossible on any single device. The solution is patch-based training: randomly sample sub-volumes.
For a volume of size \( D^* \times H^* \times W^* \) and a patch of size \( d_p \times h_p \times w_p \), the top-left corner of each patch is drawn uniformly:
Standard patch sizes: \( 128^3 \) voxels (aggressive models), \( 96^3 \) (conservative), \( 64^3 \) (minimum for fine structures). With foreground oversampling, 2 out of every 3 patches are forced to contain at least one foreground (tumour) voxel — otherwise rare lesions are almost never seen during training.
5.3 Augmentation — Full Mathematics
Augmentation is the single most important regulariser for small medical datasets. Three families of transforms are applied in sequence.
Spatial / geometric augmentations. A general 3D affine transformation is expressed as a \( 4 \times 4 \) homogeneous matrix:
where the rotation matrices about each axis are:
and \( \mathbf{S}_{\text{scale}} = \text{diag}(\gamma, \gamma, \gamma, 1) \) with \( \gamma \sim \mathcal{U}(0.85, 1.25) \). The output voxel at location \( \mathbf{x}' \) takes the value from the input at \( \mathbf{T}^{-1}\mathbf{x}' \), interpolated trilinearly. Rotation angles \( \theta_x, \theta_y, \theta_z \sim \mathcal{U}(-30°, 30°) \) for brain, \( \pm 10° \) for abdominal (anatomy has strong up/down prior).
Elastic deformation. A smooth random displacement field \( \boldsymbol{\delta}: \mathbb{R}^3 \to \mathbb{R}^3 \) is constructed from uniform noise convolved with a Gaussian:
where \( \mathcal{G}_\sigma \) is a Gaussian kernel with standard deviation \( \sigma \) (larger \( \sigma \) = smoother deformation) and \( \alpha \) controls displacement magnitude. Typical values: \( \alpha \in [100, 200] \) voxels, \( \sigma \in [5, 8] \) voxels. The deformed volume is:
Intensity augmentations. Applied after spatial augmentations to avoid HU scale drift:
Gamma correction: Models non-linear scanner response:
Gaussian noise: Models low-dose CT quantum noise and MRI thermal noise:
Simulated MRI bias field: A synthetic slow-varying multiplicative field is applied during training to make the model robust to real scanner inhomogeneity:
where \( \phi_k \) are low-frequency basis functions (e.g., polynomials up to degree 3).
from monai.transforms import (
RandSpatialCropd, RandFlipd, RandRotate90d,
RandAffined, RandElasticDeformationd,
RandGaussianNoised, RandAdjustContrastd,
RandGibbsNoised, RandBiasFieldd, Compose
)
train_aug = Compose([
RandSpatialCropd(keys=["image","label"],
roi_size=(128,128,128),
random_center=True),
RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=1),
RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=2),
RandRotate90d(keys=["image","label"], prob=0.5),
RandAffined(keys=["image","label"],
prob=0.3,
rotate_range=(0.52, 0.52, 0.52),
scale_range=(0.2, 0.2, 0.2),
mode=("bilinear","nearest")),
RandElasticDeformationd(keys=["image","label"],
sigma_range=(5,8),
magnitude_range=(100,200),
prob=0.3,
mode=("bilinear","nearest")),
RandGaussianNoised(keys=["image"], prob=0.2, std=0.05),
RandAdjustContrastd(keys=["image"], prob=0.3,
gamma=(0.7,1.5)),
RandBiasFieldd(keys=["image"], prob=0.3), # MRI only
])
6. Architecture Design — The 3D U-Net in Full Detail
The U-Net, introduced by Ronneberger et al. in 2015 and extended to 3D by Çiçek et al. in 2016, remains the dominant architecture for medical image segmentation. Its defining innovation is the skip connection: instead of compressing all information into a single bottleneck vector, intermediate feature maps from the encoder are concatenated with corresponding decoder feature maps at the same spatial resolution. This lets the decoder recover fine-grained spatial detail that the pooling operations would otherwise discard.
6.1 Encoder Block — Double Convolution with Instance Normalisation
Each encoder level consists of two successive 3D convolutions with normalisation and activation. We prefer Instance Normalisation over Batch Normalisation in 3D medical imaging because batch sizes are small (often 2 or 4 volumes) and IN computes statistics per channel per sample, making it independent of batch size:
where \( \mathbf{h}_{\ell-1}^{\downarrow} \) is the downsampled output of the previous level. The Instance Normalisation at each position \( (c, d, h, w) \) is:
where \( \mu_c = \frac{1}{D'H'W'}\sum_{d,h,w} \mathbf{h}_{c,d,h,w} \) is computed over the spatial dimensions of a single channel of a single sample, and \( \gamma_c, \beta_c \) are learnable scale and shift parameters.
6.2 Downsampling — Strided Convolution
nnU-Net replaces max pooling with a 2×2×2 strided convolution, which is learnable and has been shown to perform slightly better:
This halves the spatial size in each dimension: \( (D, H, W) \to (D/2, H/2, W/2) \). If the original spacing is anisotropic (e.g., 0.7 mm in-plane, 3 mm axial), nnU-Net uses non-isotropic pooling — only downsampling in-plane until spatial sizes become comparable across dimensions.
6.3 Skip Connections and Decoder Upsampling
The decoder at level \( \ell \) first upsamples the feature map from the level below using a transposed convolution:
This doubles the spatial size. The skip connection concatenates the encoder features at the same level:
The channel count after concatenation is \( C_\ell + C_{\ell+1}/2 \) (encoder channels + upsampled decoder channels). This concatenated tensor is processed by another double convolution, identical in structure to the encoder block.
6.4 Output Head
At the finest decoder level, a \( 1\times1\times1 \) convolution maps from 32 feature channels to \( C_{\text{seg}} \) class logits, followed by a softmax:
The softmax ensures predictions are probability distributions over classes at each voxel.
6.5 Deep Supervision (nnU-Net)
nnU-Net adds auxiliary output heads at every decoder level except the coarsest. Each head is a \( 1\times1\times1 \) convolution producing class probabilities at the resolution of that level. The combined loss uses geometrically decreasing weights:
where \( \mathbf{Y}_\ell^{\downarrow} \) is the ground-truth mask downsampled to the resolution of level \( \ell \). Deep supervision prevents vanishing gradients in the early encoder layers — those layers are now directly connected to a loss, not just indirectly through the bottleneck.
6.6 Model in Code
from monai.networks.nets import UNet
model = UNet(
spatial_dims=3, # 3D convolutions throughout
in_channels=1, # 1 for CT, 4 for BraTS MRI
out_channels=3, # background + 2 tumour regions
channels=(32, 64, 128, 256, 320), # feature channels per level
strides=(2, 2, 2, 2), # downsampling at each level
num_res_units=2, # residual units in each block
norm="INSTANCE", # instance norm for small batches
dropout=0.1,
).to(device)
# nnU-Net self-configuring version
from nnunetv2.run.run_training import run_training
# nnU-Net reads your dataset, sets spacing, patch size, architecture
# automatically based on median patient spacing and GPU memory
run_training("Dataset001_BraTS", "3d_fullres", fold=0)
7. The Loss Function — Three Objectives for Segmentation
Choosing the loss function is where the mathematics of medical image segmentation diverges most sharply from natural image tasks. The fundamental problem is class imbalance: in a 128³ brain tumour segmentation patch, the tumour enhancement region might occupy 0.5% of voxels. Training with vanilla cross-entropy on such data produces models that learn to predict "background" for everything — achieving 99.5% pixel accuracy while being clinically useless.
7.1 Multi-Class Dice Loss
The Dice Similarity Coefficient (DSC) measures the overlap between two sets. For a single class \( c \), with \( p_{ic} \in [0,1] \) the predicted probability at voxel \( i \) and \( g_{ic} \in \{0,1\} \) the ground-truth indicator:
The differentiable soft Dice loss (averaged over all classes):
Note the squared denominator (\( p_{ic}^2 \) rather than \( p_{ic} \)): this is the square-form Dice, which has a smoother gradient landscape. The smoothing constant \( \varepsilon = 1 \) prevents division by zero when both prediction and ground truth are zero (empty class).
The key property: the Dice loss is invariant to class imbalance. Whether the foreground occupies 1% or 50% of voxels, a model predicting all zeros achieves \( \mathcal{L}_\text{Dice} = 1 \) (worst possible), forcing the model to actually find the structure. Vanilla cross-entropy gives a low loss to the all-zeros predictor (99% correct on a 1%-foreground dataset).
7.2 Focal Loss
Focal loss (Lin et al., RetinaNet 2017) modulates the cross-entropy by a factor that downweights well-classified (easy) examples, concentrating training on hard, misclassified voxels:
where \( p_t \) is the model's estimated probability for the correct class:
Standard hyperparameters: \( \alpha = 0.25 \), \( \gamma = 2 \). When the model confidently predicts the correct class (\( p_t \to 1 \)), the factor \( (1 - p_t)^\gamma \to 0 \), so the easy example contributes negligibly to the loss. When the model is confused (\( p_t \to 0.5 \)), the factor approaches 1 and the loss is close to standard cross-entropy.
7.3 Tversky Loss — For Highly Imbalanced Tiny Lesions
For tasks with extremely rare structures (MS lesions, microbleeds, coronary plaques), the Tversky loss allows separate penalty weights for false positives and false negatives:
with \( \alpha = 0.3 \) (FP weight) and \( \beta = 0.7 \) (FN weight). Setting \( \alpha = \beta = 0.5 \) recovers the Dice loss. The higher \( \beta \) means missing a lesion voxel (false negative) is penalised more than incorrectly labelling a background voxel as lesion (false positive) — the right clinical prior for screening applications.
7.4 The Combined Loss
The standard training objective combines Dice (which handles global overlap) and cross-entropy or Focal (which provides stable per-voxel gradients, especially in early training when predictions are random):
This is the DiceFocal loss, which is the default in MONAI for 3D segmentation. For classification tasks, standard binary cross-entropy is used with the logit output of a global average pooled encoder:
8. Optimization — How the Weights Actually Move
8.1 Optimizer Choice
nnU-Net, surprisingly, uses SGD with Nesterov momentum rather than Adam, and outperforms Adam on most medical tasks. The Nesterov update is:
with momentum \( \mu = 0.99 \) and weight decay \( \lambda_\text{wd} = 3 \times 10^{-5} \). For transformer-based models (UNETR, Swin-UNet), AdamW is preferred:
with \( \beta_1 = 0.9, \beta_2 = 0.999, \varepsilon = 10^{-8} \). The weight decay term \( -\eta_t \lambda_{\text{wd}} \theta_t \) (applied directly to weights, not to the gradient-adaptive learning rate) is AdamW's key distinction from Adam, providing proper L2 regularisation.
8.2 Learning Rate Schedule
nnU-Net uses a polynomial decay schedule over a fixed number of training iterations \( T_{\max} \):
with \( \eta_0 = 0.01 \) for SGD and \( \eta_0 = 10^{-4} \) for AdamW. This decays more aggressively than cosine annealing in the early phase and more gently late — empirically better for 3D medical tasks. For fine-tuning from a pretrained backbone:
8.3 Gradient Clipping
In 3D networks with many levels, gradients can explode during early training, especially through skip connections. Global gradient norm clipping bounds the update:
nnU-Net uses \( \tau = 12 \); transformer-based models typically use \( \tau = 1.0 \).
8.4 Exponential Moving Average (EMA)
After a warmup period (e.g., 150 epochs), a shadow copy of the parameters is updated as a running exponential average:
The EMA weights are smoother and generalize better than the raw weights, because they average over many recent iterations and are less sensitive to the noise of any single mini-batch. Validation and inference always use the EMA weights after they become available.
8.5 Training Loop
for epoch in range(1000):
# Phase control
if epoch < 100:
freeze(backbone)
else:
unfreeze(backbone)
if epoch >= 150:
ema.enabled = True
# Foreground-oversampled patch sampling
set_oversample_ratio(fg_prob=0.667)
for batch in train_loader:
images = batch["image"].to(device) # B×C×128³
labels = batch["label"].to(device) # B×1×128³
# Forward
with torch.cuda.amp.autocast():
logits = model(images) # B×Cseg×128³
loss = dice_focal_loss(logits, labels)
# Backward
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 12.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
ema.update()
scheduler.step()
# Validate with EMA weights using sliding window
val_dice = sliding_window_inference(
model_ema, val_loader,
roi_size=(128,128,128),
sw_batch_size=4,
overlap=0.5
)
if val_dice > best_dice:
best_dice = val_dice
torch.save(ema.state_dict(), "best_model.pth")
9. Inference — Sliding Window and Test-Time Augmentation
9.1 Sliding Window Inference
A full CT volume (e.g., 300×512×512 at 1 mm) does not fit in GPU memory during inference. Sliding window inference partitions the volume into overlapping patches, runs each through the model, and aggregates predictions with Gaussian weighting:
For a volume of size \( D \times H \times W \) and patch size \( d_p \times h_p \times w_p \), the stride in each dimension with overlap ratio \( r \) is:
Typical overlap: \( r = 0.5 \). The aggregated prediction at each voxel is:
where the Gaussian importance weight for patch \( i \) centred at \( \mathbf{c}_i \) is:
The Gaussian weighting downweights predictions near the patch boundary (which have less context) and upweights predictions near the patch centre.
9.2 Test-Time Augmentation (TTA)
Applying multiple augmented versions of the input volume and averaging their predictions improves performance by 0.5–1.5% Dice at the cost of \( N_{\text{TTA}} \) forward passes. Standard TTA for 3D medical images applies all 8 combinations of axis-flips:
where \( \mathcal{F} = \{\text{identity}, \text{flip}_x, \text{flip}_y, \text{flip}_z, \text{flip}_{xy}, \text{flip}_{xz}, \text{flip}_{yz}, \text{flip}_{xyz}\} \) and \( \mathcal{F}_f^{-1} \) reverses the augmentation on the prediction. The CTC-like argmax over averaged probabilities is taken at the end, not per-augmentation.
10. Evaluation Metrics — The Clinical Standard
10.1 Dice Score (DSC)
The Dice Similarity Coefficient is the primary metric for segmentation. At inference, predictions are thresholded at 0.5 to produce binary masks \( A = \{\mathbf{x} : \hat{p}(\mathbf{x}) > 0.5\} \) and \( B = \{\mathbf{x} : y(\mathbf{x}) = 1\} \):
State-of-the-art benchmarks for context: liver segmentation (LiTS) → DSC ≈ 0.96; tumour (LiTS) → DSC ≈ 0.70; whole brain tumour (BraTS) → DSC ≈ 0.92; tumour core → DSC ≈ 0.88; enhancing tumour → DSC ≈ 0.81.
10.2 Hausdorff Distance 95 (HD₉₅)
The standard Hausdorff distance is sensitive to outliers (a single stray predicted voxel far from the tumour yields an enormous distance). The 95th-percentile variant is more robust:
where \( \partial A \) is the surface (boundary voxels) of set \( A \), \( d(a, B) = \min_{b \in B} \|a - b\|_2 \) is the point-to-surface distance in mm, and \( \text{perc}_{95} \) is the 95th percentile. HD₉₅ captures the worst remaining 5% of boundary errors after removing the most extreme outliers.
10.3 Average Symmetric Surface Distance (ASSD)
ASSD reports the mean boundary error rather than the 95th-percentile:
ASSD is reported alongside HD₉₅ as a complementary measure: a model can have low ASSD (mostly good boundaries) but high HD₉₅ (occasional large errors on difficult regions).
10.4 Detection: FROC Curve
For lesion detection tasks (nodule detection in LUNA16), the standard metric is the Free-Response ROC (FROC) curve, which plots sensitivity versus average number of false positives per scan. The competition metric is the mean sensitivity at 7 predefined false positive rates:
10.5 Inter-Rater Agreement as the Performance Ceiling
Human expert performance on medical imaging tasks is itself imperfect. Two radiologists labelling the same tumour agree with a Dice score of roughly 0.75–0.85 depending on tumour type. This inter-rater agreement sets the theoretical ceiling for model performance — a model that exceeds inter-rater agreement is likely overfit to one annotator's style. Always compute and report inter-rater Dice alongside model Dice to give the metric meaning.
11. Deployment — From .pth File to Clinical System
11.1 ONNX Export and TensorRT Optimization
import torch, torch.onnx
# Export to ONNX (with dynamic batch and spatial dims)
dummy_input = torch.randn(1, 1, 128, 128, 128).to(device)
torch.onnx.export(
model_ema,
dummy_input,
"model.onnx",
opset_version=17,
input_names=["volume"],
output_names=["segmentation"],
dynamic_axes={
"volume": {0: "batch", 2: "depth", 3: "height", 4: "width"},
"segmentation": {0: "batch", 2: "depth", 3: "height", 4: "width"},
},
do_constant_folding=True,
)
# TensorRT INT8 calibration + engine build
import tensorrt as trt
builder = trt.Builder(trt.Logger(trt.Logger.WARNING))
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = MyCalibrator(calib_loader)
engine = builder.build_engine(network, config)
11.2 DICOM Integration — Returning Results to PACS
The model output (a segmentation mask) must be converted back to DICOM-native formats before it can appear in the radiologist's workflow. Two formats are used:
- DICOM SEG: Encodes the segmentation mask as a DICOM object with one frame per annotated slice. Each segment has a coded category (e.g., "Liver", coded from SNOMED-CT).
- DICOM SR (Structured Report): A text-based DICOM object encoding quantitative measurements — tumour volume (mL), longest diameter (mm), Dice vs. prior scan — that populates the radiologist's worklist automatically.
from monai.deploy.operators import (
DICOMDataLoaderOperator,
DICOMSeriesToVolumeOperator,
InferenceOperator,
DICOMSegWriterOperator,
)
from monai.deploy.core import Application
class LiverSegApp(Application):
def compose(self):
load = DICOMDataLoaderOperator()
convert = DICOMSeriesToVolumeOperator()
infer = InferenceOperator(model_path="model.onnx",
preprocess=ct_preprocess,
postprocess=sliding_window)
write = DICOMSegWriterOperator(segment_labels=["Background","Liver","Tumour"])
self.add_flow(load, convert)
self.add_flow(convert, infer)
self.add_flow(infer, write)
app = LiverSegApp()
app.run()
11.3 Regulatory — FDA and CE Mark
Any medical AI system intended for clinical decision support is a Software as a Medical Device (SaMD) and must undergo regulatory clearance. In the United States this means an FDA 510(k) pre-market notification (demonstrating substantial equivalence to an existing cleared device) or a De Novo pathway (for novel device types). In the European Union, the Medical Device Regulation (MDR) requires a CE mark with conformity assessment by a notified body.
The regulatory submission must include: a clinical validation study on a demographically diverse prospective cohort; a detailed description of the intended use and contraindications; an Instructions for Use (IFU) document; a Post-Market Surveillance (PMS) plan; software lifecycle documentation (IEC 62304); and a risk management file (ISO 14971).
12. Monitoring, Drift, and Maintenance
Deploying a model is not the end — it is the beginning of ongoing maintenance. Medical AI models degrade silently when hospital equipment changes, acquisition protocols drift, or patient demographics shift. The mathematics of distribution shift provides a language for detecting these failures before they cause clinical harm.
12.1 Population Stability Index (PSI)
The Population Stability Index quantifies how much the input distribution has changed between training and production. The input (e.g., mean HU per scan, slice thickness, scanner manufacturer) is discretised into \( B \) buckets. Let \( E_j \) be the training-time proportion in bucket \( j \) and \( A_j \) the production proportion:
Note: PSI is a symmetrised KL divergence. Interpretation thresholds:
| PSI value | Interpretation | Action |
|---|---|---|
| < 0.1 | No significant change | Continue monitoring |
| 0.1 – 0.25 | Moderate shift | Investigate; consider recalibration |
| > 0.25 | Major shift | Retrain on new domain data immediately |
12.2 Prediction Distribution Drift — KL Divergence
Monitor the distribution of model output predictions over time. Let \( P_{\text{train}} \) be the prediction distribution on the training set and \( P_{\text{new}} \) on the current production window (e.g., last 200 scans). The KL divergence:
where \( c \) could be the predicted foreground fraction per scan. A sudden increase in \( D_{\text{KL}} \) signals that the model is outputting systematically different predictions — a leading indicator of degraded performance before any labelled validation data confirms it.
12.3 The Continuous Learning Loop
# Monitoring + retraining loop (runs monthly)
1. COLLECT: Pull last 30 days of production scans
2. FLAG: Run PSI on input statistics (HU mean, spacing, scanner)
3. SAMPLE: If PSI > 0.10, flag 50 random scans for radiologist review
4. LABEL: Radiologist corrects model predictions (active learning)
5. VALIDATE: New labelled set → compute Dice vs. prospective hold-out
6. RETRAIN: Combine original training data + new data
Use replay buffer (30% old, 70% new) to avoid forgetting
7. SHADOW: Deploy new model alongside old; both predict same scans
8. A/B: If new model ≥ old model on Dice and HD95 over 2 weeks
9. PROMOTE: Replace production model; archive old weights with tag
10. LOG: Update model card, version registry, FDA PMS report
12.4 What to Monitor in Production
| Signal | Metric | Alert threshold |
|---|---|---|
| Input images | Mean HU, slice thickness, scanner model | PSI > 0.10 |
| Predictions | Mean foreground fraction per scan | ±2σ from training baseline |
| Confidence | Mean max-softmax probability per voxel | Drop > 5% |
| Latency | Inference time per volume (p95) | > 90 s |
| Failures | Rate of preprocessing errors or NaN outputs | > 1% |
| Clinical | Radiologist correction rate (active learning) | > 20% |
13. The Bigger Picture — Why Each Piece Had to Be There
It is worth stepping back and asking why every decision in this pipeline was made — not what was done, but why the simpler alternative failed.
Why resample to isotropic spacing? A 3×3×3 convolutional kernel assumes cubic voxels. When the z-spacing is 5 mm and the in-plane spacing is 0.7 mm, a 3×3×3 kernel covers 2.1 mm in-plane but 15 mm axially — it is not looking at the same anatomical neighbourhood in all directions. Isotropy is the prerequisite for any 3D convolution to be spatially meaningful.
Why Dice loss instead of cross-entropy? Because the cross-entropy gradient at a voxel is independent of whether any other voxel is correctly classified. On a 1%-foreground dataset, the background gradient signal dominates by 99:1. Dice directly measures the quantity we care about — overlap — and its gradient is distributed over the entire foreground region.
Why Instance Normalization instead of Batch Normalization? Because 3D batch sizes are 2–4 volumes. Batch Normalization statistics computed over 2 samples are noisy and unstable. Instance Normalization computes statistics per-channel per-sample, making it independent of batch size.
Why skip connections? Because spatial detail is progressively lost through pooling. A tumour boundary that was 1 voxel wide at 128×128 is invisible at 8×8. The encoder retains this fine-grained information; the skip connection routes it directly to the decoder at the matching resolution, bypassing the bottleneck entirely for this spatial detail.
Why sliding window inference? Because the full volume does not fit in GPU memory. Why Gaussian weighting? Because predictions near patch boundaries have less surrounding context and are less reliable. The Gaussian weight encodes this uncertainty geometrically.
Why EMA weights? Because the raw training weights oscillate around a good solution as the optimizer bounces around the loss landscape. The exponential moving average smooths out these oscillations, acting as an implicit ensemble of the model at many recent iterations.
Why monitor PSI in production? Because the model has no way to tell you it is wrong — it will produce confident-looking predictions on out-of-distribution inputs. PSI gives the system a way to raise a flag before a radiologist is misled.
Every piece of this pipeline was added to solve a specific failure mode of a simpler design. Isotropic resampling was added because anisotropic convolutions failed. Dice loss because cross-entropy produced all-background models. Instance Norm because Batch Norm collapsed on size-2 batches. Sliding window because full-volume inference ran out of memory. EMA because checkpoint instability hurt generalisation. PSI monitoring because silent degradation is the most dangerous kind of model failure.
Medical imaging is the field where getting the engineering right has the most direct impact on human outcomes. The mathematics above is not abstract — it is the formal language for ensuring that a model trained on data from one hospital actually helps patients at another.





Comments
Post a Comment