From 7e2d9687302b1df22e79d0af8741b616ecaf9a65 Mon Sep 17 00:00:00 2001 From: kosh <kosh@kosh-web.cfd> Date: Mon, 9 Sep 2024 14:44:54 +0530 Subject: [PATCH] Made ViT from scratch --- src/main.py | 327 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 320 insertions(+), 7 deletions(-) diff --git a/src/main.py b/src/main.py index 84c5cb6..57a8869 100644 --- a/src/main.py +++ b/src/main.py @@ -2,8 +2,10 @@ import pandas import datasets import numpy +from tqdm import tqdm +import torch -#|%%--%%| <66pBRkviU5|LMiIYmSiRI> +# |%%--%%| <66pBRkviU5|LMiIYmSiRI> _ = datasets.load_dataset( "Hemg/deepfake-and-real-images", @@ -14,14 +16,325 @@ if not isinstance(_, datasets.Dataset): raise Exception("Something went wrong") original_raw_dataset: datasets.Dataset = _ -#|%%--%%| <LMiIYmSiRI|8Yvl492qqH> +# |%%--%%| <LMiIYmSiRI|KugfehqjKv> -_ = original_raw_dataset.to_pandas() +number_of_images = len(original_raw_dataset) -if not isinstance(_, datasets.Dataset): - raise Exception("Something went wrong") -original_dataset: datasets.Dataset = _ +# |%%--%%| <KugfehqjKv|EFzgkiWqFX> -#|%%--%%| <8Yvl492qqH|SZqCHvaVss> +number_of_images + +# |%%--%%| <EFzgkiWqFX|AsQx8efYP3> + +class PatchEmbedding(torch.nn.Module): + """Turns a 2D input image into a 1D sequence learnable embedding vector. + + Args: + in_channels (int): Number of color channels for the input images. Defaults to 3. + patch_size (int): Size of patches to convert input image into. Defaults to 16. + embedding_dim (int): Size of embedding to turn image into. Defaults to 768. + """ + # 2. Initialize the class with appropriate variables + def __init__(self, + in_channels:int=3, + patch_size:int=16, + embedding_dim:int=768): + super().__init__() + + # 3. Create a layer to turn an image into patches + self.patch_size = patch_size + self.patcher = torch.nn.Conv2d(in_channels=in_channels, + out_channels=embedding_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0) + + # 4. Create a layer to flatten the patch feature maps into a single dimension + self.flatten = torch.nn.Flatten(start_dim=2, # only flatten the feature map dimensions into a single vector + end_dim=3) + + # 5. Define the forward method + def forward(self, x): + # Create assertion to check that inputs are the correct shape + image_resolution = x.shape[-1] + patch_size = self.patch_size + assert image_resolution % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}" + + # Perform the forward pass + x_patched = self.patcher(x) + x_flattened = self.flatten(x_patched) + + # 6. Make sure the output shape has the right order + return x_flattened.permute((0, 2, 1)) + +# |%%--%%| <AsQx8efYP3|PQho0pEU3x> + +class MultiheadSelfAttentionBlock(torch.nn.Module): + """Creates a multi-head self-attention block ("MSA block" for short). + """ + # 2. Initialize the class with hyperparameters from Table 1 + def __init__(self, + embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base + num_heads:int=12, # Heads from Table 1 for ViT-Base + attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks + super().__init__() + + # 3. Create the Norm layer (LN) + self.layer_norm = torch.nn.LayerNorm(normalized_shape=embedding_dim) + + # 4. Create the Multi-Head Attention (MSA) layer + self.multihead_attn = torch.nn.MultiheadAttention(embed_dim=embedding_dim, + num_heads=num_heads, + dropout=attn_dropout, + batch_first=True) # does our batch dimension come first? + + # 5. Create a forward() method to pass the data throguh the layers + def forward(self, x): + x = self.layer_norm(x) + attn_output, _ = self.multihead_attn(query=x, # query embeddings + key=x, # key embeddings + value=x, # value embeddings + need_weights=False) # do we need the weights or just the layer outputs? + return attn_output + +# |%%--%%| <PQho0pEU3x|X9bS6VigyO> + +class MLPBlock(torch.nn.Module): + """Creates a layer normalized multilayer perceptron block ("MLP block" for short).""" + # 2. Initialize the class with hyperparameters from Table 1 and Table 3 + def __init__(self, + embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base + mlp_size:int=3072, # MLP size from Table 1 for ViT-Base + dropout:float=0.1): # Dropout from Table 3 for ViT-Base + super().__init__() + + # 3. Create the Norm layer (LN) + self.layer_norm = torch.nn.LayerNorm(normalized_shape=embedding_dim) + + # 4. Create the Multilayer perceptron (MLP) layer(s) + self.mlp = torch.nn.Sequential( + torch.nn.Linear(in_features=embedding_dim, + out_features=mlp_size), + torch.nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)." + torch.nn.Dropout(p=dropout), + torch.nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above + out_features=embedding_dim), # take back to embedding_dim + torch.nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.." + ) + + # 5. Create a forward() method to pass the data throguh the layers + def forward(self, x): + x = self.layer_norm(x) + x = self.mlp(x) + return x + +# |%%--%%| <X9bS6VigyO|xTGaCs6wAt> + +class TransformerEncoderBlock(torch.nn.Module): + """Creates a Transformer Encoder block.""" + # 2. Initialize the class with hyperparameters from Table 1 and Table 3 + def __init__(self, + embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base + num_heads:int=12, # Heads from Table 1 for ViT-Base + mlp_size:int=3072, # MLP size from Table 1 for ViT-Base + mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base + attn_dropout:float=0): # Amount of dropout for attention layers + super().__init__() + + # 3. Create MSA block (equation 2) + self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim, + num_heads=num_heads, + attn_dropout=attn_dropout) + + # 4. Create MLP block (equation 3) + self.mlp_block = MLPBlock(embedding_dim=embedding_dim, + mlp_size=mlp_size, + dropout=mlp_dropout) + + # 5. Create a forward() method + def forward(self, x): + + # 6. Create residual connection for MSA block (add the input to the output) + x = self.msa_block(x) + x + + # 7. Create residual connection for MLP block (add the input to the output) + x = self.mlp_block(x) + x + + return x + +# |%%--%%| <xTGaCs6wAt|s029U18Gni> + +class ViT(torch.nn.Module): + """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default.""" + # 2. Initialize the class with hyperparameters from Table 1 and Table 3 + def __init__(self, + img_size:int=256, # Training resolution from Table 3 in ViT paper + in_channels:int=3, # Number of channels in input image + patch_size:int=16, # Patch size + num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base + embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base + mlp_size:int=3072, # MLP size from Table 1 for ViT-Base + num_heads:int=12, # Heads from Table 1 for ViT-Base + attn_dropout:float=0, # Dropout for attention projection + mlp_dropout:float=0.1, # Dropout for dense/MLP layers + embedding_dropout:float=0.1, # Dropout for patch and position embeddings + num_classes:int=2): # Default for ImageNet but can customize this + super().__init__() # don't forget the super().__init__()! + + # 3. Make the image size is divisble by the patch size + assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}." + + # 4. Calculate number of patches (height * width/patch^2) + self.num_patches = (img_size * img_size) // patch_size**2 + + # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings) + self.class_embedding = torch.nn.Parameter(data=torch.randn(1, 1, embedding_dim), + requires_grad=True) + + # 6. Create learnable position embedding + self.position_embedding = torch.nn.Parameter(data=torch.randn(1, (self.num_patches)+1, embedding_dim), + requires_grad=True) + + # 7. Create embedding dropout value + self.embedding_dropout = torch.nn.Dropout(p=embedding_dropout) + + # 8. Create patch embedding layer + self.patch_embedding = PatchEmbedding(in_channels=in_channels, + patch_size=patch_size, + embedding_dim=embedding_dim) + + # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential()) + # Note: The "*" means "all" + self.transformer_encoder = torch.nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_size=mlp_size, + mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)]) + + # 10. Create classifier head + self.classifier = torch.nn.Sequential( + torch.nn.LayerNorm(normalized_shape=embedding_dim), + torch.nn.Linear(in_features=embedding_dim, + out_features=num_classes) + ) + + # 11. Create a forward() method + def forward(self, x): + + # 12. Get batch size + batch_size = x.shape[0] + + # 13. Create class token embedding and expand it to match the batch size (equation 1) + class_token = self.class_embedding.expand(batch_size, -1, -1) # "-1" means to infer the dimension (try this line on its own) + + # 14. Create patch embedding (equation 1) + x = self.patch_embedding(x) + + # 15. Concat class embedding and patch embedding (equation 1) + x = torch.cat((class_token, x), dim=1) + + # 16. Add position embedding to patch embedding (equation 1) + x = self.position_embedding + x + + # 17. Run embedding dropout (Appendix B.1) + x = self.embedding_dropout(x) + + # 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3) + x = self.transformer_encoder(x) + + # 19. Put 0 index logit through classifier (equation 4) + x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index + + return x + +# |%%--%%| <s029U18Gni|D2Fnuvk9hw> + +def train(model, dataloader, criterion, optimizer): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + for images, labels in dataloader: + images, labels = images.cuda(), labels.cuda() + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass + outputs = model(images) + + # Compute loss + loss = criterion(outputs, labels) + + # Backward pass and optimize + loss.backward() + optimizer.step() + + # Statistics + running_loss += loss.item() + _, predicted = torch.max(outputs, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + epoch_loss = running_loss / len(dataloader) + epoch_accuracy = 100 * correct / total + print(f'Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%') + + +# |%%--%%| <D2Fnuvk9hw|zMpnWiyItv> + +model = ViT().cuda() + +# |%%--%%| <zMpnWiyItv|8WqgfhdZT8> + +loss_function = torch.nn.CrossEntropyLoss() + +# |%%--%%| <8WqgfhdZT8|Cl2qQDeIQV> + +optimizer = torch.optim.Adam(model.parameters()) + +# |%%--%%| <Cl2qQDeIQV|vKz9NMcv7I> + +images = [] +labels = [] +START_INDEX = 0 +RANGE = 32 +END_INDEX = START_INDEX + RANGE +images = numpy.array( + original_raw_dataset[START_INDEX:END_INDEX]["image"] + ).tolist() +print(1) +images += numpy.array( + original_raw_dataset[-(START_INDEX + 1):-(END_INDEX + 1):-1]["image"] +).tolist() +print(2) + +labels = numpy.array( + original_raw_dataset["label"][START_INDEX:END_INDEX] + ).tolist() +print(3) +labels += numpy.array( + original_raw_dataset["label"][-(START_INDEX + 1):-(END_INDEX + 1):-1] +).tolist() +print(4) + +labels = torch.tensor(labels).cuda() +images = torch.tensor(images).cuda().view((RANGE * 2, 3, 256, 256)).float() +print(5) + +# |%%--%%| <vKz9NMcv7I|LYgaCybme2> + +dataset = torch.utils.data.TensorDataset(images, labels) + +# |%%--%%| <LYgaCybme2|KFseMis27r> + +dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True) + +# |%%--%%| <KFseMis27r|PvIYi4SzgP> + +while True: + train(model, dataloader, loss_function, optimizer) + +# |%%--%%| <PvIYi4SzgP|W5EFRIN0B0>