Changed parameter type from tuple to list in layer.py
This commit is contained in:
parent
3fe69b3869
commit
98b3b4e18a
38
src/layer.py
38
src/layer.py
@ -3,8 +3,11 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class Layer(ABC):
|
class Layer(ABC):
|
||||||
|
_cached_inputs: np.ndarray
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||||
|
self._cached_inputs = x
|
||||||
"""
|
"""
|
||||||
x = inputs
|
x = inputs
|
||||||
if should_cache = True,
|
if should_cache = True,
|
||||||
@ -14,7 +17,7 @@ class Layer(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parameters(self) -> tuple[np.ndarray, ...]:
|
def parameters(self) -> list[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Returns the different parameters.
|
Returns the different parameters.
|
||||||
The order is defined as per the sub class's convinience
|
The order is defined as per the sub class's convinience
|
||||||
@ -22,17 +25,17 @@ class Layer(ABC):
|
|||||||
|
|
||||||
@parameters.setter
|
@parameters.setter
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parameters(self, parameters: tuple[np.ndarray, ...]) -> None:
|
def parameters(self, parameters: list[np.ndarray]) -> None:
|
||||||
"""
|
"""
|
||||||
Write to parameters property
|
Write to parameters property
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]:
|
def d_output_wrt_parameters(self) -> list[np.ndarray]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray:
|
def d_output_wrt_inputs(self) -> np.ndarray:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -46,25 +49,26 @@ class Dense(Layer):
|
|||||||
self.__b = np.random.random((output_size))
|
self.__b = np.random.random((output_size))
|
||||||
|
|
||||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||||
|
super().forward(x)
|
||||||
return np.dot(self.__w, x.T).T + self.__b
|
return np.dot(self.__w, x.T).T + self.__b
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> tuple[np.ndarray, np.ndarray]:
|
def parameters(self) -> list[np.ndarray]:
|
||||||
return (self.__w, self.__b)
|
return [self.__w, self.__b]
|
||||||
|
|
||||||
@parameters.setter
|
@parameters.setter
|
||||||
def parameters(self, parameters: tuple[np.ndarray, ...]) -> None:
|
def parameters(self, parameters: list[np.ndarray]) -> None:
|
||||||
self.__w = parameters[0]
|
self.__w = parameters[0]
|
||||||
self.__b = parameters[1]
|
self.__b = parameters[1]
|
||||||
|
|
||||||
def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
def d_output_wrt_parameters(self) -> list[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
d_out_wrt_w = input
|
d_out_wrt_w = input
|
||||||
d_out_wrt_b = output
|
d_out_wrt_b = output
|
||||||
"""
|
"""
|
||||||
return (inputs, np.array([1]))
|
return [self._cached_inputs, np.array([1])]
|
||||||
|
|
||||||
def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray:
|
def d_output_wrt_inputs(self) -> np.ndarray:
|
||||||
return self.__w
|
return self.__w
|
||||||
|
|
||||||
|
|
||||||
@ -73,15 +77,15 @@ class ReLU(Layer):
|
|||||||
return (x > 0) * x
|
return (x > 0) * x
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> tuple[np.ndarray, ...]:
|
def parameters(self) -> list[np.ndarray]:
|
||||||
return ()
|
return []
|
||||||
|
|
||||||
@parameters.setter
|
@parameters.setter
|
||||||
def parameters(self, parameters: tuple[np.ndarray, ...]) -> None:
|
def parameters(self, parameters: list[np.ndarray]) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]:
|
def d_output_wrt_parameters(self) -> list[np.ndarray]:
|
||||||
return ()
|
return []
|
||||||
|
|
||||||
def d_output_wrt_inputs(self, x: np.ndarray) -> np.ndarray:
|
def d_output_wrt_inputs(self) -> np.ndarray:
|
||||||
return (x > 0) * 1.0
|
return (self._cached_inputs > 0) * 1.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user