- def __init__(self, layers:Collection[int], drops:Collection[float]):
- super().__init__()
- mod_layers = []
- activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]
- for n_in,n_out,p,actn in zip(layers[:-1],layers[1:], drops, activs):
- mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)
- self.layers = nn.Sequential(*mod_layers)
-
- def pool(self, x:Tensor, bs:int, is_max:bool):
- "Pool the tensor along the seq_len dimension."
- f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
- return f(x.permute(1,2,0), (1,)).view(bs,-1)
-
- def forward(self, input:Tuple[Tensor,Tensor]) -> Tuple[Tensor,Tensor,Tensor]:
- raw_outputs, outputs = input
- output = outputs[-1]
- sl,bs,_ = output.size()
- avgpool = self.pool(output, bs, False)
- mxpool = self.pool(output, bs, True)
- x = torch.cat([output[-1], mxpool, avgpool], 1)
- x = self.layers(x)