728x90
공식 Documentation
Python List와 마찬가지로 nn.Module을 저장하는 역할을 한다.
예제
list를 nn.ModuleList()로 감싸 주면 된다.
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
nn.Sequential()과의 차이점
- nn.Sequential() : 안에 들어가는 모듈들을 연결해주고, 하나의 뉴럴넷을 정의한다. 즉, 나열된 모듈들의 output shape과 input shape이 일치해야 한다는 것.
- nn.ModuleList() : 개별적으로 모듈들이 담겨있는 리스트. 모듈들의 연결관계가 정의되지 않는다. 즉, forward 함수에서 ModulList 내의 모듈들을 이용하여 적절한 연결관계를 정의하는 과정이 필수적이다.
참고자료
https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html
https://dongsarchive.tistory.com/67
728x90
'딥러닝 > Pytorch' 카테고리의 다른 글
pytorch - nn.function과 nn의 차이점 (0) | 2022.06.19 |
---|---|
torch - GPU 사용하기 (0) | 2022.06.13 |
Pytorch nn.Embedding() (0) | 2022.05.30 |
Pytorch 기본 문법 - 모델 평가 (0) | 2022.05.25 |
Pytorch 기본 문법 - 모델 훈련 (0) | 2022.05.23 |