diff --git a/src/layer.py b/src/layer.py index 9b4a602..2a25336 100644 --- a/src/layer.py +++ b/src/layer.py @@ -32,7 +32,7 @@ class Layer(ABC): pass @abstractmethod - def d_output_wrt_inputs(self) -> np.ndarray: + def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: pass @@ -64,5 +64,24 @@ class Dense(Layer): """ return (inputs, np.array([1])) - def d_output_wrt_inputs(self) -> np.ndarray: + def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: return self.__w + + +class ReLU(Layer): + def forward(self, x: np.ndarray) -> np.ndarray: + return (x > 0) * x + + @property + def parameters(self) -> tuple[np.ndarray, ...]: + return () + + @parameters.setter + def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + return + + def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]: + return () + + def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: + return (x > 0) * 1.0