diff --git a/src/layer.py b/src/layer.py index 556aaf4..9b4a602 100644 --- a/src/layer.py +++ b/src/layer.py @@ -34,3 +34,35 @@ class Layer(ABC): @abstractmethod def d_output_wrt_inputs(self) -> np.ndarray: pass + + +class Dense(Layer): + __w: np.ndarray + __b: np.ndarray + + def __init__(self, input_size: int, output_size: int) -> None: + super().__init__() + self.__w = np.random.random((output_size, input_size)) + self.__b = np.random.random((output_size)) + + def forward(self, x: np.ndarray) -> np.ndarray: + return np.dot(self.__w, x.T).T + self.__b + + @property + def parameters(self) -> tuple[np.ndarray, np.ndarray]: + return (self.__w, self.__b) + + @parameters.setter + def parameters(self, parameters: tuple[np.ndarray, ...]) -> None: + self.__w = parameters[0] + self.__b = parameters[1] + + def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + d_out_wrt_w = input + d_out_wrt_b = output + """ + return (inputs, np.array([1])) + + def d_output_wrt_inputs(self) -> np.ndarray: + return self.__w