Implement loss
This commit is contained in:
parent
40601130e0
commit
db5761cf4e
12
src/loss.py
Normal file
12
src/loss.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import numpy as np
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Loss(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, output: np.ndarray, target: np.ndarray) -> float:
|
||||||
|
""""""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_diffrential(self, output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
|
||||||
|
""""""
|
Loading…
x
Reference in New Issue
Block a user