From b6cd28db7c500a56f8961006fccca6c9549584ec Mon Sep 17 00:00:00 2001 From: kosh <kosh@kosh-web.cfd> Date: Thu, 9 May 2024 12:29:35 +0530 Subject: [PATCH] Implement CrossEntropyLoss --- src/loss.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/loss.py b/src/loss.py index 43d80e4..f73ad86 100644 --- a/src/loss.py +++ b/src/loss.py @@ -3,10 +3,22 @@ from abc import ABC, abstractmethod class Loss(ABC): + @staticmethod @abstractmethod - def __call__(self, output: np.ndarray, target: np.ndarray) -> float: + def __call__(output: np.ndarray, target: np.ndarray) -> float: """""" + @staticmethod @abstractmethod - def get_diffrential(self, output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray: + def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray: """""" + + +class CrossEntropyLoss(Loss): + @staticmethod + def __call__(output: np.ndarray, target: np.ndarray) -> float: + return -np.sum(target * np.log10(output), dtype=np.float32) + + @staticmethod + def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray: + return -target / (output + 0.00001)