728x90
공식 Documentation
Python List와 마찬가지로 nn.Module을 저장하는 역할을 한다.
예제
list를 nn.ModuleList()로 감싸 주면 된다.
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
nn.Sequential()과의 차이점
- nn.Sequential() : 안에 들어가는 모듈들을 연결해주고, 하나의 뉴럴넷을 정의한다. 즉, 나열된 모듈들의 output shape과 input shape이 일치해야 한다는 것.
- nn.ModuleList() : 개별적으로 모듈들이 담겨있는 리스트. 모듈들의 연결관계가 정의되지 않는다. 즉, forward 함수에서 ModulList 내의 모듈들을 이용하여 적절한 연결관계를 정의하는 과정이 필수적이다.
참고자료
https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html
ModuleList — PyTorch 1.11.0 documentation
Shortcuts
pytorch.org
https://dongsarchive.tistory.com/67
nn.ModuleList vs nn.Sequential
파이토치 코드를 보다보면 자주 등장하는 두 가지 클래스다. 비슷하게 쓰이는것 같으면서도 그 차이점을 구별해라 하면 말하기 어려운데, 구글링을 해 보니 친절한 답변이 있어서 가져왔다. (링
dongsarchive.tistory.com
728x90
'딥러닝 > Pytorch' 카테고리의 다른 글
pytorch - nn.function과 nn의 차이점 (0) | 2022.06.19 |
---|---|
torch - GPU 사용하기 (0) | 2022.06.13 |
Pytorch nn.Embedding() (0) | 2022.05.30 |
Pytorch 기본 문법 - 모델 평가 (0) | 2022.05.25 |
Pytorch 기본 문법 - 모델 훈련 (0) | 2022.05.23 |