From c09dc9f7ec2fa3876fc97621b6b956af5f080507 Mon Sep 17 00:00:00 2001 From: kosh Date: Wed, 8 May 2024 16:39:22 +0530 Subject: [PATCH] Impliment layer abstract class --- src/.gitkeep | 0 src/layer.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 src/.gitkeep create mode 100644 src/layer.py diff --git a/src/.gitkeep b/src/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/layer.py b/src/layer.py new file mode 100644 index 0000000..556aaf4 --- /dev/null +++ b/src/layer.py @@ -0,0 +1,36 @@ +import numpy as np +from abc import ABC, abstractmethod + + +class Layer(ABC): + @abstractmethod + def forward(self, x: np.ndarray) -> np.ndarray: + """ + x = inputs + if should_cache = True, + additional caching will be done. + Set this to true and then call forward right before calling backward + """ + + @property + @abstractmethod + def parameters(self) -> tuple[np.ndarray, ...]: + """ + Returns the different parameters. + The order is defined as per the sub class's convinience + """ + + @parameters.setter + @abstractmethod + def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + """ + Write to parameters property + """ + + @abstractmethod + def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]: + pass + + @abstractmethod + def d_output_wrt_inputs(self) -> np.ndarray: + pass