Implement SDG optimizer
This commit is contained in:
parent
98b3b4e18a
commit
9b4f3073f3
25
src/optimizer.py
Normal file
25
src/optimizer.py
Normal file
@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from layer import Layer
|
||||
from neural_network import NeuralNetwork
|
||||
import numpy as np
|
||||
from loss import LossFunction
|
||||
|
||||
|
||||
class SGD:
|
||||
__layers: list[Layer]
|
||||
__loss_function: LossFunction
|
||||
|
||||
def __init__(self, layers: list[Layer], loss_function: LossFunction) -> None:
|
||||
self.__layers = layers
|
||||
self.__loss_function = loss_function
|
||||
|
||||
def train(self, output: np.ndarray, target: np.ndarray, lr=0.001):
|
||||
"""Forward needs to be called before this step"""
|
||||
d_loss_wrt_output = self.__loss_function.get_diffrential(output=output, target=target)
|
||||
for layer in self.__layers[::-1]:
|
||||
d_loss_wrt_parameters = d_loss_wrt_output * layer.d_output_wrt_parameters()
|
||||
delta_parameters = lr * d_loss_wrt_parameters
|
||||
for i in range(len(delta_parameters)):
|
||||
layer.parameters[i] -= delta_parameters[i]
|
||||
d_loss_wrt_output = d_loss_wrt_output * layer.d_output_wrt_inputs()
|
Loading…
x
Reference in New Issue
Block a user