728x90

공식 Documentation

Python List와 마찬가지로 nn.Module을 저장하는 역할을 한다.

예제

list를 nn.ModuleList()로 감싸 주면 된다.

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

 

nn.Sequential()과의 차이점

  • nn.Sequential() : 안에 들어가는 모듈들을 연결해주고, 하나의 뉴럴넷을 정의한다. 즉, 나열된 모듈들의 output shape과 input shape이 일치해야 한다는 것. 
  • nn.ModuleList() : 개별적으로 모듈들이 담겨있는 리스트. 모듈들의 연결관계가 정의되지 않는다. 즉, forward 함수에서 ModulList 내의 모듈들을 이용하여 적절한 연결관계를 정의하는 과정이 필수적이다.

 

참고자료

https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html

 

ModuleList — PyTorch 1.11.0 documentation

Shortcuts

pytorch.org

https://dongsarchive.tistory.com/67

 

nn.ModuleList vs nn.Sequential

파이토치 코드를 보다보면 자주 등장하는 두 가지 클래스다. 비슷하게 쓰이는것 같으면서도 그 차이점을 구별해라 하면 말하기 어려운데, 구글링을 해 보니 친절한 답변이 있어서 가져왔다. (링

dongsarchive.tistory.com

 

728x90

'딥러닝 > Pytorch' 카테고리의 다른 글

pytorch - nn.function과 nn의 차이점  (0) 2022.06.19
torch - GPU 사용하기  (0) 2022.06.13
Pytorch nn.Embedding()  (0) 2022.05.30
Pytorch 기본 문법 - 모델 평가  (0) 2022.05.25
Pytorch 기본 문법 - 모델 훈련  (0) 2022.05.23

+ Recent posts