Define the model RNN
class RNN(nn.Module):
def __init__(self):
super().__init__()
num_classes = 7
hidden_size = 128
dropout = 0.4
embedding_dim = 768
num_layers = 1 # Increase the number of layers to 3
self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers, batch_first=True, dropout =dropout)
self.dropout = nn.Dropout(dropout)
self.fc1 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
mean_x = torch.mean(x, dim=1, keepdim=True)
batch_size, num_models, seq_len, hidden_size = mean_x.shape
x = mean_x.reshape(batch_size*num_models, seq_len, hidden_size)
x, _ = self.rnn(x)
x = F.relu(x)
x = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(2)
x = self.dropout(x)
logit = self.fc1(x)
return logit