Smashing Large Language Models

This tutorial demonstrates how to use the pruna package to optimize any custom large language model. We will use the facebook/opt-125m model as an example.

Loading the LLM

First, load your llm.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto").cuda()

text = "The 45th president of the United States of America is"
ins = tokenizer(text, return_tensors="pt").to('cuda')

Initializing the Smash Config

Next, initialize the smash_config.

from pruna_engine.SmashConfig import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smasher_config['task'] = 'text_text_generation'
smash_config['quantizers'] = ['gptq']
smash_config["tokenizer_name"] = tokenizer
smash_config["n_quantization_bits"] = 4

Smashing the Model

Now, smash the model.

from pruna.smash import smash

# Smash the model
smashed_model = smash(
    model=model,
    dataloader="WikiText",
    api_key='<your-api-key>',  # replace <your-api-key> with your actual API key
    smash_config=smash_config,
)

Don’t forget to replace the api_key by the one provided by PrunaAI.

Running the Model

Finally, run the model to generate the text.

# Display the result
smashed_model.generate(**ins)

Wrap Up

Congratulations! You have successfully smashed a large language model. You can now use the pruna package to optimize any custom LLM. The only parts that you should modify are step 1 and step 4 to fit your use case. Additionally you can use the quantizer ‘llm-int8’ to quantize the model or compiler ‘transformers-fast’ to compile for inference speedup.