end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

自動微分における順伝播と逆伝播(バックプロバゲーション)

自動微分における順伝播と逆伝播は、合成関数の微分によりますが、 十分には理解できていない気がしていますので、振り返り。

「ゼロから作るディープラーニング 3 」の写経です

目次

数式と、計算グラフ(svg)

p.25付近の内容です

【順伝播】 数式 a = A(x) 、b = B(a) 、y = C(b) ⇒ y = C( B( A(x) ) ) 凡例: □数式、○値 計算グラフ x A a B b C y 【逆伝播】 計算グラフ dy/dx dz/dx dy/da db/da dy/db dy/db dy/dy 数式 dy/dx = ( ( ( dy/dy・dy/db ) db/da ) da/dx )

python script

以下のurlにあるstep7~8の写経です。

https://github.com/oreilly-japan/deep-learning-from-scratch-3

Define-by-Run に該当しますかね

#!/usr/local/bin/python
# -*- coding: utf-8 -*-
import numpy as np

def main():
    A = Square()
    B = Exp()
    C = Square()

    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)

    # backward
    y.grad = np.array(1.0)
    y.backward()
    print(x.grad)

class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None
        
    def set_creator(self, func):
        self.creator = func

    def backward(self):
        funcs = [self.creator]
        while funcs:
            f = funcs.pop()  # 1. Get a function
            x, y = f.input, f.output  # 2. Get the function's input/output
            x.grad = f.backward(y.grad)  # 3. Call the function's backward

            if x.creator is not None:
                funcs.append(x.creator)

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        output.set_creator(self)
        self.input = input
        self.output = output
        return output

    def forward(self, x):
        raise NotImplementedError()

    def backward(self, gy):
        raise NotImplementedError()

class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y

    def backward(self, gy):
        x = self.input.data
        gx = 2 * x * gy
        return gx

class Exp(Function):
    def forward(self, x):
        y = np.exp(x) # ネイピア数e が底の指数関数
        return y

    def backward(self, gy):
        x = self.input.data
        gx = np.exp(x) * gy
        return gx

if __name__ == '__main__':
    main()