Define the model RNN


class RNN(nn.Module):
    def __init__(self):
        super().__init__()
        num_classes = 7
        hidden_size = 128
        dropout = 0.4
        embedding_dim = 768
        num_layers = 1  # Increase the number of layers to 3
        self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers, batch_first=True, dropout =dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        mean_x = torch.mean(x, dim=1, keepdim=True)
        batch_size, num_models, seq_len, hidden_size = mean_x.shape
        x = mean_x.reshape(batch_size*num_models, seq_len, hidden_size)
        
        x, _ = self.rnn(x)
        
        x = F.relu(x)
        x = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(2)
        
        x = self.dropout(x)
        logit = self.fc1(x)
        return logit