From 06c981253c96b939ad1712d2237de3cb0399fdf6 Mon Sep 17 00:00:00 2001 From: kosh Date: Wed, 8 May 2024 18:49:39 +0530 Subject: [PATCH] Implement ReLU layer --- src/layer.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/layer.py b/src/layer.py index 9b4a602..2a25336 100644 --- a/src/layer.py +++ b/src/layer.py @@ -32,7 +32,7 @@ class Layer(ABC): pass @abstractmethod - def d_output_wrt_inputs(self) -> np.ndarray: + def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: pass @@ -64,5 +64,24 @@ class Dense(Layer): """ return (inputs, np.array([1])) - def d_output_wrt_inputs(self) -> np.ndarray: + def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: return self.__w + + +class ReLU(Layer): + def forward(self, x: np.ndarray) -> np.ndarray: + return (x > 0) * x + + @property + def parameters(self) -> tuple[np.ndarray, ...]: + return () + + @parameters.setter + def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + return + + def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]: + return () + + def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: + return (x > 0) * 1.0