Smashing Automatic Speech Recognition Models with x-fast

This tutorial demonstrates how to use the pruna package to optimize any custom whisper model. In this case, the outputted model is a smashed whisper model. We will use the openai/whisper-large-v3 model as an example.

Loading the ASR model

First, load your asr model.

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="eager"
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

Initializing the Smash Config

Next, initialize the smash_config.

from pruna_engine.SmashConfig import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smasher_config['task'] = 'audio_text_transcription'
smasher_config['compilers'] = ['x-fast']
# uncomment the following line to quantize the model to 16 bits
# smasher_config['quantizers'] = half

Smashing the Model

Now, smash the model.

from pruna.smash import smash

# Smash the model
smashed_model = smash(
    model=model,
    api_key='<your-api-key>',  # replace <your-api-key> with your actual API key
    smash_config=smash_config,
)

Don’t forget to replace the api_key by the one provided by PrunaAI.

Preparing the Input

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features.cuda().half()
prompt = processor.get_decoder_prompt_ids(language="english", task="transcribe")

Running the Model

Finally, run the model to transcribe the audio file. The first iteration may take a while to run because we do additional compilation in that time, but subsequent iterations will be very fast.

# Display the result
results = model.generate(input_features)
processor.batch_decode(results, skip_special_tokens=False)

Wrap Up

Congratulations! You have successfully smashed an ASR model. You can now use the pruna package to optimize any custom ASR model. The only parts that you should modify are step 1 and step 5 to fit your use case.