Also can be seen at: https://medium.com/@heyulong3d/vision-transformers-vit-experiments-using-pytorch-and-pytorch-lightning-61e26738d9dd?sk=8326d1c2706380c7599c67e53d2e2b5c
Overview
This article will implement Vision Transformer (ViT) from scratch using PyTorch and PyTorch Lightning. It also covers insightful experiments with different patch size, model size, attention heads, and other improvements like overlapping patch embedding on CIFAR-10 dataset.
This article will focus more on practice and experiments rather than a theoretical introduction. References [1–5] would give this solid background on how Vision Tranformer works theoretically.
Basic Setup
I use Google Colab for this project (A100 GPU), and I also use Google Drive to store statistics permanently.
DEBUG_MODE = False # Logs will be saved on the temporarily instead of permanently
MAX_EPOCH = 200 if not DEBUG_MODE else 10
# Install required packages (Colab only, can skip locally if already installed)
!pip install pytorch-lightning torchmetrics torchvision seaborn --quiet
# Imports and settings
import os
import pandas as p
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torchvision
import matplotlib.pyplot as plt
import seaborn as sns
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger
pl.seed_everything(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# DATASET_PATH and CHECKPOINT_PATH for saving files
DATASET_PATH = "./data/"
if DEBUG_MODE:
CHECKPOINT_PATH = "./saved_models/ViT/"
else:
from google.colab import drive
drive.mount('/content/drive')
CHECKPOINT_PATH = "/content/drive/MyDrive/colab_checkpoints/ViT/"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
print("Device:", torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
When the DEBUG_MODE
turned off, it will try to mount my google drive and set my checkpoint path for it.
Baseline ViT Implementation
This is my baseline arguments:
patch_size = 2
num_patches = (32 // patch_size) ** 2
model_kwargs = {
"embed_dim": 256,
"hidden_dim": 512,
"num_heads": 8,
"num_layers": 6,
"patch_size": patch_size,
"num_channels": 3,
"num_patches": num_patches,
"num_classes": 10,
"dropout": 0.2,
}
lr = 3e-4
The dataset module part will download the dataset and split train/valid/test set.
# CIFAR10DataModule, and visualization
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, data_dir="./data", batch_size=128, num_workers=2):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]),
])
self.train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]),
])
def prepare_data(self):
# Download CIFAR10 only once (done on one process in distributed training)
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
train_dataset = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
val_dataset = CIFAR10(self.data_dir, train=True, transform=self.test_transform)
test_dataset = CIFAR10(self.data_dir, train=False, transform=self.test_transform)
# Deterministic split for reproducibility
torch.manual_seed(42)
self.train_set, _ = random_split(train_dataset, [45000, 5000])
torch.manual_seed(42)
_, self.val_set = random_split(val_dataset, [45000, 5000])
self.test_set = test_dataset
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers)
# Instantiate the datamodule
datamodule = CIFAR10DataModule(data_dir=DATASET_PATH, batch_size=128)
# Setup and visualize data (optional)
datamodule.prepare_data()
datamodule.setup()
# Visualize a few samples from validation set
NUM_IMAGES = 4
CIFAR_images = torch.stack([datamodule.val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)
plt.figure(figsize=(8, 8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
where train_dataset
is the augmented version (with train_transform
) and val_dataset
is the non-augmented version (with test_transform
).
# Patch utility function
def img_to_patch(x, patch_size, flatten_channels=True):
"""
Args:
x: Tensor of shape [B, C, H, W]
patch_size: int
flatten_channels: bool
Returns:
Patches of x as (B, num_patches, C*patch_size*patch_size)
"""
B, C, H, W = x.shape
x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W]
if flatten_channels:
x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W]
return x
# Model: Attention Block and Vision Transformer
class AttentionBlock(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.layer_norm_2 = nn.LayerNorm(embed_dim)
self.linear = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
inp_x = self.layer_norm_1(x)
x = x + self.attn(inp_x, inp_x, inp_x)[0]
x = x + self.linear(self.layer_norm_2(x))
return x
class VisionTransformer(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers,
num_classes, patch_size, num_patches, dropout=0.0):
super().__init__()
self.patch_size = patch_size
self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
self.transformer = nn.Sequential(
*(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
)
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
self.dropout = nn.Dropout(dropout)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
def forward(self, x):
x = img_to_patch(x, self.patch_size) # [B, num_patches, patch_dim]
B, T, _ = x.shape
x = self.input_layer(x) # [B, num_patches, embed_dim]
cls_token = self.cls_token.repeat(B, 1, 1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embedding[:, :T+1]
x = self.dropout(x)
x = x.transpose(0, 1) # Transformer expects [S, B, E]
x = self.transformer(x)
cls = x[0]
out = self.mlp_head(cls)
return out
The above defines the ViT’s architecture.
# PyTorch LightningModule for ViT
class ViT(pl.LightningModule):
def __init__(self, model_kwargs, lr):
super().__init__()
self.save_hyperparameters()
self.model = VisionTransformer(**model_kwargs)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
return [optimizer], [lr_scheduler]
def _calculate_loss(self, batch, mode="train"):
imgs, labels = batch
preds = self.model(imgs)
loss = F.cross_entropy(preds, labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
self.log(f"{mode}_loss", loss)
self.log(f"{mode}_acc", acc)
return loss
def training_step(self, batch, batch_idx):
loss = self._calculate_loss(batch, mode="train")
return loss
def validation_step(self, batch, batch_idx):
self._calculate_loss(batch, mode="val")
def test_step(self, batch, batch_idx):
self._calculate_loss(batch, mode="test")
The above defines the training scheme of ViT.
# Training and evaluation function
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
def train_model(
model_kwargs,
lr,
datamodule,
max_epochs=180,
logger=None,
log_every_n_steps=20,
verbose=True,
):
"""
Trains a Vision Transformer with a LightningDataModule.
Returns: (model, result_dict)
"""
import os
import pytorch_lightning as pl
# Callbacks
callbacks = [
ModelCheckpoint(
save_weights_only=True,
monitor="val_acc",
mode="max",
),
LearningRateMonitor("epoch"),
]
# Logger
if logger is None:
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir, name="vit_default")
# Create model
model = ViT(model_kwargs=model_kwargs, lr=lr)
# Trainer
trainer = pl.Trainer(
max_epochs=max_epochs,
accelerator="auto",
devices=1,
logger=logger,
callbacks=callbacks,
log_every_n_steps=log_every_n_steps,
enable_progress_bar=True,
)
# Train
trainer.fit(model, datamodule=datamodule)
# Load best model from checkpoint
# The best_model_path points to the model checkpoint with the highest validation accuracy (val_acc) achieved during training,
# as tracked by your ModelCheckpoint callback.
best_path = trainer.checkpoint_callback.best_model_path
model = ViT.load_from_checkpoint(best_path)
if verbose:
print(f"Loaded best model from {best_path}")
# Test on validation and test sets (uses test_step for both)
val_result = trainer.test(model, datamodule=datamodule, verbose=False)
result = {
"val_acc": val_result[0]["test_acc"] if val_result else None,
}
if verbose:
print("Validation accuracy:", result["val_acc"])
return model, result
The above defines a training function.
After that, we can call this function to train a baseline ViT model:
BASELINE_LOGDIR = os.path.join(CHECKPOINT_PATH, "lightning_logs_base")
custom_logger = TensorBoardLogger(
save_dir=BASELINE_LOGDIR,
name="", # Keep experiment name blank for cleaner folder, or set your own
version=None # Use default versioning if needed
)
model, results = train_model(
model_kwargs=model_kwargs,
lr=lr,
datamodule=datamodule,
max_epochs=MAX_EPOCH,
logger=custom_logger,
)
print("Baseline results:", results)
After finish training, use the command below to get the statistic visualisation.
%load_ext tensorboard
%tensorboard --logdir $BASELINE_LOGDIR
After 200 epochs, the training/valid accuracy and loss curves would be like this:

The above picture shows that training converged smoothly and there is a slight overfitting seen in rising validation loss.
And finally, the test accuracy is 0.7624.

Ablation Studies on ViT Architecture
Define a ablation-related training function and then call it.
def ablation_study_tensorboard(model_kwargs_base, lr, datamodule, max_epochs=100):
ablation_results = []
# Patch size ablation
patch_sizes = [2, 4, 8]
for patch_size in patch_sizes:
model_kwargs = model_kwargs_base.copy()
model_kwargs.update({"patch_size": patch_size, "num_patches": (32 // patch_size) ** 2})
exp_name = f"patch_{patch_size}"
logger = TensorBoardLogger(ABLATION_LOGDIR, name=exp_name)
_, result = train_model(
model_kwargs=model_kwargs,
lr=lr,
datamodule=datamodule,
max_epochs=max_epochs,
logger=logger,
verbose=True,
)
ablation_results.append({"ablation": "patch_size", "value": patch_size, **result})
# Embedding dimension ablation
embed_dims = [96, 192, 256, 512, 1024]
for embed_dim in embed_dims:
model_kwargs = model_kwargs_base.copy()
model_kwargs.update({"embed_dim": embed_dim, "hidden_dim": embed_dim * 2})
exp_name = f"embed_{embed_dim}"
logger = TensorBoardLogger(ABLATION_LOGDIR, name=exp_name)
_, result = train_model(
model_kwargs=model_kwargs,
lr=lr,
datamodule=datamodule,
max_epochs=max_epochs,
logger=logger,
verbose=True,
)
ablation_results.append({"ablation": "embed_dim", "value": embed_dim, **result})
# Number of layers ablation
num_layers_list = [4, 6, 8, 10, 12]
for num_layers in num_layers_list:
model_kwargs = model_kwargs_base.copy()
model_kwargs.update({"num_layers": num_layers})
exp_name = f"layers_{num_layers}"
logger = TensorBoardLogger(ABLATION_LOGDIR, name=exp_name)
_, result = train_model(
model_kwargs=model_kwargs,
lr=lr,
datamodule=datamodule,
max_epochs=max_epochs,
logger=logger,
verbose=True,
)
ablation_results.append({"ablation": "num_layers", "value": num_layers, **result})
# Number of heads ablation
num_heads_list = [1, 3, 8, 16, 32, 96]
for num_heads in num_heads_list:
model_kwargs = model_kwargs_base.copy()
model_kwargs.update({"num_heads": num_heads})
exp_name = f"heads_{num_heads}"
logger = TensorBoardLogger(ABLATION_LOGDIR, name=exp_name)
_, result = train_model(
model_kwargs=model_kwargs,
lr=lr,
datamodule=datamodule,
max_epochs=max_epochs,
logger=logger,
verbose=True,
)
ablation_results.append({"ablation": "num_heads", "value": num_heads, **result})
df = pd.DataFrame(ablation_results)
print(df)
return df
ablation_df = ablation_study_tensorboard(model_kwargs, lr, datamodule, max_epochs=MAX_EPOCH)
This will change different hyper-parameters (patch size, embedding dimension, number of layers, and number of heads) and run them one by one.
Use the command below to visualise the results:
ABLATION_LOGDIR = os.path.join(CHECKPOINT_PATH, "lightning_logs_ablation")
%load_ext tensorboard
%tensorboard --logdir $ABLATION_LOGDIR
Ablation studies — path size


This indicates:
- Insight: patch size defines resolution of “words”
- Smaller patches (patch_2, patch_4) achieve higher accuracy because of more fine-grained spatial information
- Smaller patches mean longer sequences, making training more expensive
Ablation studies — Embeding size


This indicates:
- Insight: embeding size controls width (how much info each token carries)
- Smaller may lack capacity to represent enough features
- Bigger is not always better: embed_1024 (likely due to overfitting, optimization instability)
- You need the “just right” zone
Ablation studies —Layer number


This indicates:
- Insight: layer number controls depth (how many times tokens are processed)
- Too big: 12/10 layers performs slightly worse, possibly due to overfitting or optimization difficulty (Deeper is not always better!)
- Too small: 4 layers underperforms, indicating insufficient capacity
Ablation studies — Attention heads


This indicates:
- Insight: attention head defines how many independent relationships exist between tokens
- 1 head severely limits the model’s ability to capture varied relationships
- Too many heads may make each head too “narrow,” thus introducing optimization difficulties.
- 4–16 heads are similar — difference is within noise
The best combination VS the baseline
Based on the previous observations, we update the arguments to try to find the best combination:
# run combination
model_kwargs2 = {
"embed_dim": 512,
"hidden_dim": 512,
"num_heads": 16,
"num_layers": 8,
"patch_size": 4,
"num_channels": 3,
"num_patches": 64,
"num_classes": 10,
"dropout": 0.2,
}
lr2 = 3e-4
COMB_LOGDIR = os.path.join(CHECKPOINT_PATH, "lightning_logs_comb")
custom_logger2 = TensorBoardLogger(
save_dir=COMB_LOGDIR,
name="E512H16L8P4", # Keep experiment name blank for cleaner folder, or set your own
version=None # Use default versioning if needed
)
model, results = train_model(
model_kwargs=model_kwargs2,
lr=lr,
datamodule=datamodule,
max_epochs=MAX_EPOCH,
logger=custom_logger2,
)
print("Combination results:", results)

We tested many combinations and found the best:
- Embedding dim: 512
- Num of heads: 8
- Num of layers: 6
- Patch size: 4
The test accuracy was improved from 0.7624 to 0.7804
Other Improvements
Overlapping patches

This method is very simple: just let stride be smaller than patch size so that those tokens will have overlapping information. In this way, each token will have more information compared to the none-overlapping patches. This redundant information can help learning details.
However, this method will also result in more tokens, thus it will bring in more triaining cost.
Cross ViT

Previously we tested different batch size, so one might ask the question: Can the patches of different batch size compensate each other? The answer is yes!
Smaller patches tend to capture local features while bigger patches tendo to capture global ones. Cross ViT uses two branches (S-Branch and L-Branch) to process the patches of different size, then it uses cross attention to combine the features of different resolution.
How does this cross attention mechasim work? Simply put, both branches will generate a feature block — let it be Feature_S and Feature_L. Then, the cross attention uses Feature_S to calculate Query_S, and uses Feature_L to calculate Key_L and Value_L, and finally uses Query_S, Key_L, and Value_L to calculate the Output_S. The same applies to the other branch: it uses Feature_L to calculate Query_L, and uses Feature_S to calculate Key_S and Value_S, and uses Query_L, Key_S, and Value_S to calculate the Output_L.
Therefore, Output_S and Output_L have information from the other branch. And finally the outputs are concatenate to calculate the result.
Implementation & Results
For these experiments, I did not run on Google Colab because I did not have enough resource then. So, I used my local machine (Window 11 + NVIDIA 4080 Ti) to run the script.
First, install dependencies:
pip install pytorch-lightning torchmetrics torchvision seaborn --quiet
git clone https://github.com/lucidrains/vit-pytorch.git
cd vit-pytorch
pip install .
The code below defines classes and functions:
DEBUG_MODE = False # Logs will be saved on the temporarily instead of permanently
MAX_EPOCH = 200 if not DEBUG_MODE else 5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import numpy as np
pl.seed_everything(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
DATASET_PATH = "./data/"
CHECKPOINT_PATH = "./saved_models/ViT/section_3_3"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
print("Device:", torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
from cifar10 import CIFAR10DataModule
datamodule = CIFAR10DataModule(batch_size=128)
class OverlappingPatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4, stride=2, in_chans=3, embed_dim=256):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H_patch, W_patch)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)
return x
class StandardPatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class AttentionBlock(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
inp_x = self.norm1(x)
x = x + self.attn(inp_x, inp_x, inp_x)[0]
x = x + self.ffn(self.norm2(x))
return x
class VisionTransformer(nn.Module):
def __init__(
self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes,
patch_size, num_patches, dropout=0.0, patch_type="standard", stride=4):
super().__init__()
# Patch embedding
if patch_type == "overlap":
self.patch_embed = OverlappingPatchEmbedding(
img_size=32, patch_size=patch_size, stride=stride,
in_chans=num_channels, embed_dim=embed_dim)
# Calculate number of patches dynamically:
num_patches = ((32 - patch_size)//stride + 1) ** 2
else:
self.patch_embed = StandardPatchEmbedding(
img_size=32, patch_size=patch_size, in_chans=num_channels, embed_dim=embed_dim)
num_patches = (32 // patch_size) ** 2
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
self.dropout = nn.Dropout(dropout)
self.transformer = nn.Sequential(*[
AttentionBlock(embed_dim, hidden_dim, num_heads, dropout)
for _ in range(num_layers)
])
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
def forward(self, x):
x = self.patch_embed(x) # (B, num_patches, embed_dim)
B, T, _ = x.shape
cls_token = self.cls_token.repeat(B, 1, 1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embedding[:, :T+1]
x = self.dropout(x)
x = x.transpose(0, 1) # [S, B, E]
x = self.transformer(x)
cls = x[0]
out = self.mlp_head(cls)
return out
class ViT(pl.LightningModule):
def __init__(self, model_kwargs, lr):
super().__init__()
self.save_hyperparameters()
self.model = VisionTransformer(**model_kwargs)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
return [optimizer], [lr_scheduler]
def _calculate_loss(self, batch, mode="train"):
imgs, labels = batch
preds = self.model(imgs)
loss = F.cross_entropy(preds, labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
self.log(f"{mode}_loss", loss)
self.log(f"{mode}_acc", acc)
return loss
def training_step(self, batch, batch_idx):
loss = self._calculate_loss(batch, mode="train")
return loss
def validation_step(self, batch, batch_idx):
self._calculate_loss(batch, mode="val")
def test_step(self, batch, batch_idx):
self._calculate_loss(batch, mode="test")
class ViT2(pl.LightningModule):
def __init__(self, model_kwargs, lr):
super().__init__()
from vit_pytorch.cross_vit import CrossViT
self.model = CrossViT(
image_size = 32,
num_classes = 10,
depth = 4, # number of multi-scale encoding blocks
sm_dim = 192, # high res dimension
sm_patch_size = 4, # high res patch size (should be smaller than lg_patch_size)
sm_enc_depth = 2, # high res depth
sm_enc_heads = 8, # high res heads
sm_enc_mlp_dim = 1024, # high res feedforward dimension
lg_dim = 384, # low res dimension
lg_patch_size = 16, # low res patch size
lg_enc_depth = 3, # low res depth
lg_enc_heads = 8, # low res heads
lg_enc_mlp_dim = 1024, # low res feedforward dimensions
cross_attn_depth = 2, # cross attention rounds
cross_attn_heads = 8, # cross attention heads
dropout = 0.1,
emb_dropout = 0.1
)
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
return [optimizer], [lr_scheduler]
def _calculate_loss(self, batch, mode="train"):
imgs, labels = batch
preds = self.model(imgs)
loss = F.cross_entropy(preds, labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
self.log(f"{mode}_loss", loss)
self.log(f"{mode}_acc", acc)
return loss
def training_step(self, batch, batch_idx):
loss = self._calculate_loss(batch, mode="train")
return loss
def validation_step(self, batch, batch_idx):
self._calculate_loss(batch, mode="val")
def test_step(self, batch, batch_idx):
self._calculate_loss(batch, mode="test")
def train_model(
model_kwargs,
lr,
datamodule,
max_epochs=180,
logger=None,
log_every_n_steps=20,
verbose=True,
):
callbacks = [
ModelCheckpoint(save_weights_only=True, monitor="val_acc", mode="max"),
LearningRateMonitor("epoch"),
]
if logger is None:
logger = TensorBoardLogger(CHECKPOINT_PATH, name="vit_default")
model = ViT2(model_kwargs=model_kwargs, lr=lr)
trainer = pl.Trainer(
default_root_dir=CHECKPOINT_PATH,
max_epochs=max_epochs,
accelerator="auto",
devices=1,
logger=logger,
callbacks=callbacks,
log_every_n_steps=log_every_n_steps,
enable_progress_bar=True,
)
trainer.fit(model, datamodule=datamodule)
best_path = trainer.checkpoint_callback.best_model_path
model = ViT2.load_from_checkpoint(best_path)
if verbose:
print(f"Loaded best model from {best_path}")
val_result = trainer.test(model, datamodule=datamodule, verbose=False)
result = {"val_acc": val_result[0]["test_acc"] if val_result else None}
if verbose:
print("Validation accuracy:", result["val_acc"])
return model, result
def train():
# Instantiate the datamodule
datamodule = CIFAR10DataModule(batch_size=128)
# Setup and visualize data (optional)
datamodule.prepare_data()
datamodule.setup()
"""
## Run Experiments
### A. Standard ViT (already done)
"""
"""
## Run Experiments
### B. Use Overlapping Patch Embedding
### C. Use Cross Vit (no overlap)
### D: Use CrossViT + Overlapping (Combine All Enhancements)
"""
model_kwargs = {
"embed_dim": 512,
"hidden_dim": 512,
"num_heads": 8,
"num_layers": 6,
"patch_size": 4,
"num_channels": 3,
"num_patches": 64,
"num_classes": 10,
"dropout": 0.2,
"patch_type": "standard",
"stride": 4,
}
lr = 3e-4
model_kwargs_swiglu = model_kwargs.copy()
logger = TensorBoardLogger(CHECKPOINT_PATH, name='CrossViT')
model, results = train_model(
model_kwargs=model_kwargs_swiglu,
lr=lr,
datamodule=datamodule,
max_epochs=MAX_EPOCH,
logger=logger
)
print("results:", results)
# ==== unified models for ViT and CrossViT (with optional overlapping patches) ====
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
# -------------------------------------------------
# Positional embedding helper: 2D sin/cos (shape-safe for any HxW)
# -------------------------------------------------
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32, device=None):
"""Returns [h*w, dim] sin-cos positional embedding."""
assert (dim % 4) == 0, "pos emb dim must be multiple of 4"
y, x = torch.meshgrid(torch.arange(h, device=device),
torch.arange(w, device=device), indexing="ij")
omega = torch.arange(dim // 4, device=device) / max(1, (dim // 4 - 1))
omega = temperature ** -omega
y = y.reshape(-1, 1) * omega.reshape(1, -1)
x = x.reshape(-1, 1) * omega.reshape(1, -1)
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1).to(dtype)
return pos_emb # [h*w, dim]
# -------------------------------------------------
# Vanilla ViT with selectable patch embed (standard or overlapping)
# -------------------------------------------------
class StandardPatchEmbedding(nn.Module):
"""Kernel == stride == patch_size (non-overlapping)."""
def __init__(self, in_chans, patch_size, embed_dim):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
self.ln = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x) # [B, C, H', W']
x = x.permute(0, 2, 3, 1).flatten(1, 2) # [B, N, C]
x = self.ln(x)
return x
class OverlapPatchEmbedding(nn.Module):
"""Kernel == patch_size, stride < patch_size (overlapping)."""
def __init__(self, in_chans, patch_size, stride, embed_dim):
super().__init__()
assert stride > 0 and stride <= patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, bias=False)
self.ln = nn.LayerNorm(embed_dim)
self.patch_size = patch_size
self.stride = stride
def forward(self, x):
x = self.proj(x) # [B, D, H', W']
x = x.permute(0, 2, 3, 1).flatten(1, 2) # [B, N, D]
x = self.ln(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=False)
self.norm2 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x): # x: [S, B, E]
y = self.norm1(x)
x = x + self.attn(y, y, y)[0]
x = x + self.ffn(self.norm2(x))
return x
class VanillaViT(nn.Module):
def __init__(
self,
image_size=32,
in_chans=3,
patch_size=4,
embed_dim=256,
hidden_dim=512,
depth=6,
heads=8,
dropout=0.2,
num_classes=10,
overlap=False,
stride=None,
pos_type="learned" # "learned" or "sincos"
):
super().__init__()
# patch embed
if overlap:
stride = stride if stride is not None else patch_size // 2
self.patch_embed = OverlapPatchEmbedding(in_chans, patch_size, stride, embed_dim)
num_patches = ((image_size - patch_size) // stride + 1) ** 2
else:
self.patch_embed = StandardPatchEmbedding(in_chans, patch_size, embed_dim)
num_patches = (image_size // patch_size) ** 2
# CLS + pos
self.cls = nn.Parameter(torch.randn(1, 1, embed_dim))
self.dropout = nn.Dropout(dropout)
self.pos_type = pos_type
if pos_type == "learned":
self.pos = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))
else:
# sin-cos buffer (no grad), created once
h = w = int(math.sqrt(num_patches))
pe = posemb_sincos_2d(h, w, embed_dim)
self.register_buffer("pos", pe, persistent=False) # [N, E]
# transformer
self.blocks = nn.Sequential(*[
AttentionBlock(embed_dim, hidden_dim, heads, dropout=dropout)
for _ in range(depth)
])
self.head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
def forward(self, x): # x: [B, C, H, W]
tok = self.patch_embed(x) # [B, N, E]
B, N, E = tok.shape
cls = self.cls.expand(B, -1, -1)
x = torch.cat([cls, tok], dim=1) # [B, 1+N, E]
if isinstance(self.pos, nn.Parameter): # learned
x = x + self.pos[:, : (N + 1)]
else:
# sincos buffer: self.pos is [N, E]
x[:, 1:] = x[:, 1:] + self.pos.unsqueeze(0).to(x.dtype).to(x.device)
x = self.dropout(x)
x = x.transpose(0, 1) # [S, B, E]
x = self.blocks(x) # transformer encoder
cls = x[0] # [B, E]
return self.head(cls)
# -------------------------------------------------
# CrossViT (original) and CrossViT with overlapping patch embedders
# -------------------------------------------------
import vit_pytorch.cross_vit as cv # your file, same API as in the snippet you shared
# We will re-use cv.MultiScaleEncoder, etc., and only swap the image embedders when overlap=True :contentReference[oaicite:1]{index=1}
class OverlapImageEmbedder(nn.Module):
"""Drop-in replacement for cv.ImageEmbedder that uses Conv2d (kernel=patch, stride=stride) + 2D sin/cos pos emb."""
def __init__(self, dim, image_size, patch_size, stride, channels=3, dropout=0.1):
super().__init__()
assert stride > 0 and stride <= patch_size
self.proj = nn.Conv2d(channels, dim, kernel_size=patch_size, stride=stride, bias=False)
self.ln = nn.LayerNorm(dim)
# token grid
H = (image_size - patch_size) // stride + 1
W = H
pe = posemb_sincos_2d(H, W, dim) # [H*W, D]
self.register_buffer("pos_embed", pe, persistent=False)
self.cls = nn.Parameter(torch.randn(1, 1, dim))
self.drop = nn.Dropout(dropout)
def forward(self, img):
x = self.proj(img) # [B, D, H', W']
x = x.permute(0, 2, 3, 1).flatten(1, 2) # [B, N, D]
x = self.ln(x)
# add sin/cos pos
x = x + self.pos_embed.unsqueeze(0).to(x.dtype).to(x.device)
cls = self.cls.expand(x.size(0), -1, -1)
x = torch.cat([cls, x], dim=1) # [B, 1+N, D]
return self.drop(x)
class CrossViTOverlap(nn.Module):
"""CrossViT where both streams use overlapping patch tokenizers."""
def __init__(
self,
*,
image_size,
num_classes,
sm_dim,
lg_dim,
sm_patch_size,
lg_patch_size,
sm_stride,
lg_stride,
sm_enc_depth=2,
sm_enc_heads=8,
sm_enc_mlp_dim=2048,
sm_enc_dim_head=64,
lg_enc_depth=3,
lg_enc_heads=8,
lg_enc_mlp_dim=2048,
lg_enc_dim_head=64,
cross_attn_depth=2,
cross_attn_heads=8,
cross_attn_dim_head=64,
depth=3,
dropout=0.1,
emb_dropout=0.1,
channels=3,
):
super().__init__()
# overlapping embedders for both scales
self.sm_image_embedder = OverlapImageEmbedder(dim=sm_dim, image_size=image_size,
patch_size=sm_patch_size, stride=sm_stride,
channels=channels, dropout=emb_dropout)
self.lg_image_embedder = OverlapImageEmbedder(dim=lg_dim, image_size=image_size,
patch_size=lg_patch_size, stride=lg_stride,
channels=channels, dropout=emb_dropout)
# keep the rest of CrossViT unchanged – reuse MultiScaleEncoder
self.multi_scale_encoder = cv.MultiScaleEncoder(
depth=depth,
sm_dim=sm_dim, lg_dim=lg_dim,
cross_attn_heads=cross_attn_heads, cross_attn_dim_head=cross_attn_dim_head, cross_attn_depth=cross_attn_depth,
sm_enc_params=dict(depth=sm_enc_depth, heads=sm_enc_heads, mlp_dim=sm_enc_mlp_dim, dim_head=sm_enc_dim_head),
lg_enc_params=dict(depth=lg_enc_depth, heads=lg_enc_heads, mlp_dim=lg_enc_mlp_dim, dim_head=lg_enc_dim_head),
dropout=dropout
)
self.sm_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes))
self.lg_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes))
def forward(self, img):
sm = self.sm_image_embedder(img)
lg = self.lg_image_embedder(img)
sm, lg = self.multi_scale_encoder(sm, lg)
sm_cls, lg_cls = sm[:, 0], lg[:, 0]
return self.sm_head(sm_cls) + self.lg_head(lg_cls)
# -------------------------------------------------
# Builder: pick 'vit' | 'cross_vit' | 'cross_vit_overlap' from one dict
# -------------------------------------------------
def build_model(cfg):
model_type = cfg.get("model_type", "vit")
if model_type == "vit":
return VanillaViT(
image_size=cfg.get("image_size", 32),
in_chans=cfg.get("in_chans", 3),
patch_size=cfg.get("patch_size", 4),
embed_dim=cfg.get("embed_dim", 256),
hidden_dim=cfg.get("hidden_dim", 512),
depth=cfg.get("depth", 6),
heads=cfg.get("heads", 8),
dropout=cfg.get("dropout", 0.2),
num_classes=cfg.get("num_classes", 10),
overlap=cfg.get("overlap", False),
stride=cfg.get("stride"),
pos_type=cfg.get("pos_type", "learned"),
)
elif model_type == "cross_vit":
# pure CrossViT (no overlap) – use original class
return cv.CrossViT(
image_size=cfg.get("image_size", 32),
num_classes=cfg.get("num_classes", 10),
depth=cfg.get("depth", 3),
sm_dim=cfg.get("sm_dim", 192),
sm_patch_size=cfg.get("sm_patch_size", 4), # must divide image_size
sm_enc_depth=cfg.get("sm_enc_depth", 2),
sm_enc_heads=cfg.get("sm_enc_heads", 8),
sm_enc_mlp_dim=cfg.get("sm_enc_mlp_dim", 1024),
lg_dim=cfg.get("lg_dim", 384),
lg_patch_size=cfg.get("lg_patch_size", 16), # must divide image_size
lg_enc_depth=cfg.get("lg_enc_depth", 3),
lg_enc_heads=cfg.get("lg_enc_heads", 8),
lg_enc_mlp_dim=cfg.get("lg_enc_mlp_dim", 1024),
cross_attn_depth=cfg.get("cross_attn_depth", 2),
cross_attn_heads=cfg.get("cross_attn_heads", 8),
dropout=cfg.get("dropout", 0.1),
emb_dropout=cfg.get("emb_dropout", 0.1),
)
elif model_type == "cross_vit_overlap":
# CrossViT + overlapping on small and/or large stream
return CrossViTOverlap(
image_size=cfg.get("image_size", 32),
num_classes=cfg.get("num_classes", 10),
depth=cfg.get("depth", 3),
sm_dim=cfg.get("sm_dim", 192),
sm_patch_size=cfg.get("sm_patch_size", 4),
sm_stride=cfg.get("sm_stride", 2), # <= overlap control
sm_enc_depth=cfg.get("sm_enc_depth", 2),
sm_enc_heads=cfg.get("sm_enc_heads", 8),
sm_enc_mlp_dim=cfg.get("sm_enc_mlp_dim", 1024),
lg_dim=cfg.get("lg_dim", 384),
lg_patch_size=cfg.get("lg_patch_size", 16),
lg_stride=cfg.get("lg_stride", 8), # <= overlap control
lg_enc_depth=cfg.get("lg_enc_depth", 3),
lg_enc_heads=cfg.get("lg_enc_heads", 8),
lg_enc_mlp_dim=cfg.get("lg_enc_mlp_dim", 1024),
cross_attn_depth=cfg.get("cross_attn_depth", 2),
cross_attn_heads=cfg.get("cross_attn_heads", 8),
dropout=cfg.get("dropout", 0.1),
emb_dropout=cfg.get("emb_dropout", 0.1),
)
else:
raise ValueError(f"Unknown model_type: {model_type}")
# -------------------------------------------------
# Lightning wrapper (unified)
# -------------------------------------------------
import pytorch_lightning as pl
import torch.optim as optim
class LitUnifiedViT(pl.LightningModule):
def __init__(self, cfg, lr=3e-4):
super().__init__()
self.save_hyperparameters()
self.model = build_model(cfg)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
opt = optim.AdamW(self.parameters(), lr=self.hparams.lr)
sch = optim.lr_scheduler.MultiStepLR(opt, milestones=[100, 150], gamma=0.1)
return [opt], [sch]
def _step(self, batch, tag):
imgs, labels = batch
logits = self(imgs)
loss = F.cross_entropy(logits, labels)
acc = (logits.argmax(dim=-1) == labels).float().mean()
self.log(f"{tag}_loss", loss, prog_bar=True)
self.log(f"{tag}_acc", acc, prog_bar=True)
return loss
def training_step(self, batch, batch_idx): return self._step(batch, "train")
def validation_step(self, batch, batch_idx): self._step(batch, "val")
def test_step(self, batch, batch_idx): self._step(batch, "test")
Then the training logic:
# from cifar10 import CIFAR10DataModule
dm = CIFAR10DataModule(batch_size=128)
dm.prepare_data(); dm.setup()
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
def train_one(cfg, run_name, max_epochs=MAX_EPOCH, lr=3e-4, precision="16-mixed"):
logger = TensorBoardLogger(save_dir="./saved_models/ViT/section_3_3", name=run_name)
callbacks = [
ModelCheckpoint(monitor="val_acc", mode="max", save_weights_only=True),
LearningRateMonitor(logging_interval="epoch"),
]
lit = LitUnifiedViT(cfg, lr=lr)
trainer = pl.Trainer(
accelerator="auto", devices=1,
precision=precision, # mixed precision to avoid OOM on overlap / CrossViT
max_epochs=max_epochs, logger=logger, callbacks=callbacks
)
trainer.fit(lit, datamodule=dm)
best = LitUnifiedViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, cfg=cfg, lr=lr)
test_metrics = trainer.test(best, datamodule=dm, verbose=False)[0]
print(run_name, "=>", {k: round(float(v), 4) for k, v in test_metrics.items()})
return test_metrics
# 1) Overlapping ViT
cfg_vit_overlap = dict(
model_type="vit", image_size=32, num_classes=10,
patch_size=4, overlap=True, stride=2, pos_type="sincos", # sin/cos avoids pos-emb size mismatch
embed_dim=512, hidden_dim=512, depth=6, heads=8, dropout=0.2
)
m1 = train_one(cfg_vit_overlap, run_name="vit_overlap")
# 2) CrossViT (no overlap)
cfg_cross = dict(
model_type="cross_vit", image_size=32, num_classes=10,
depth=3,
sm_dim=192, sm_patch_size=4, sm_enc_depth=2, sm_enc_heads=8, sm_enc_mlp_dim=1024,
lg_dim=384, lg_patch_size=16, lg_enc_depth=3, lg_enc_heads=8, lg_enc_mlp_dim=1024,
dropout=0.1, emb_dropout=0.1
)
m2 = train_one(cfg_cross, run_name="cross_vit")
# 3) CrossViT + Overlap (token counts grow fast; if OOM, lower sm_dim / depth or batch size)
cfg_cross_overlap = dict(
model_type="cross_vit_overlap", image_size=32, num_classes=10, depth=3,
sm_dim=192, sm_patch_size=4, sm_stride=2, sm_enc_depth=2, sm_enc_heads=8, sm_enc_mlp_dim=1024,
lg_dim=384, lg_patch_size=16, lg_stride=8, lg_enc_depth=3, lg_enc_heads=8, lg_enc_mlp_dim=1024,
dropout=0.1, emb_dropout=0.1
)
m3 = train_one(cfg_cross_overlap, run_name="cross_vit_overlap")
The results:


So, the accuracy has been improved from 0.7804 to 0.8738 in this dataset.
Conclusion
We explored basic ViT, ablation studies, and recent enhancements.
Ablation shows we should find the “just right” zone on:
- Patch size → controls “words” resolution
- Model size → controls feature’s “width” and process’s “depth”
- Attention heads → defines how many independent relationships exist between tokens
The best model uses: CrossViT + Overlap
- Overlap → richer local info
- CrossViT → combines features of different resolution
- Combine both → enhances local feature and holistic representation
That’s all. If you found this article helpful, please show your support by clicking the clap icon 👏 and following me 🙏.
References
- Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., … & Houlsby, N. (2020). An image is worth 16×16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
- Vision Transformer: What It Is & How It Works [2024 Guide]
- Transformer Explained
- Transformers, the tech behind LLMs | Deep Learning Chapter 5
- Vision Transformer (ViT): Tutorial + Baseline
- Building a Vision Transformer Model From Scratch
- lucidrains/vit-pytorch
- Tutorial 11: Vision Transformers
- https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
- Wang, W., Xie, E., Li, X., Fan, D. P., Song, K., Liang, D., … & Shao, L. (2022). Pvt v2: Improved baselines with pyramid vision transformer. Computational visual media, 8(3), 415–424.
- Chen, C. F. R., Fan, Q., & Panda, R. (2021). Crossvit: Cross-attention multi-scale vision transformer for image classification. In Proceedings of the IEEE/CVF international conference on computer vision (pp. 357–366).
- Arun, D., Ozturk, K., Bowyer, K. W., & Flynn, P. (2025). Improved Ear Verification with Vision Transformers and Overlapping Patches. arXiv preprint arXiv:2503.23275.