Flexible Probe System¶
This document describes the new flexible probe system that replaces the old rigid linear probe approach. The new system allows you to configure different types of probes with various aggregation and processing strategies.
Overview¶
The new probe system provides:
Multiple probe types: Linear, MLP, LSTM, Attention, and Transformer probes, plus weighted versions of each
Flexible aggregation: Mean, max, concatenation, CLS token, or no aggregation
Input processing options: Flatten, sequence, pooled, or no processing
Probe-specific parameters: Hidden dimensions, attention heads, LSTM configuration, etc.
Training overrides: Per-probe learning rates, batch sizes, epochs, etc.
Backward compatibility: Legacy configurations still work automatically
Probe Types¶
1. Linear Probe ("linear")¶
Simple linear classification layer. Good for baseline performance.
probe_config:
probe_type: "linear"
aggregation: "mean"
input_processing: "pooled"
target_layers: ["layer_12"]
2. MLP Probe ("mlp")¶
Multi-layer perceptron with configurable hidden dimensions.
probe_config:
probe_type: "mlp"
aggregation: "mean"
input_processing: "pooled"
target_layers: ["layer_8", "layer_12"]
hidden_dims: [512, 256]
dropout_rate: 0.2
activation: "gelu"
3. LSTM Probe ("lstm")¶
Long Short-Term Memory network for sequence modeling.
probe_config:
probe_type: "lstm"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
lstm_hidden_size: 256
num_layers: 2
bidirectional: true
max_sequence_length: 1000
4. Attention Probe ("attention")¶
Attention mechanism for sequence modeling.
probe_config:
probe_type: "attention"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_6", "layer_10"]
num_heads: 8
attention_dim: 512
num_layers: 2
max_sequence_length: 800
use_positional_encoding: true
5. Transformer Probe ("transformer")¶
Full transformer architecture for complex sequence modeling.
probe_config:
probe_type: "transformer"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 12
attention_dim: 768
num_layers: 4
max_sequence_length: 1200
use_positional_encoding: true
Weighted Probe Types¶
Weighted probe types are enhanced versions of the standard probes that use learned weights to combine multiple layer embeddings. They provide a single architecture head that learns optimal weights for combining embeddings from different layers.
6. Weighted Linear Probe ("weighted_linear")¶
Single linear classifier with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_linear"
aggregation: "none" # Required for weighted probes
input_processing: "pooled"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
freeze_backbone: true
7. Weighted MLP Probe ("weighted_mlp")¶
Single MLP with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_mlp"
aggregation: "none" # Required for weighted probes
input_processing: "pooled"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
hidden_dims: [512, 256]
dropout_rate: 0.2
activation: "gelu"
freeze_backbone: true
8. Weighted LSTM Probe ("weighted_lstm")¶
Single LSTM with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_lstm"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
lstm_hidden_size: 128
num_layers: 2
bidirectional: true
max_sequence_length: 1000
use_positional_encoding: false
dropout_rate: 0.3
freeze_backbone: true
9. Weighted Attention Probe ("weighted_attention")¶
Single attention mechanism with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_attention"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 8
attention_dim: 256
num_layers: 2
max_sequence_length: 800
use_positional_encoding: false
dropout_rate: 0.3
freeze_backbone: true
10. Weighted Minimal Attention Probe ("weighted_attention_minimal")¶
Single minimal attention mechanism with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_attention_minimal"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 4
freeze_backbone: true
11. Weighted Transformer Probe ("weighted_transformer")¶
Single transformer encoder with learned weights for combining multiple layer embeddings.
probe_config:
probe_type: "weighted_transformer"
aggregation: "none" # Required for weighted probes
input_processing: "sequence"
target_layers: ["layer_4", "layer_6", "layer_8", "layer_10", "layer_12"]
num_heads: 12
attention_dim: 768
num_layers: 4
max_sequence_length: 1200
use_positional_encoding: true
dropout_rate: 0.3
freeze_backbone: true
Key Features of Weighted Probes¶
Single Architecture Head: Each weighted probe uses one architecture component (linear, MLP, LSTM, attention, transformer) instead of multiple projection heads per layer
Learned Weighted Sum: Uses
nn.Parameterto learn optimal weights for combining multiple layer embeddingsDimension Validation: Ensures all embeddings have the same dimension for weighted sum aggregation
Weight Debugging: All weighted probes implement
print_learned_weights()method to show which layers are most importantEfficiency: More efficient than multiple projection heads while maintaining or improving performance
Requirements for Weighted Probes¶
Aggregation: Must use
aggregation: "none"to enable learned weightsMultiple Layers: Requires multiple target layers to learn meaningful weights
Same Dimensions: All layer embeddings must have the same dimension for weighted sum
Aggregation Methods¶
"mean"¶
Average embeddings across layers (default for backward compatibility).
"max"¶
Take maximum values across layers.
"concat"¶
Concatenate embeddings from all layers (requires larger probe networks).
"cls_token"¶
Use only the CLS token from sequence-based models.
"none"¶
No aggregation - pass embeddings directly to sequence-based probes.
Input Processing Methods¶
"pooled"¶
Pool embeddings to fixed dimension (default for backward compatibility).
"sequence"¶
Keep sequence structure for sequence-based probes.
"flatten"¶
Flatten all dimensions into a single vector.
"none"¶
No processing - use embeddings as-is.
Configuration Examples¶
Basic Linear Probe (Legacy Style)¶
experiments:
- run_name: "simple_linear"
run_config: "configs/run_configs/example_run.yml"
pretrained: true
layers: "layer_12" # Legacy field
frozen: true # Legacy field
Advanced MLP Probe¶
experiments:
- run_name: "advanced_mlp"
run_config: "configs/run_configs/example_run.yml"
pretrained: true
probe_config:
name: "advanced_mlp"
probe_type: "mlp"
aggregation: "concat"
input_processing: "pooled"
target_layers: ["layer_6", "layer_8", "layer_10", "layer_12"]
freeze_backbone: true
learning_rate: 3e-4 # Override global LR
batch_size: 4 # Override global batch size
hidden_dims: [1024, 512, 256]
dropout_rate: 0.15
activation: "relu"
Sequence LSTM Probe¶
experiments:
- run_name: "sequence_lstm"
run_config: "configs/run_configs/example_run.yml"
pretrained: true
probe_config:
name: "sequence_lstm"
probe_type: "lstm"
aggregation: "none"
input_processing: "sequence"
target_layers: ["layer_8", "layer_12"]
lstm_hidden_size: 256
num_layers: 2
bidirectional: true
max_sequence_length: 1000
use_positional_encoding: false
Migration from Legacy System¶
The new system automatically handles legacy configurations:
Legacy fields still work:
layersandfrozenfields are automatically converted toprobe_configNo breaking changes: Existing configurations continue to work without modification
Gradual migration: You can update configurations one at a time
Before (Legacy)¶
experiments:
- run_name: "old_style"
layers: "layer_12"
frozen: true
After (New Style)¶
experiments:
- run_name: "new_style"
probe_config:
probe_type: "linear"
aggregation: "mean"
input_processing: "pooled"
target_layers: ["layer_12"]
freeze_backbone: true
Training Parameter Overrides¶
Each probe can override global training parameters:
probe_config:
# ... other config ...
learning_rate: 5e-4 # Override global lr
batch_size: 4 # Override global batch_size
train_epochs: 15 # Override global train_epochs
optimizer: "adam" # Override global optimizer
weight_decay: 0.001 # Override global weight_decay
Best Practices¶
1. Choose Appropriate Probe Types¶
Linear: Baseline performance, quick experiments
MLP: Better performance, moderate complexity
LSTM: Sequence modeling, moderate complexity
Attention: Sequence modeling, higher complexity
Transformer: Complex sequence modeling, highest complexity
Weighted Probes: Enhanced versions that learn optimal weights for combining multiple layers
Use when you want to leverage multiple layers efficiently
Better performance than concatenation with lower computational cost
Provides interpretability through learned layer weights
2. Layer Selection¶
Single layer: Use
["layer_12"]for final representationsMultiple layers: Use
["layer_6", "layer_8", "layer_10", "layer_12"]for hierarchical featuresEarly layers: Use
["layer_1", "layer_2", "layer_3"]for low-level features
3. Aggregation Strategy¶
Mean/Max: Good for classification tasks
Concat: Better for complex tasks, requires larger probe networks
None: Required for sequence-based probes and weighted probes
Weighted Sum: Automatic with weighted probes when using
aggregation: "none"
4. Input Processing¶
Pooled: Good for classification tasks
Sequence: Required for sequence-based probes
Flatten: Good for spatial features
Validation¶
The system automatically validates configurations:
Required parameters for each probe type
Compatibility between aggregation and input processing methods
Valid parameter ranges (positive integers, valid activation functions, etc.)
Layer name consistency
Error Handling¶
Common validation errors and solutions:
Missing Required Parameters¶
# Error: MLP probe requires hidden_dims
probe_config:
probe_type: "mlp"
# Missing: hidden_dims
# Solution: Add required parameters
probe_config:
probe_type: "mlp"
hidden_dims: [512, 256]
Incompatible Configuration¶
# Error: cls_token aggregation requires sequence input_processing
probe_config:
aggregation: "cls_token"
input_processing: "pooled"
# Solution: Use sequence input_processing
probe_config:
aggregation: "cls_token"
input_processing: "sequence"
Performance Considerations¶
Memory Usage¶
Linear/MLP: Low memory usage
LSTM: Moderate memory usage
Attention/Transformer: Higher memory usage
Training Speed¶
Linear: Fastest training
MLP: Fast training
LSTM: Moderate training speed
Attention/Transformer: Slower training
Inference Speed¶
Linear: Fastest inference
MLP: Fast inference
LSTM: Moderate inference speed
Attention/Transformer: Slower inference
Troubleshooting¶
Common Issues¶
Out of Memory: Reduce batch size or use simpler probe types
Slow Training: Use simpler probe types or reduce hidden dimensions
Poor Performance: Try different aggregation methods or layer combinations
Validation Errors: Check parameter compatibility and required fields
Debug Mode¶
Enable debug logging to see detailed configuration validation:
import logging
logging.basicConfig(level=logging.DEBUG)
Future Extensions¶
The system is designed to be extensible:
New probe types: Easy to add new probe architectures
Custom aggregations: Support for custom aggregation functions
Advanced processing: More sophisticated input processing methods
Hyperparameter optimization: Integration with hyperparameter search tools