Clean names in loss.py
This commit is contained in:
parent
b6cd28db7c
commit
3fe69b3869
@ -2,7 +2,7 @@ import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Loss(ABC):
|
||||
class LossFunction(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def __call__(output: np.ndarray, target: np.ndarray) -> float:
|
||||
@ -14,7 +14,7 @@ class Loss(ABC):
|
||||
""""""
|
||||
|
||||
|
||||
class CrossEntropyLoss(Loss):
|
||||
class CategoricalCrossEntropyLossFunction(LossFunction):
|
||||
@staticmethod
|
||||
def __call__(output: np.ndarray, target: np.ndarray) -> float:
|
||||
return -np.sum(target * np.log10(output), dtype=np.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user