text generation pytorch causal-lm gated-state-space

Gated State Space

This repo contains pretrain model for the gated state space paper. The model has been trained on C4 dataset. I have used Lucidrains' implementation (commit) for the model. I think the main benefit of this model is the ability to scale beyond the training context length. As authors noted in the paper, they trained the model on 4k sequence length but it generalized beyond that length. I have written a blog post on how I started the training here.

Wandb Report is available at this link

How to use this.

Since it is not based on transformers library, it is a bit tricky to use the model out of the box. Here are the general steps:

  1. pip install gated-state-spaces-pytorch
  2. Download the model weights from here.
  3. Download the config from here.
  4. Following code to patch the original model:
    model = AutoregressiveWrapper(
        GatedStateSpacesLM(
            **config
        ),
    )
    model.net.to_logits = nn.Sequential(
        nn.LayerNorm(f_emb),
        model.net.to_logits,
    )
  1. Load the state dict: model.load_state_dict(torch.load('model.pt'))
  2. If you want to fine-tune the model, you can freeze the embeddings:
    model.net.token_emb.weight.requires_grad_(False)
    model.net.token_emb.weight.copy_(emb)

    model.net.to_logits[1].weight.requires_grad_(False)
    model.net.to_logits[1].weight.copy_(emb)

Training code is available in this repo. Link to the training script.

Training Information

Here are the details of the training:

Fine-Tuning Info:

model2.pt is available as fine-tuned version with longer context length.