Impliment layer abstract class
This commit is contained in:
parent
6f855c426d
commit
c09dc9f7ec
0
src/.gitkeep
Normal file
0
src/.gitkeep
Normal file
36
src/layer.py
Normal file
36
src/layer.py
Normal file
@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Layer(ABC):
|
||||
@abstractmethod
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
x = inputs
|
||||
if should_cache = True,
|
||||
additional caching will be done.
|
||||
Set this to true and then call forward right before calling backward
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> tuple[np.ndarray, ...]:
|
||||
"""
|
||||
Returns the different parameters.
|
||||
The order is defined as per the sub class's convinience
|
||||
"""
|
||||
|
||||
@parameters.setter
|
||||
@abstractmethod
|
||||
def parameters(self, parameters: tuple[np.ndarray, ...]) -> None:
|
||||
"""
|
||||
Write to parameters property
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def d_output_wrt_parameters(self, inputs: np.ndarray) -> tuple[np.ndarray, ...]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def d_output_wrt_inputs(self) -> np.ndarray:
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user