diff --git a/src/layer.py b/src/layer.py index 2a25336..e3e9b20 100644 --- a/src/layer.py +++ b/src/layer.py @@ -3,8 +3,11 @@ from abc import ABC, abstractmethod class Layer(ABC): + _cached_inputs: np.ndarray + @abstractmethod def forward(self, x: np.ndarray) -> np.ndarray: + self._cached_inputs = x """ x = inputs if should_cache = True, @@ -14,7 +17,7 @@ class Layer(ABC): @property @abstractmethod - def parameters(self) -> tuple[np.ndarray, ...]: + def parameters(self) -> list[np.ndarray]: """ Returns the different parameters. The order is defined as per the sub class's convinience @@ -22,17 +25,17 @@ class Layer(ABC): @parameters.setter @abstractmethod - def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + def parameters(self, parameters: list[np.ndarray]) -> None: """ Write to parameters property """ @abstractmethod - def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]: + def d_output_wrt_parameters(self) -> list[np.ndarray]: pass @abstractmethod - def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: + def d_output_wrt_inputs(self) -> np.ndarray: pass @@ -46,25 +49,26 @@ class Dense(Layer): self.__b = np.random.random((output_size)) def forward(self, x: np.ndarray) -> np.ndarray: + super().forward(x) return np.dot(self.__w, x.T).T + self.__b @property - def parameters(self) -> tuple[np.ndarray, np.ndarray]: - return (self.__w, self.__b) + def parameters(self) -> list[np.ndarray]: + return [self.__w, self.__b] @parameters.setter - def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + def parameters(self, parameters: list[np.ndarray]) -> None: self.__w = parameters[0] self.__b = parameters[1] - def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def d_output_wrt_parameters(self) -> list[np.ndarray]: """ d_out_wrt_w = input d_out_wrt_b = output """ - return (inputs, np.array([1])) + return [self._cached_inputs, np.array([1])] - def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: + def d_output_wrt_inputs(self) -> np.ndarray: return self.__w @@ -73,15 +77,15 @@ class ReLU(Layer): return (x > 0) * x @property - def parameters(self) -> tuple[np.ndarray, ...]: - return () + def parameters(self) -> list[np.ndarray]: + return [] @parameters.setter - def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + def parameters(self, parameters: list[np.ndarray]) -> None: return - def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]: - return () + def d_output_wrt_parameters(self) -> list[np.ndarray]: + return [] - def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray: - return (x > 0) * 1.0 + def d_output_wrt_inputs(self) -> np.ndarray: + return (self._cached_inputs > 0) * 1.0