Impliment layer abstract class
This commit is contained in:
parent
6f855c426d
commit
c09dc9f7ec
0
src/.gitkeep
Normal file
0
src/.gitkeep
Normal file
36
src/layer.py
Normal file
36
src/layer.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import numpy as np
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Layer(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
x = inputs
|
||||||
|
if should_cache = True,
|
||||||
|
additional caching will be done.
|
||||||
|
Set this to true and then call forward right before calling backward
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def parameters(self) -> tuple[np.ndarray, ...]:
|
||||||
|
"""
|
||||||
|
Returns the different parameters.
|
||||||
|
The order is defined as per the sub class's convinience
|
||||||
|
"""
|
||||||
|
|
||||||
|
@parameters.setter
|
||||||
|
@abstractmethod
|
||||||
|
def parameters(self, parameters: tuple[np.ndarray, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Write to parameters property
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def d_output_wrt_inputs(self) -> np.ndarray:
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user