modular transformers

modular transformers reduces the code required to contribute new models from 3,000-6,000 lines to approximately 500 lines by enabling inheritance across model files while maintaining the library’s “single model, single file” philosophy.

architecture

modular_<model>.py
       │
       ▼ (linter)
modeling_<model>.py (standard single-file output)

how it works

  1. modular files: contain model, processor, and configuration code with inheritance
  2. linter: automatically “unravels” modular files into standard format
  3. output: traditional single-file modeling.py for users

the modular model converter

# convert a modular file to standard format
python utils/modular_model_converter.py \
    --files-to-parse src/transformers/models/<your_model>/modular_<your_model>.py

# convert all models
python utils/modular_model_converter.py --files-to-parse all

syntax patterns

1. basic inheritance

from transformers.models.llama import LlamaConfig, LlamaModel

class MyModelConfig(LlamaConfig):
    pass  # uses parent as-is, just renames

class MyModelRMSNorm(LlamaRMSNorm):
    pass  # linter copies parent and updates references

2. super calls and unraveling

class MyModelAttention(LlamaAttention):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)  # parent body inserted here
        self.new_layer = NewLayer(config)

3. attribute modification

adding attributes:

class Olmo2Attention(OlmoAttention):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        self.q_norm = Olmo2RMSNorm(config.hidden_size // config.num_attention_heads)
        self.k_norm = Olmo2RMSNorm(config.hidden_size // config.num_attention_heads)

removing attributes:

class Olmo2Config(OlmoConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        del self.clip_qkv  # removes the assignment

4. the **super_kwargs pattern

class MyModelForCausalLM(LlamaForCausalLM):
    @my_new_decorator
    def forward(self, **super_kwargs):
        return super().forward(**super_kwargs)

the linter expands this to the full signature automatically.

dependency tracing

automatic inference

class OlmoDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        self.mlp = OlmoMLP(config)  # triggers automatic creation

if OlmoMLP isn’t explicitly defined, the linter automatically creates:

class Olmo2MLP(OlmoMLP):
    pass  # empty inheritance, just renames

function dependencies

the linter:

  1. copies imported function definitions
  2. traces dependencies (e.g., rotate_half, repeat_kv)
  3. copies those too, recursively
  4. renames everything consistently

real-world examples

olmo2 (complex multi-parent inheritance)

# configuration: inherit from OlmoConfig
class Olmo2Config(OlmoConfig):
    def __init__(self, rms_norm_eps=1e-5, clip_qkv=None, **kwargs):
        super().__init__(clip_qkv=clip_qkv, **kwargs)
        self.rms_norm_eps = rms_norm_eps
        del self.clip_qkv

# normalization: inherit from Llama
class Olmo2RMSNorm(LlamaRMSNorm):
    def forward(self, hidden_states):
        output = hidden_states * torch.rsqrt(
            hidden_states.pow(2).mean(-1, keepdim=True) + self.variance_epsilon
        )
        return output.to(hidden_states.dtype) * self.weight

# attention: inherit from Olmo, add Q/K normalization
class Olmo2Attention(OlmoAttention):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        head_dim = config.hidden_size // config.num_attention_heads
        self.q_norm = Olmo2RMSNorm(head_dim, eps=config.rms_norm_eps)
        self.k_norm = Olmo2RMSNorm(head_dim, eps=config.rms_norm_eps)
        del self.clip_qkv

qwen3 (minimal inheritance)

class Qwen3Config(Qwen2Config):
    pass  # no changes needed

class Qwen3RMSNorm(Qwen2RMSNorm):
    pass  # just rename

class Qwen3MLP(GemmaMLP):
    pass  # inherit from Gemma instead of Qwen2

best practices

1. use consistent naming

# good - consistent prefix
class MyModelConfig(LlamaConfig): pass
class MyModelAttention(LlamaAttention): pass

# bad - inconsistent prefixes
class MyModelAwesomeAttention(LlamaAttention): pass  # wrong

2. multimodal models

class MyModelTextConfig(LlamaConfig): pass
class MyModelVisionConfig(CLIPVisionConfig): pass

3. leverage empty inheritance

class MyModelRMSNorm(LlamaRMSNorm):
    pass  # perfectly valid - just renames

4. testing your modular file

# generate files
python utils/modular_model_converter.py \
    --files-to-parse src/transformers/models/mymodel/modular_mymodel.py

# verify correctness
python utils/check_modular_conversion.py \
    --files src/transformers/models/mymodel/modular_mymodel.py

benefits

before modularafter modular
3,000-6,000 lines per model~500 lines per model
copy-paste code duplicationinheritance-based sharing
implementation divergenceshared maintenance
high contribution barrierlower contribution barrier

specific improvements:

  • 80-90% code reduction for new model contributions
  • easier maintenance through shared code
  • faster integration of new architectures
  • prevents implementation divergence

known limitations (rc0)

  1. config docstrings require complete override
  2. try-except imports may cause issues
  3. single-level flattening only
  4. multi-level hierarchies require explicit intermediate definitions

references

on this page