Implement CrossEntropyLoss
This commit is contained in:
parent
db5761cf4e
commit
b6cd28db7c
16
src/loss.py
16
src/loss.py
@ -3,10 +3,22 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class Loss(ABC):
|
class Loss(ABC):
|
||||||
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, output: np.ndarray, target: np.ndarray) -> float:
|
def __call__(output: np.ndarray, target: np.ndarray) -> float:
|
||||||
""""""
|
""""""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_diffrential(self, output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
|
def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
|
||||||
""""""
|
""""""
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropyLoss(Loss):
|
||||||
|
@staticmethod
|
||||||
|
def __call__(output: np.ndarray, target: np.ndarray) -> float:
|
||||||
|
return -np.sum(target * np.log10(output), dtype=np.float32)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_diffrential(output: np.ndarray, target: np.ndarray, loss: np.ndarray| None = None) -> np.ndarray:
|
||||||
|
return -target / (output + 0.00001)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user