generated_from_trainer

wav2vec2-large-xls-r-300m-german-with-lm

This model is a fine-tuned version of facebook/wav2vec2-xls-r-300m on the German set of the Common Voice dataset. It achieves a Word Error Rate of 8,8 percent on the evaluation set

Model description

German wav2vec2-xls-r-300m trained on the full train set of Common Voice dataset with a n-gram language model. Full code available in my Github repository

Citation

Feel free to cite this work by

@misc{mfleck/wav2vec2-large-xls-r-300m-german-with-lm,
  title={XLS-R-300 Wav2Vec2 German with language model},
  author={Fleck, Michael},
  publisher={Hugging Face},
  journal={Hugging Face Hub},
  howpublished={\url{https://huggingface.co/mfleck/wav2vec2-large-xls-r-300m-german-with-lm}},
  year={2022}
}

Intended uses & limitations

Inference Usage

from transformers import pipeline

pipe = pipeline(model="mfleck/wav2vec2-large-xls-r-300m-german-with-lm")
output = pipe("/path/to/file.wav",chunk_length_s=5, stride_length_s=1)
print(output["text"])

Training and evaluation data

Script used for training (takes about 80 hours on a single A100 40GB)

import random
import re
import json
from typing import Any, Dict, List, Optional, Union

import pandas as pd
import numpy as np
import torch
# import soundfile

from datasets import load_dataset, load_metric, Audio
from dataclasses import dataclass, field

from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, TrainingArguments, Trainer, Wav2Vec2ForCTC


'''
    Most parts of this script are following the tutorial: https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
'''


common_voice_train = load_dataset("common_voice", "de", split="train+validation")
# Use train dataset with less training data
#common_voice_train = load_dataset("common_voice", "de", split="train[:3%]")
common_voice_test = load_dataset("common_voice", "de", split="test")


# Remove unused columns
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])


# Remove batches with chars which do not exist in German
print(len(common_voice_train))
regex = "[^A-Za-zäöüÄÖÜß,?.! ]+"
common_voice_train = common_voice_train.filter(lambda example: bool(re.search(regex, example['sentence']))==False)
common_voice_test = common_voice_test.filter(lambda example: bool(re.search(regex, example['sentence']))==False)
print(len(common_voice_train))


# Remove special chars from transcripts
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'
def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
    return batch
common_voice_train = common_voice_train.map(remove_special_characters, num_proc=10)
common_voice_test = common_voice_test.map(remove_special_characters, num_proc=10)


# Show some random transcripts to proof that preprocessing worked as expected
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    print(str(dataset[picks]))
show_random_elements(common_voice_train.remove_columns(["path","audio"]))


# Extract all chars which exist in datasets and add wav2vek tokens
def extract_all_chars(batch):
    all_text = " ".join(batch["sentence"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)



# Create tokenizer and repo at Huggingface
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
repo_name = "wav2vec2-large-xls-r-300m-german-with-lm"
tokenizer.push_to_hub(repo_name)
print("pushed to hub")



# Create feature extractor and processor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)


# Cast audio column
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))


# Convert audio signal to array and 16khz sampling rate
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched"
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    # Save an audio file to check if it gets loaded correctly
    # soundfile.write("/home/debian/trainnew/test.wav",batch["input_values"],audio["sampling_rate"])
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
print("dataset prepared")




@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)


# Use word error rate as metric
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}



# Model and training parameters
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.094,
    hidden_dropout=0.01,
    feat_proj_dropout=0.04,
    mask_time_prob=0.08,
    layerdrop=0.04,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)
model.freeze_feature_extractor()

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=32,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=20,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=5000,
  eval_steps=5000,
  logging_steps=100,
  learning_rate=1e-4,
  warmup_steps=500,
  save_total_limit=3,
  push_to_hub=True,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

# Start fine tuning
trainer.train()

# When done push final model to Huggingface hub
trainer.push_to_hub()

The model achieves a Word Error Rate of 8,8% using the following script:

import argparse
import re
from typing import Dict

import torch
from datasets import Audio, Dataset, load_dataset, load_metric

from transformers import AutoFeatureExtractor, pipeline



# load dataset
dataset = load_dataset("common_voice", "de", split="test")
# use only 1% of data
#dataset = load_dataset("common_voice", "de", split="test[:1%]")


# load processor
feature_extractor = AutoFeatureExtractor.from_pretrained("mfleck/wav2vec2-large-xls-r-300m-german-with-lm")
sampling_rate = feature_extractor.sampling_rate

dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))

# load eval pipeline
# device=0 means GPU, use device=-1 for CPU
asr = pipeline("automatic-speech-recognition", model="mfleck/wav2vec2-large-xls-r-300m-german-with-lm", device=0)

# Remove batches with chars which do not exist in German
regex = "[^A-Za-zäöüÄÖÜß,?.! ]+"
dataset = dataset.filter(lambda example: bool(re.search(regex, example['sentence']))==False)

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'
# map function to decode audio
def map_to_pred(batch):
    prediction = asr(batch["audio"]["array"], chunk_length_s=5, stride_length_s=1)

    # Print automatic generated transcript
    #print(str(prediction))

    batch["prediction"] = prediction["text"]
    text = batch["sentence"]
    batch["target"] = re.sub(chars_to_ignore_regex, "", text.lower()) + " "
    
    return batch

# run inference on all examples
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)

# load metric
wer = load_metric("wer")
cer = load_metric("cer")

# compute metrics
wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])

# print results
result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
print(result_str)

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

Training results

Training Loss Epoch Step Validation Loss Wer
0.1396 1.42 5000 0.1449 0.1479
0.1169 2.83 10000 0.1285 0.1286
0.0938 4.25 15000 0.1277 0.1230
0.0924 5.67 20000 0.1305 0.1191
0.0765 7.09 25000 0.1256 0.1158
0.0749 8.5 30000 0.1186 0.1092
0.066 9.92 35000 0.1173 0.1068
0.0581 11.34 40000 0.1225 0.1030
0.0582 12.75 45000 0.1153 0.0999
0.0507 14.17 50000 0.1182 0.0971
0.0491 15.59 55000 0.1136 0.0939
0.045 17.01 60000 0.1140 0.0914
0.0395 18.42 65000 0.1160 0.0902
0.037 19.84 70000 0.1148 0.0882

Framework versions