Implement SDG optimizer
This commit is contained in:
parent
98b3b4e18a
commit
9b4f3073f3
25
src/optimizer.py
Normal file
25
src/optimizer.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from layer import Layer
|
||||||
|
from neural_network import NeuralNetwork
|
||||||
|
import numpy as np
|
||||||
|
from loss import LossFunction
|
||||||
|
|
||||||
|
|
||||||
|
class SGD:
|
||||||
|
__layers: list[Layer]
|
||||||
|
__loss_function: LossFunction
|
||||||
|
|
||||||
|
def __init__(self, layers: list[Layer], loss_function: LossFunction) -> None:
|
||||||
|
self.__layers = layers
|
||||||
|
self.__loss_function = loss_function
|
||||||
|
|
||||||
|
def train(self, output: np.ndarray, target: np.ndarray, lr=0.001):
|
||||||
|
"""Forward needs to be called before this step"""
|
||||||
|
d_loss_wrt_output = self.__loss_function.get_diffrential(output=output, target=target)
|
||||||
|
for layer in self.__layers[::-1]:
|
||||||
|
d_loss_wrt_parameters = d_loss_wrt_output * layer.d_output_wrt_parameters()
|
||||||
|
delta_parameters = lr * d_loss_wrt_parameters
|
||||||
|
for i in range(len(delta_parameters)):
|
||||||
|
layer.parameters[i] -= delta_parameters[i]
|
||||||
|
d_loss_wrt_output = d_loss_wrt_output * layer.d_output_wrt_inputs()
|
Loading…
x
Reference in New Issue
Block a user