From 7e2d9687302b1df22e79d0af8741b616ecaf9a65 Mon Sep 17 00:00:00 2001
From: kosh <kosh@kosh-web.cfd>
Date: Mon, 9 Sep 2024 14:44:54 +0530
Subject: [PATCH] Made ViT from scratch

---
 src/main.py | 327 ++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 320 insertions(+), 7 deletions(-)

diff --git a/src/main.py b/src/main.py
index 84c5cb6..57a8869 100644
--- a/src/main.py
+++ b/src/main.py
@@ -2,8 +2,10 @@
 import pandas
 import datasets
 import numpy
+from tqdm import tqdm
+import torch
 
-#|%%--%%| <66pBRkviU5|LMiIYmSiRI>
+# |%%--%%| <66pBRkviU5|LMiIYmSiRI>
 
 _ = datasets.load_dataset(
         "Hemg/deepfake-and-real-images",
@@ -14,14 +16,325 @@ if not isinstance(_, datasets.Dataset):
     raise Exception("Something went wrong")
 original_raw_dataset: datasets.Dataset = _
 
-#|%%--%%| <LMiIYmSiRI|8Yvl492qqH>
+# |%%--%%| <LMiIYmSiRI|KugfehqjKv>
 
-_ = original_raw_dataset.to_pandas()
+number_of_images = len(original_raw_dataset)
 
-if not isinstance(_, datasets.Dataset):
-    raise Exception("Something went wrong")
-original_dataset: datasets.Dataset = _
+# |%%--%%| <KugfehqjKv|EFzgkiWqFX>
 
-#|%%--%%| <8Yvl492qqH|SZqCHvaVss>
+number_of_images
+
+# |%%--%%| <EFzgkiWqFX|AsQx8efYP3>
+
+class PatchEmbedding(torch.nn.Module):
+    """Turns a 2D input image into a 1D sequence learnable embedding vector.
+    
+    Args:
+        in_channels (int): Number of color channels for the input images. Defaults to 3.
+        patch_size (int): Size of patches to convert input image into. Defaults to 16.
+        embedding_dim (int): Size of embedding to turn image into. Defaults to 768.
+    """ 
+    # 2. Initialize the class with appropriate variables
+    def __init__(self, 
+                 in_channels:int=3,
+                 patch_size:int=16,
+                 embedding_dim:int=768):
+        super().__init__()
+        
+        # 3. Create a layer to turn an image into patches
+        self.patch_size = patch_size
+        self.patcher = torch.nn.Conv2d(in_channels=in_channels,
+                                 out_channels=embedding_dim,
+                                 kernel_size=patch_size,
+                                 stride=patch_size,
+                                 padding=0)
+
+        # 4. Create a layer to flatten the patch feature maps into a single dimension
+        self.flatten = torch.nn.Flatten(start_dim=2, # only flatten the feature map dimensions into a single vector
+                                  end_dim=3)
+
+    # 5. Define the forward method 
+    def forward(self, x):
+        # Create assertion to check that inputs are the correct shape
+        image_resolution = x.shape[-1]
+        patch_size = self.patch_size
+        assert image_resolution % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"
+        
+        # Perform the forward pass
+        x_patched = self.patcher(x)
+        x_flattened = self.flatten(x_patched) 
+        
+        # 6. Make sure the output shape has the right order 
+        return x_flattened.permute((0, 2, 1))
+
+# |%%--%%| <AsQx8efYP3|PQho0pEU3x>
+
+class MultiheadSelfAttentionBlock(torch.nn.Module):
+    """Creates a multi-head self-attention block ("MSA block" for short).
+    """
+    # 2. Initialize the class with hyperparameters from Table 1
+    def __init__(self,
+                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
+                 num_heads:int=12, # Heads from Table 1 for ViT-Base
+                 attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
+        super().__init__()
+        
+        # 3. Create the Norm layer (LN)
+        self.layer_norm = torch.nn.LayerNorm(normalized_shape=embedding_dim)
+        
+        # 4. Create the Multi-Head Attention (MSA) layer
+        self.multihead_attn = torch.nn.MultiheadAttention(embed_dim=embedding_dim,
+                                                    num_heads=num_heads,
+                                                    dropout=attn_dropout,
+                                                    batch_first=True) # does our batch dimension come first?
+        
+    # 5. Create a forward() method to pass the data throguh the layers
+    def forward(self, x):
+        x = self.layer_norm(x)
+        attn_output, _ = self.multihead_attn(query=x, # query embeddings 
+                                             key=x, # key embeddings
+                                             value=x, # value embeddings
+                                             need_weights=False) # do we need the weights or just the layer outputs?
+        return attn_output
+
+# |%%--%%| <PQho0pEU3x|X9bS6VigyO>
+
+class MLPBlock(torch.nn.Module):
+    """Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
+    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
+    def __init__(self,
+                 embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
+                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
+                 dropout:float=0.1): # Dropout from Table 3 for ViT-Base
+        super().__init__()
+        
+        # 3. Create the Norm layer (LN)
+        self.layer_norm = torch.nn.LayerNorm(normalized_shape=embedding_dim)
+        
+        # 4. Create the Multilayer perceptron (MLP) layer(s)
+        self.mlp = torch.nn.Sequential(
+            torch.nn.Linear(in_features=embedding_dim,
+                      out_features=mlp_size),
+            torch.nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
+            torch.nn.Dropout(p=dropout),
+            torch.nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
+                      out_features=embedding_dim), # take back to embedding_dim
+            torch.nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
+        )
+    
+    # 5. Create a forward() method to pass the data throguh the layers
+    def forward(self, x):
+        x = self.layer_norm(x)
+        x = self.mlp(x)
+        return x
+
+# |%%--%%| <X9bS6VigyO|xTGaCs6wAt>
+
+class TransformerEncoderBlock(torch.nn.Module):
+    """Creates a Transformer Encoder block."""
+    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
+    def __init__(self,
+                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
+                 num_heads:int=12, # Heads from Table 1 for ViT-Base
+                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
+                 mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
+                 attn_dropout:float=0): # Amount of dropout for attention layers
+        super().__init__()
+
+        # 3. Create MSA block (equation 2)
+        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
+                                                     num_heads=num_heads,
+                                                     attn_dropout=attn_dropout)
+        
+        # 4. Create MLP block (equation 3)
+        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
+                                   mlp_size=mlp_size,
+                                   dropout=mlp_dropout)
+        
+    # 5. Create a forward() method  
+    def forward(self, x):
+        
+        # 6. Create residual connection for MSA block (add the input to the output)
+        x =  self.msa_block(x) + x 
+        
+        # 7. Create residual connection for MLP block (add the input to the output)
+        x = self.mlp_block(x) + x 
+        
+        return x
+
+# |%%--%%| <xTGaCs6wAt|s029U18Gni>
+
+class ViT(torch.nn.Module):
+    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
+    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
+    def __init__(self,
+                 img_size:int=256, # Training resolution from Table 3 in ViT paper
+                 in_channels:int=3, # Number of channels in input image
+                 patch_size:int=16, # Patch size
+                 num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
+                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
+                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
+                 num_heads:int=12, # Heads from Table 1 for ViT-Base
+                 attn_dropout:float=0, # Dropout for attention projection
+                 mlp_dropout:float=0.1, # Dropout for dense/MLP layers 
+                 embedding_dropout:float=0.1, # Dropout for patch and position embeddings
+                 num_classes:int=2): # Default for ImageNet but can customize this
+        super().__init__() # don't forget the super().__init__()!
+        
+        # 3. Make the image size is divisble by the patch size 
+        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."
+        
+        # 4. Calculate number of patches (height * width/patch^2)
+        self.num_patches = (img_size * img_size) // patch_size**2
+                 
+        # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
+        self.class_embedding = torch.nn.Parameter(data=torch.randn(1, 1, embedding_dim),
+                                            requires_grad=True)
+        
+        # 6. Create learnable position embedding
+        self.position_embedding = torch.nn.Parameter(data=torch.randn(1, (self.num_patches)+1, embedding_dim),
+                                               requires_grad=True)
+                
+        # 7. Create embedding dropout value
+        self.embedding_dropout = torch.nn.Dropout(p=embedding_dropout)
+        
+        # 8. Create patch embedding layer
+        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
+                                              patch_size=patch_size,
+                                              embedding_dim=embedding_dim)
+        
+        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential()) 
+        # Note: The "*" means "all"
+        self.transformer_encoder = torch.nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
+                                                                            num_heads=num_heads,
+                                                                            mlp_size=mlp_size,
+                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
+       
+        # 10. Create classifier head
+        self.classifier = torch.nn.Sequential(
+            torch.nn.LayerNorm(normalized_shape=embedding_dim),
+            torch.nn.Linear(in_features=embedding_dim, 
+                      out_features=num_classes)
+        )
+    
+    # 11. Create a forward() method
+    def forward(self, x):
+        
+        # 12. Get batch size
+        batch_size = x.shape[0]
+        
+        # 13. Create class token embedding and expand it to match the batch size (equation 1)
+        class_token = self.class_embedding.expand(batch_size, -1, -1) # "-1" means to infer the dimension (try this line on its own)
+
+        # 14. Create patch embedding (equation 1)
+        x = self.patch_embedding(x)
+
+        # 15. Concat class embedding and patch embedding (equation 1)
+        x = torch.cat((class_token, x), dim=1)
+
+        # 16. Add position embedding to patch embedding (equation 1) 
+        x = self.position_embedding + x
+
+        # 17. Run embedding dropout (Appendix B.1)
+        x = self.embedding_dropout(x)
+
+        # 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
+        x = self.transformer_encoder(x)
+
+        # 19. Put 0 index logit through classifier (equation 4)
+        x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
+
+        return x       
+
+# |%%--%%| <s029U18Gni|D2Fnuvk9hw>
+
+def train(model, dataloader, criterion, optimizer):
+    model.train()
+    running_loss = 0.0
+    correct = 0
+    total = 0
+
+    for images, labels in dataloader:
+        images, labels = images.cuda(), labels.cuda()
+        
+        # Zero the parameter gradients
+        optimizer.zero_grad()
+        
+        # Forward pass
+        outputs = model(images)
+        
+        # Compute loss
+        loss = criterion(outputs, labels)
+        
+        # Backward pass and optimize
+        loss.backward()
+        optimizer.step()
+        
+        # Statistics
+        running_loss += loss.item()
+        _, predicted = torch.max(outputs, 1)
+        total += labels.size(0)
+        correct += (predicted == labels).sum().item()
+    
+    epoch_loss = running_loss / len(dataloader)
+    epoch_accuracy = 100 * correct / total
+    print(f'Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
+
+
+# |%%--%%| <D2Fnuvk9hw|zMpnWiyItv>
+
+model = ViT().cuda()
+
+# |%%--%%| <zMpnWiyItv|8WqgfhdZT8>
+
+loss_function = torch.nn.CrossEntropyLoss()
+
+# |%%--%%| <8WqgfhdZT8|Cl2qQDeIQV>
+
+optimizer = torch.optim.Adam(model.parameters())
+
+# |%%--%%| <Cl2qQDeIQV|vKz9NMcv7I>
+
+images = []
+labels = []
+START_INDEX = 0
+RANGE = 32
+END_INDEX = START_INDEX + RANGE
+images = numpy.array(
+    original_raw_dataset[START_INDEX:END_INDEX]["image"]
+    ).tolist()
+print(1)
+images += numpy.array(
+    original_raw_dataset[-(START_INDEX + 1):-(END_INDEX + 1):-1]["image"]
+).tolist()
+print(2)
+
+labels = numpy.array(
+    original_raw_dataset["label"][START_INDEX:END_INDEX]
+    ).tolist()
+print(3)
+labels += numpy.array(
+    original_raw_dataset["label"][-(START_INDEX + 1):-(END_INDEX + 1):-1]
+).tolist()
+print(4)
+
+labels = torch.tensor(labels).cuda()
+images = torch.tensor(images).cuda().view((RANGE * 2, 3, 256, 256)).float()
+print(5)
+
+# |%%--%%| <vKz9NMcv7I|LYgaCybme2>
+
+dataset = torch.utils.data.TensorDataset(images, labels)
+
+# |%%--%%| <LYgaCybme2|KFseMis27r>
+
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
+
+# |%%--%%| <KFseMis27r|PvIYi4SzgP>
+
+while True:
+    train(model, dataloader, loss_function, optimizer)
+
+# |%%--%%| <PvIYi4SzgP|W5EFRIN0B0>