diff --git a/src/loss.py b/src/loss.py index 43d80e4..f73ad86 100644 --- a/src/loss.py +++ b/src/loss.py @@ -3,10 +3,22 @@ from abc import ABC, abstractmethod class Loss(ABC): + @staticmethod @abstractmethod - def __call__(self, output: np.ndarray, target: np.ndarray) -> float: + def __call__(output: np.ndarray, target: np.ndarray) -> float: """""" + @staticmethod @abstractmethod - def get_diffrential(self, output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray: + def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray: """""" + + +class CrossEntropyLoss(Loss): + @staticmethod + def __call__(output: np.ndarray, target: np.ndarray) -> float: + return -np.sum(target * np.log10(output), dtype=np.float32) + + @staticmethod + def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray: + return -target / (output + 0.00001)