개발자로 후회없는 삶 살기
Pytorch PART.논문을 코드로 구현하는 능력(ResNet) 본문
서론
이번 과제로 논문 구현이 취미셨던 김태훈 님의 마음을 느껴봅니다.
본론
- 전체 코드
https://github.com/SangBeom-Hahn/Practical_Pytorch/blob/main/model/model.py
1. 논문 읽기
-> 1회독
Abstract > Introduction > figures > conclusion
빠른 구현 가능성 파악과 프로토타입 제작을 위해서, 전체 논문을 읽는 것이 아니라 논문을 기재한 기관이 어디인지, 코드를 제공하는지 확인하고 위 4단계를 빠르게 속독합니다.
-> 2 회독
Task > Approach > Experiments
1 회독을 완료하고 코드도 실행해 보았다면, 더 높은 성능과 이슈 사항을 파악하기 위해 2 회독을 진행합니다. 저자가 발견한 문제들을 파악하고, 성능을 높이기 위한 데이터 전처리 기법을 적용하며, 연구 환경과 비슷한 결과물을 내기 위해 노력합니다.
2. 모델 구현
논문을 읽었으면 어떤 구조로 모델을 짜야할지 대략적으로 이해했을 것입니다. 이를 토대로 논문 구현을 시작합니다.
- 모델 구조
모델 구조와 표를 보면 모델의 전체적인 구조를 파악할 수 있습니다. 그림을 보면 stride 개수가 안 나와있는데 이것을 논문을 읽음으로써 이해한 상태여야 논문을 구현할 수 있습니다.
conv1 : kernal_size = 7/ kernel 개수 = 64/ stride = 2
maxpool : size = 3/ stride = 2
conv2 : kerner_size = 3, 개수 = 64, stride = ?/ conv2 레이어 2개가 1 세트가 3개(x3)
conv3 : kerner_size = 3, 개수 = 128, stride = ?/ conv3 레이어 2개가 1 세트가 4개(x4)
conv4 : kerner_size = 3, 개수 = 256, stride = ?/ conv4 레이어 2개가 1 세트가 6개(x6)
conv5 : kerner_size = 3, 개수 = 512, stride = ?/ conv5 레이어 2개가 1 세트가 3개(x3)
We perform downsampling directly by convolutional layers that have a stride of 2. The network ends with a global average pooling layer and a 1000-way fully-connected layer with softmax.
논문에 모델 중간에서 stride 2로 다운 샘플링을 했다고 나와있습니다. 전체 논문을 읽어보면 중간에 다운 샘플링을 하고 커널을 늘려 channel dimension을 늘렸다고 나와있습니다.
그러면 /2가 stride 2이고 나머지는 1인 것을 알 수 있고
전체 레이어를 이해할 수 있으며 마지막에 1000차원의 fc layer가 있음을 알고 구현하면 됩니다.
-> 모델 구현
깊고 반복되는 네트워크를 구현할 때는 반복되는 작은 block부터 구현해서 쌓아 올리는 게 좋습니다. 또한 모델을 구현할 때 모델이 복잡해지면 코드로 타이핑하기 힘들어서 모듈을 각각 클래스화 시켜 간단하게 작성하는 것이 좋습니다. 여기서는 Conv, Residual, 모델을 클래스화 하였습니다. 저는 모델 내부에 list를 두고 모듈을 append 한 후 마지막에 Sequential add_module 합니다.
시퀀셜의 경우 반드시 모듈 타입이어야 넣을 수 있고 forward에서 호출하면 연속적으로 호출됩니다. 따라서 속성으로 각 레이어들을 모듈 타입으로 정의하고 시퀀셜에 넣을 수 있도록 구현합니다.
We adopt batch normalization (BN)
주어진 논문에 BN과 Relu를 사용했다고 명시되어 있는데
Normalization이나 activation처럼 optional인 경우에는 파라미터로 받고 if문을 사용합니다. BN과 DropOut, Relu 등도 논문에 정확히 어디에 사용하였다고 명시되어 있진 않습니다. 따라서, 기본적으로 모델 구현에 필수적인 되는 부분은 무조건적으로 사용된다고 이해하고 가야 합니다.
short cut의 경우도 모듈 타입으로 선언하고 시퀀셜에 넣으면 forward에서 호출됩니다.
def forward(self, x, short_cut=True):
if short_cut:
short_cut_x = self.short_cut(x)
else:
short_cut_x = x
return short_cut_x + self.resblk(x)
하지만 입력 x에 따로 적용해야 하는 것이므로 forward에서 분기문을 사용하여 별도로 처리할 수 있습니다. 전체 코드는 깃허브를 참고하시길 바랍니다.
논문을 보면 maxpooling에 대한 얘기 역시 나와있지 않지만 기본적으로 CNN 모델을 만들 때 Conv 레이어와 Pool 레이어의 반복은 공식과도 같습니다. maxpooling을 정의하지 않을 경우 정확도가 많이 떨어지는 것을 확인할 수 있습니다.
'[AI] > [딥러닝 | 이슈해결]' 카테고리의 다른 글
Augmentation PART.albumentation 활용 (2) | 2023.12.19 |
---|---|
[최적화] GPU 풀 구현기 1 (0) | 2023.12.09 |
Pytorch PART.데이터 로더, 폴더 활용(Collate, Sampler 등) (0) | 2023.12.04 |
Pytorch PART.Pytorch 탬플릿 Base 파일 관리 및 활용 (0) | 2023.12.04 |
PyTorch PART.PyTorch 탬플릿 Config 파일 활용 (0) | 2023.11.19 |