카테고리 없음

7-6. CNN 구현하기

코랑이 2022. 1. 5. 16:46

단순한 CNN 구현하기

단순 CNN 구성

  • SimpleConvNet 초기화(1/3)
    - 코드가 길어지므로 3단으로 나누어 설명
  • class SimpleConvNet: def __init__(self, input_dim = (1, 28, 28) conv_param={'filter_num':30, 'filter_size':5, 'pad':0, 'stride':1}, hidden_size=100, output_size=10, weight_init_std=0.01): filter_num = conv_param['filter_num'] filter_size = conv_param['filter_size'] filter_pad = conv_param['pad'] filter_stride = conv_param['stride'] input_ize = input_dim[1] conv_output_size = (input_size - filter_size + 2 * filter_pad) \ filter_stride + 1 pool_output_size = int(filter_num * (conv_output_size/2) * (conv_output_size/2))​

- 여기서, 합성곱계층 hyperparam은 딕셔너리 형태로 주어짐(conv_param). (ex. {'filter_num':30, 'filter_size':5} )

- 초기화 인수로 주어진 conv_param을 딕셔너리에서 꺼냄 

- 합성곱 계층 출력 크기를 계산함

 

  • SimpleConvNet 초기화(2/3)
    	self.params = {}
            self.params['W1'] = weight_init_std * \
            					np.random.randn(filter_num, input_dim[0], filter_size, filter_size)
            self.params['b1'] = np.zeros(filter_num)
            self.params['W2'] = weight_init_std * \
            					np.random.randn(filter_num, input_dim[0], filter_size, filter_size)
            self.params['b2'] = np.zeros(filter_num)
            self.params['W3'] = weight_init_std * \
            					np.random.randn(filter_num, input_dim[0], filter_size, filter_size)
            self.params['b3'] = np.zeros(filter_num)

- 가중치 초기화 과정

- 매개변수들을 인스턴스 변수 params 딕셔너리에 저장

- 1번째 층의 합성곱계층 가중치:W1, 편향:b1 

- 2번째 층의 합성곱계층 가중치:W2, 편향:b2 

- 3번째 층의 합성곱계층 가중치:W3, 편향:b3 

 

  • SimpleConvNet 초기화(3/3)
        self.layers = OrderedDict()
        self.layers['Conv1'] = Convolution(self.params['W1'],
        									self.params['b1']
                                            conv_params['stride']
                                            conv_parmas['pad']
        self.layers['Relu1'] = Relu()
        self.layers['pool1'] = Pooling([ppl_h=2, pool_w=2, stride=2)
        self.layers['Affine1'] = Affine(self.params['W2']
        								self.params['b2'])
        self.layers['Relu2'] = Relu()
        self.layers['Affine2'] = Affine(self.params['W3']
        								self.params['b3']
                                        
        self.last_layer = SoftmaxWithLoss()

- CNN 구성하는 계층들을 생성

- layers 변수는 순서가있는 딕셔너리임.(orderdict)

- 마지막 softmaxWithLoss()는 last_layer라는 별도 변수로 저장

 

  • predict, loss 함수
        def predict(self, x):
            for layer in self.layers.values():
                x = layer.forward(x)
            return x
            
        def loss(self, x, t):
            y = self.predict(x)
            return self.last_layer.forward(y, t)​

- x는 입력데이터, t는 정답 레이블.

- predict()는 초기화 시 layers에 추가한 계층을 맨 앞부터 차례로 forward 메서드로 호출하여 다음 계층으로 전달

- loss()는 predict()의 결과를 인수로 받아, 마지막층의 forward 메서드를 호출.

-> 즉, 첫 계층부터 마지막 계층까지 forward를 처리함.

 

 

  • 오차역전파법으로 기울기 구하기
        def gradient(self, x, t):
            # 순전파
            self.loss(x, t)
            
            # 역전파
            dout = 1
            dout = self.last_layer.backward(dout)
            
            layers = list(self.layers.values())
            layers.reverse()
            
            for layer in layers:
                dout = layer.backward(dout)
                
            # 결과 저장
            grads = {}
            grads['W1'] = self.layers['Conv1'].dW
            grads['b1'] = self.layers['Conv1'].db
            grads['W2'] = self.layers['Affine1'].dW
            grads['b2'] = self.layers['Affine1'].dW
            grads['W3'] = self.layers['Affine2'].dW
            grads['b3'] = self.layers['Affine2'].dW
            
            return grads

- 매개변수의 기울기는 오차 역전파법으로 구함.

- 이 과정은 순전파-역전파 반복

- 마지막으로 grads 라는 딕셔너리 변수에 각 가중치 매개변수 업데이트

 

 

 

요약

- 단순한 CNN 구조에서, 앞단은 Conv-ReLU-pooing 으로 이루어짐

- SimpleConvNet 구현 시 초기화함수에서 conv_param 이라는 딕셔너리에 초기 인수 넣어줌.
- SimpleConvNet 구현 시 params 라는 딕셔너리에 1,2,3 각 층에 Weight, bias(W,b) 넣어줌

- SimpleConvNet 구현 시 layers라는 순서딕셔너리에 위에서 초기화시킨 인수들을 넣어주고 계층을 쌓음

- loss 함수에서 predict()를 계속 호출하며 loss 를 줄여나가는 방향으로 학습

- predict() 에서는 각 계층을 forward 메서드로 호출함.