From b6cd28db7c500a56f8961006fccca6c9549584ec Mon Sep 17 00:00:00 2001
From: kosh <kosh@kosh-web.cfd>
Date: Thu, 9 May 2024 12:29:35 +0530
Subject: [PATCH] Implement CrossEntropyLoss

---
 src/loss.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/src/loss.py b/src/loss.py
index 43d80e4..f73ad86 100644
--- a/src/loss.py
+++ b/src/loss.py
@@ -3,10 +3,22 @@ from abc import ABC, abstractmethod
 
 
 class Loss(ABC):
+    @staticmethod
     @abstractmethod
-    def __call__(self, output: np.ndarray, target: np.ndarray) -> float:
+    def __call__(output: np.ndarray, target: np.ndarray) -> float:
         """"""
 
+    @staticmethod
     @abstractmethod
-    def get_diffrential(self, output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
+    def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
         """"""
+
+
+class CrossEntropyLoss(Loss):
+    @staticmethod
+    def __call__(output: np.ndarray, target: np.ndarray) -> float:
+        return -np.sum(target * np.log10(output), dtype=np.float32)
+
+    @staticmethod
+    def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
+        return -target / (output + 0.00001)