Made ViT from scratch

This commit is contained in:
Koshin S Hegde 2024-09-09 14:44:54 +05:30
parent 9d0dc76903
commit 7e2d968730

View File

@ -2,8 +2,10 @@
import pandas
import datasets
import numpy
from tqdm import tqdm
import torch
#|%%--%%| <66pBRkviU5|LMiIYmSiRI>
# |%%--%%| <66pBRkviU5|LMiIYmSiRI>
_ = datasets.load_dataset(
"Hemg/deepfake-and-real-images",
@ -14,14 +16,325 @@ if not isinstance(_, datasets.Dataset):
raise Exception("Something went wrong")
original_raw_dataset: datasets.Dataset = _
#|%%--%%| <LMiIYmSiRI|8Yvl492qqH>
# |%%--%%| <LMiIYmSiRI|KugfehqjKv>
_ = original_raw_dataset.to_pandas()
number_of_images = len(original_raw_dataset)
if not isinstance(_, datasets.Dataset):
raise Exception("Something went wrong")
original_dataset: datasets.Dataset = _
# |%%--%%| <KugfehqjKv|EFzgkiWqFX>
#|%%--%%| <8Yvl492qqH|SZqCHvaVss>
number_of_images
# |%%--%%| <EFzgkiWqFX|AsQx8efYP3>
class PatchEmbedding(torch.nn.Module):
"""Turns a 2D input image into a 1D sequence learnable embedding vector.
Args:
in_channels (int): Number of color channels for the input images. Defaults to 3.
patch_size (int): Size of patches to convert input image into. Defaults to 16.
embedding_dim (int): Size of embedding to turn image into. Defaults to 768.
"""
# 2. Initialize the class with appropriate variables
def __init__(self,
in_channels:int=3,
patch_size:int=16,
embedding_dim:int=768):
super().__init__()
# 3. Create a layer to turn an image into patches
self.patch_size = patch_size
self.patcher = torch.nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0)
# 4. Create a layer to flatten the patch feature maps into a single dimension
self.flatten = torch.nn.Flatten(start_dim=2, # only flatten the feature map dimensions into a single vector
end_dim=3)
# 5. Define the forward method
def forward(self, x):
# Create assertion to check that inputs are the correct shape
image_resolution = x.shape[-1]
patch_size = self.patch_size
assert image_resolution % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"
# Perform the forward pass
x_patched = self.patcher(x)
x_flattened = self.flatten(x_patched)
# 6. Make sure the output shape has the right order
return x_flattened.permute((0, 2, 1))
# |%%--%%| <AsQx8efYP3|PQho0pEU3x>
class MultiheadSelfAttentionBlock(torch.nn.Module):
"""Creates a multi-head self-attention block ("MSA block" for short).
"""
# 2. Initialize the class with hyperparameters from Table 1
def __init__(self,
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
super().__init__()
# 3. Create the Norm layer (LN)
self.layer_norm = torch.nn.LayerNorm(normalized_shape=embedding_dim)
# 4. Create the Multi-Head Attention (MSA) layer
self.multihead_attn = torch.nn.MultiheadAttention(embed_dim=embedding_dim,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True) # does our batch dimension come first?
# 5. Create a forward() method to pass the data throguh the layers
def forward(self, x):
x = self.layer_norm(x)
attn_output, _ = self.multihead_attn(query=x, # query embeddings
key=x, # key embeddings
value=x, # value embeddings
need_weights=False) # do we need the weights or just the layer outputs?
return attn_output
# |%%--%%| <PQho0pEU3x|X9bS6VigyO>
class MLPBlock(torch.nn.Module):
"""Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
# 2. Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
dropout:float=0.1): # Dropout from Table 3 for ViT-Base
super().__init__()
# 3. Create the Norm layer (LN)
self.layer_norm = torch.nn.LayerNorm(normalized_shape=embedding_dim)
# 4. Create the Multilayer perceptron (MLP) layer(s)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(in_features=embedding_dim,
out_features=mlp_size),
torch.nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
torch.nn.Dropout(p=dropout),
torch.nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
out_features=embedding_dim), # take back to embedding_dim
torch.nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
)
# 5. Create a forward() method to pass the data throguh the layers
def forward(self, x):
x = self.layer_norm(x)
x = self.mlp(x)
return x
# |%%--%%| <X9bS6VigyO|xTGaCs6wAt>
class TransformerEncoderBlock(torch.nn.Module):
"""Creates a Transformer Encoder block."""
# 2. Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
attn_dropout:float=0): # Amount of dropout for attention layers
super().__init__()
# 3. Create MSA block (equation 2)
self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
attn_dropout=attn_dropout)
# 4. Create MLP block (equation 3)
self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
mlp_size=mlp_size,
dropout=mlp_dropout)
# 5. Create a forward() method
def forward(self, x):
# 6. Create residual connection for MSA block (add the input to the output)
x = self.msa_block(x) + x
# 7. Create residual connection for MLP block (add the input to the output)
x = self.mlp_block(x) + x
return x
# |%%--%%| <xTGaCs6wAt|s029U18Gni>
class ViT(torch.nn.Module):
"""Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
# 2. Initialize the class with hyperparameters from Table 1 and Table 3
def __init__(self,
img_size:int=256, # Training resolution from Table 3 in ViT paper
in_channels:int=3, # Number of channels in input image
patch_size:int=16, # Patch size
num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
num_heads:int=12, # Heads from Table 1 for ViT-Base
attn_dropout:float=0, # Dropout for attention projection
mlp_dropout:float=0.1, # Dropout for dense/MLP layers
embedding_dropout:float=0.1, # Dropout for patch and position embeddings
num_classes:int=2): # Default for ImageNet but can customize this
super().__init__() # don't forget the super().__init__()!
# 3. Make the image size is divisble by the patch size
assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."
# 4. Calculate number of patches (height * width/patch^2)
self.num_patches = (img_size * img_size) // patch_size**2
# 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
self.class_embedding = torch.nn.Parameter(data=torch.randn(1, 1, embedding_dim),
requires_grad=True)
# 6. Create learnable position embedding
self.position_embedding = torch.nn.Parameter(data=torch.randn(1, (self.num_patches)+1, embedding_dim),
requires_grad=True)
# 7. Create embedding dropout value
self.embedding_dropout = torch.nn.Dropout(p=embedding_dropout)
# 8. Create patch embedding layer
self.patch_embedding = PatchEmbedding(in_channels=in_channels,
patch_size=patch_size,
embedding_dim=embedding_dim)
# 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
# Note: The "*" means "all"
self.transformer_encoder = torch.nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_size=mlp_size,
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
# 10. Create classifier head
self.classifier = torch.nn.Sequential(
torch.nn.LayerNorm(normalized_shape=embedding_dim),
torch.nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
# 11. Create a forward() method
def forward(self, x):
# 12. Get batch size
batch_size = x.shape[0]
# 13. Create class token embedding and expand it to match the batch size (equation 1)
class_token = self.class_embedding.expand(batch_size, -1, -1) # "-1" means to infer the dimension (try this line on its own)
# 14. Create patch embedding (equation 1)
x = self.patch_embedding(x)
# 15. Concat class embedding and patch embedding (equation 1)
x = torch.cat((class_token, x), dim=1)
# 16. Add position embedding to patch embedding (equation 1)
x = self.position_embedding + x
# 17. Run embedding dropout (Appendix B.1)
x = self.embedding_dropout(x)
# 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
x = self.transformer_encoder(x)
# 19. Put 0 index logit through classifier (equation 4)
x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
return x
# |%%--%%| <s029U18Gni|D2Fnuvk9hw>
def train(model, dataloader, criterion, optimizer):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.cuda(), labels.cuda()
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images)
# Compute loss
loss = criterion(outputs, labels)
# Backward pass and optimize
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(dataloader)
epoch_accuracy = 100 * correct / total
print(f'Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
# |%%--%%| <D2Fnuvk9hw|zMpnWiyItv>
model = ViT().cuda()
# |%%--%%| <zMpnWiyItv|8WqgfhdZT8>
loss_function = torch.nn.CrossEntropyLoss()
# |%%--%%| <8WqgfhdZT8|Cl2qQDeIQV>
optimizer = torch.optim.Adam(model.parameters())
# |%%--%%| <Cl2qQDeIQV|vKz9NMcv7I>
images = []
labels = []
START_INDEX = 0
RANGE = 32
END_INDEX = START_INDEX + RANGE
images = numpy.array(
original_raw_dataset[START_INDEX:END_INDEX]["image"]
).tolist()
print(1)
images += numpy.array(
original_raw_dataset[-(START_INDEX + 1):-(END_INDEX + 1):-1]["image"]
).tolist()
print(2)
labels = numpy.array(
original_raw_dataset["label"][START_INDEX:END_INDEX]
).tolist()
print(3)
labels += numpy.array(
original_raw_dataset["label"][-(START_INDEX + 1):-(END_INDEX + 1):-1]
).tolist()
print(4)
labels = torch.tensor(labels).cuda()
images = torch.tensor(images).cuda().view((RANGE * 2, 3, 256, 256)).float()
print(5)
# |%%--%%| <vKz9NMcv7I|LYgaCybme2>
dataset = torch.utils.data.TensorDataset(images, labels)
# |%%--%%| <LYgaCybme2|KFseMis27r>
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
# |%%--%%| <KFseMis27r|PvIYi4SzgP>
while True:
train(model, dataloader, loss_function, optimizer)
# |%%--%%| <PvIYi4SzgP|W5EFRIN0B0>