modular transformers
on this page
modular transformers reduces the code required to contribute new models from 3,000-6,000 lines to approximately 500 lines by enabling inheritance across model files while maintaining the library’s “single model, single file” philosophy.
architecture
modular_<model>.py
│
▼ (linter)
modeling_<model>.py (standard single-file output) how it works
- modular files: contain model, processor, and configuration code with inheritance
- linter: automatically “unravels” modular files into standard format
- output: traditional single-file
modeling.pyfor users
the modular model converter
# convert a modular file to standard format
python utils/modular_model_converter.py \
--files-to-parse src/transformers/models/<your_model>/modular_<your_model>.py
# convert all models
python utils/modular_model_converter.py --files-to-parse all syntax patterns
1. basic inheritance
from transformers.models.llama import LlamaConfig, LlamaModel
class MyModelConfig(LlamaConfig):
pass # uses parent as-is, just renames
class MyModelRMSNorm(LlamaRMSNorm):
pass # linter copies parent and updates references 2. super calls and unraveling
class MyModelAttention(LlamaAttention):
def __init__(self, config, layer_idx):
super().__init__(config, layer_idx) # parent body inserted here
self.new_layer = NewLayer(config) 3. attribute modification
adding attributes:
class Olmo2Attention(OlmoAttention):
def __init__(self, config, layer_idx):
super().__init__(config, layer_idx)
self.q_norm = Olmo2RMSNorm(config.hidden_size // config.num_attention_heads)
self.k_norm = Olmo2RMSNorm(config.hidden_size // config.num_attention_heads) removing attributes:
class Olmo2Config(OlmoConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
del self.clip_qkv # removes the assignment 4. the **super_kwargs pattern
class MyModelForCausalLM(LlamaForCausalLM):
@my_new_decorator
def forward(self, **super_kwargs):
return super().forward(**super_kwargs) the linter expands this to the full signature automatically.
dependency tracing
automatic inference
class OlmoDecoderLayer(LlamaDecoderLayer):
def __init__(self, config, layer_idx):
super().__init__(config, layer_idx)
self.mlp = OlmoMLP(config) # triggers automatic creation if OlmoMLP isn’t explicitly defined, the linter automatically creates:
class Olmo2MLP(OlmoMLP):
pass # empty inheritance, just renames function dependencies
the linter:
- copies imported function definitions
- traces dependencies (e.g.,
rotate_half,repeat_kv) - copies those too, recursively
- renames everything consistently
real-world examples
olmo2 (complex multi-parent inheritance)
# configuration: inherit from OlmoConfig
class Olmo2Config(OlmoConfig):
def __init__(self, rms_norm_eps=1e-5, clip_qkv=None, **kwargs):
super().__init__(clip_qkv=clip_qkv, **kwargs)
self.rms_norm_eps = rms_norm_eps
del self.clip_qkv
# normalization: inherit from Llama
class Olmo2RMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
output = hidden_states * torch.rsqrt(
hidden_states.pow(2).mean(-1, keepdim=True) + self.variance_epsilon
)
return output.to(hidden_states.dtype) * self.weight
# attention: inherit from Olmo, add Q/K normalization
class Olmo2Attention(OlmoAttention):
def __init__(self, config, layer_idx):
super().__init__(config, layer_idx)
head_dim = config.hidden_size // config.num_attention_heads
self.q_norm = Olmo2RMSNorm(head_dim, eps=config.rms_norm_eps)
self.k_norm = Olmo2RMSNorm(head_dim, eps=config.rms_norm_eps)
del self.clip_qkv qwen3 (minimal inheritance)
class Qwen3Config(Qwen2Config):
pass # no changes needed
class Qwen3RMSNorm(Qwen2RMSNorm):
pass # just rename
class Qwen3MLP(GemmaMLP):
pass # inherit from Gemma instead of Qwen2 best practices
1. use consistent naming
# good - consistent prefix
class MyModelConfig(LlamaConfig): pass
class MyModelAttention(LlamaAttention): pass
# bad - inconsistent prefixes
class MyModelAwesomeAttention(LlamaAttention): pass # wrong 2. multimodal models
class MyModelTextConfig(LlamaConfig): pass
class MyModelVisionConfig(CLIPVisionConfig): pass 3. leverage empty inheritance
class MyModelRMSNorm(LlamaRMSNorm):
pass # perfectly valid - just renames 4. testing your modular file
# generate files
python utils/modular_model_converter.py \
--files-to-parse src/transformers/models/mymodel/modular_mymodel.py
# verify correctness
python utils/check_modular_conversion.py \
--files src/transformers/models/mymodel/modular_mymodel.py benefits
| before modular | after modular |
|---|---|
| 3,000-6,000 lines per model | ~500 lines per model |
| copy-paste code duplication | inheritance-based sharing |
| implementation divergence | shared maintenance |
| high contribution barrier | lower contribution barrier |
specific improvements:
- 80-90% code reduction for new model contributions
- easier maintenance through shared code
- faster integration of new architectures
- prevents implementation divergence
known limitations (rc0)
- config docstrings require complete override
- try-except imports may cause issues
- single-level flattening only
- multi-level hierarchies require explicit intermediate definitions