This tutorial shows how to train encoding models using language model features with the LeBel assembly. Language model features capture rich semantic representations from transformer models.
Language model features extract high-dimensional representations from transformer models like GPT-2. These features capture semantic, syntactic, and contextual information that can be highly predictive of brain activity.
from encoding.assembly.assembly_loader import load_assembly
# Load the pre-packaged LeBel assembly
assembly = load_assembly("assembly_lebel_uts03.pkl")
from encoding.features.factory import FeatureExtractorFactory
extractor = FeatureExtractorFactory.create_extractor(
modality="language_model",
model_name="gpt2-small", # Can be changed to other models
config={
"model_name": "gpt2-small",
"layer_idx": 9, # Layer to extract features from
"last_token": True, # Use last token only
"lookback": 256, # Context lookback
"context_type": "fullcontext",
},
cache_dir="cache_language_model",
)
from encoding.downsample.downsampling import Downsampler
from encoding.models.nested_cv import NestedCVModel
downsampler = Downsampler()
model = NestedCVModel(model_name="ridge_regression")
# FIR delays for hemodynamic response modeling
fir_delays = [1, 2, 3, 4]
# Trimming configuration for LeBel dataset
trimming_config = {
"train_features_start": 10,
"train_features_end": -5,
"train_targets_start": 0,
"train_targets_end": None,
"test_features_start": 50,
"test_features_end": -5,
"test_targets_start": 40,
"test_targets_end": None,
}
# No additional downsampling configuration needed
downsample_config = {}
from encoding.trainer import AbstractTrainer
trainer = AbstractTrainer(
assembly=assembly,
feature_extractors=[extractor],
downsampler=downsampler,
model=model,
fir_delays=fir_delays,
trimming_config=trimming_config,
use_train_test_split=True,
logger_backend="wandb",
wandb_project_name="lebel-language-model",
dataset_type="lebel",
results_dir="results",
layer_idx=9, # Pass layer_idx to trainer
lookback=256, # Pass lookback to trainer
)
metrics = trainer.train()
print(f"Median correlation: {metrics.get('median_score', float('nan')):.4f}")
Language model features are extracted by:
modality
: "language_model" - specifies the feature typemodel_name
: "gpt2-small" - transformer model to uselayer_idx
: 9 - which layer to extract features fromlast_token
: True - use only the last token's features (we recommend using this)lookback
: 256 - context window sizecontext_type
: "fullcontext" - how to handle contextcache_dir
: "cache_language_model" - directory for cachingThe language model extractor uses a sophisticated caching system:
This makes it efficient to experiment with different layers without recomputing features.