Probe API Documentation¶
Overview¶
The probe API provides an interface for defining, configuring, and attaching probes to backbone/base models that can be used to adapt or fine-tune the backbone/base models to downstream tasks.
Key Ideas:
Probes (and backbone models) are regular PyTorch modules (linear, MLP, LSTM, attention, transformer heads).
Configuration is done via
ProbeConfig(Python) or YAML files that map toProbeConfig.Probes may be trained online (attached to a base model) or offline (on pre-computed embeddings).
Getting Started¶
1. Start Simple¶
Begin with a simple linear probe on the backbone’s last layer:
from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
cfg = ProbeConfig(
probe_type="linear",
target_layers=["last_layer"],
aggregation="mean",
freeze_backbone=True,
online_training=True,
)
probe = build_probe_from_config(cfg, base_model=base, num_classes=50, device="cpu")
2. Increase Complexity If Needed¶
If performance plateaus, move to MLP, LSTM, attention, or transformer probes by changing probe_type and the related fields in ProbeConfig. Generally, attention probe works best with self-supervised models and transformers and it does not improve much on EfficientNet backbones.
3. Match Probe Complexity to Task¶
Simple in-domain tasks → linear probes work well on bird classification/detection tasks because most of the bioacoustics models were trained on this tasks
Out-of-domain tasks → attention/transformer probes on all layers or even lower layers work better for repertoire classification or species that were under-represented in the training data used for the backbones.
4. Consider Computational Budget¶
Limited resources →
_lastvariants with linear/MLPGenerous resources →
_allvariants with attention/transformer
Performance Trade-offs¶
_last Variants¶
Pros:
Fast execution
Simple architecture
Lower memory usage
Fewer parameters to train
Cons:
Single representation, overfitted for species classification (mostly birds) in the case of supervised models
May miss multi-scale features
Use when:
Quick experiments needed
Limited computational resources
Strong, well-trained backbone
Simple classification tasks
_all Variants¶
Pros:
Rich multi-scale features
More expressive models
Better for complex tasks
Learns optimal layer weighting
Cons:
Slower execution
High disk usage in the case of offline probes
Higher memory requirements
More parameters to train
Use when:
Maximum performance needed
Sufficient computational resources
Out-of-domain tasks
Comparing layer-wise representations
Quick Selection Guide¶
Task Complexity: LOW ──────────────────────────────────> HIGH
Probe Type: linear → mlp → lstm → attention → transformer
Feature Scope: SINGLE LAYER ─────────────────────────> ALL LAYERS
Variant: _last ─────────────────────────────────> _all
Computational: FAST ──────────────────────────────────> SLOW
linear_last ──────────────────────> transformer_all
Quick Start¶
Build and Use a Probe (Online Mode)¶
from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
# 1. Load a backbone model that returns features
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
# 2. Define a simple linear probe on the backbone features
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"], # use final backbone layer
aggregation="mean", # mean-pool over time
freeze_backbone=True, # keep backbone frozen
online_training=True, # end-to-end graph (even if backbone is frozen)
)
# 3. Build the probe
probe = build_probe_from_config(
probe_config=probe_config,
base_model=base,
num_classes=50,
device="cpu",
)
Offline Mode (Pre-computed Embeddings)¶
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
# For pre-computed embeddings (no base model needed)
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"], # conceptual; not used when base_model=None
aggregation="none",
freeze_backbone=True,
online_training=False,
)
probe = build_probe_from_config(
probe_config=probe_config,
input_dim=768, # embedding dimension
num_classes=50,
device="cpu",
)
# Use with embeddings
# For inference, set the probe to eval mode and use torch.no_grad()
probe.eval()
with torch.no_grad():
predictions = probe(embeddings) # embeddings shape: (batch, 768)
Note: The probe’s forward() method does not automatically use inference mode. For inference (when you don’t need gradients), you should:
Call
probe.eval()to set the model to evaluation mode (disables dropout, batch norm updates, etc.)Wrap the forward pass in
with torch.no_grad():to disable gradient computation and reduce memory usage
For training/fine-tuning, use probe.train() and omit the torch.no_grad() context.
Defining Probe Configurations¶
Probe Types¶
Common probe_type values:
linear– simple linear classifiermlp– multi-layer perceptronlstm– LSTM sequence modelattention– self-attention headtransformer– transformer encoder probe
Core Fields in ProbeConfig¶
All probe configs support (non-exhaustive):
Architecture & layers
probe_type:"linear" | "mlp" | "lstm" | "attention" | "transformer"- The architecture of the probe head:"linear": 2D probe - Simple linear classifier (single fully-connected layer). Fastest and most memory-efficient. Expects 2D input(batch, features). Use withaggregation="mean"or"max". Best for: baseline performance, simple tasks, limited resources."mlp": 2D probe - Multi-layer perceptron with configurable hidden layers and non-linear activations. More expressive than linear while still efficient. Expects 2D input(batch, features). Use withaggregation="mean"or"max". Requireshidden_dimsparameter. Best for: tasks needing non-linearity, moderate complexity."lstm": 3D probe - Long Short-Term Memory network for sequence modeling. Processes temporal sequences and captures long-range dependencies. Expects 3D input(batch, time, features). Use withaggregation="none"to preserve sequence structure. Requireslstm_hidden_size,num_layers, and optionallybidirectional. Best for: temporal/sequential tasks, variable-length sequences."attention": 3D probe - Self-attention mechanism for sequence modeling. Captures relationships between all positions in a sequence. Expects 3D input(batch, time, features). Use withaggregation="none"to preserve sequence structure. Requiresnum_headsandattention_dim. Best for: tasks requiring global sequence understanding, parallel processing."transformer": 3D probe - Full transformer encoder architecture with multiple attention layers. Most expressive and powerful probe type. Expects 3D input(batch, time, features). Use withaggregation="none"to preserve sequence structure. Requiresnum_heads,attention_dim, andnum_layers. Best for: complex tasks, maximum performance, sufficient computational resources.
target_layers: List of layer names to extract embeddings from. Main options:["last_layer"]: Uses the final (non-classification) layer of the model. Best for: single-layer probing, baseline experiments, efficient computation.["all"]: Uses all discoverable layers in the model. Best for: multi-layer probing, learning optimal layer combinations, maximum expressiveness.Specific layer names: Use concrete layer names (e.g.,
["layer_6", "layer_12"]). Discover available layers usinglist_model_layers(model_name). Best for: targeted probing of specific layers, custom layer combinations.
aggregation:"mean" | "max" | "none" | "cls_token"- Controls how to reduce the time/sequence dimension of embeddings:"mean": Average pooling over the time dimension. Reduces 3D embeddings(batch, time, features)to 2D(batch, features). Use with 2D probes (linear, MLP) that expect fixed-size feature vectors."max": Max pooling over the time dimension. Reduces 3D embeddings(batch, time, features)to 2D(batch, features). Alternative to mean pooling, can capture peak activations. Use with 2D probes (linear, MLP)."none": No aggregation - preserves the full sequence structure(batch, time, features). Required for 3D probes (LSTM, attention, transformer) that process sequences. Also enables learned weighted combination of multiple layers."cls_token": Uses only the first token (CLS token) from transformer models. Reduces to 2D(batch, features). Use with transformer-based backbones and 2D probes.
input_processing:"pooled" | "sequence" | "flatten" | "none"- How to process input embeddings before feeding to the probe:"pooled": Default - Pools embeddings to a fixed dimension. Works with embeddings that have already been aggregated (e.g., viaaggregation="mean"). Use with 2D probes (linear, MLP) that expect fixed-size feature vectors."sequence": Keeps sequence structure - Preserves the temporal/sequence dimension(batch, time, features). Required for 3D probes (LSTM, attention, transformer) that process sequences. Only compatible with sequence-based probe types. Must use withaggregation="none"."flatten": Flattens all dimensions - Reshapes multi-dimensional embeddings into a single vector. Converts any shape to(batch, features). Use when you need to flatten complex embeddings (e.g., 4D tensors) for 2D probes."none": No processing - Uses embeddings as-is without any transformation. Use when embeddings are already in the correct format for your probe type.
Training behavior
freeze_backbone:Trueto keep base model frozenonline_training:Truefor online (end-to-end graph) vsFalsefor pure offline
Probe-specific parameters
MLP:
hidden_dims,dropout_rate,activation, …LSTM:
lstm_hidden_size,num_layers,bidirectional,max_sequence_length, …Attention/Transformer:
num_heads,attention_dim,num_layers,max_sequence_length,use_positional_encoding, …
See
ProbeConfigclass documentation or useProbeConfig.model_json_schema()for complete parameter details, defaults, and valid ranges.
Example: Minimal Linear Probe (Python)¶
from avex.configs import ProbeConfig
probe_config = ProbeConfig(
probe_type="linear",
target_layers=["backbone"],
aggregation="mean",
freeze_backbone=True,
online_training=True,
)
Example: YAML Probe Definition¶
# my_linear_probe.yml
probe_type: linear
target_layers: ["backbone"]
aggregation: mean
freeze_backbone: true
online_training: true
from avex.models.probes.utils import (
load_probe_config,
build_probe_from_config,
)
from avex import load_model
config = load_probe_config("my_linear_probe.yml")
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe = build_probe_from_config(config, base_model=base, num_classes=50, device="cpu")
API Reference¶
Factory Functions¶
build_probe_from_config()¶
Unified factory function for building probe instances from a ProbeConfig. Supports both online (with base model) and offline (with pre-computed embeddings) modes.
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
def build_probe_from_config(
probe_config: ProbeConfig,
num_classes: int,
device: str,
base_model: Optional[torch.nn.Module] = None,
input_dim: Optional[int] = None,
target_length: Optional[int] = None,
**kwargs,
) -> torch.nn.Module:
...
Key parameters:
probe_config: TheProbeConfigobject.num_classes: Number of output classes.device:"cpu"or"cuda", etc.base_model: Optional backbone model to attach the probe to (for online mode). If provided, probe will be attached for end-to-end training.input_dim: Optional embedding dimension (for offline mode). Required ifbase_modelis None.target_length: Optional audio target length override.
Mode detection:
Online mode: When
base_modelis provided, the probe is attached to the base model for end-to-end training.Offline mode: When
input_dimis provided, the probe operates on pre-computed embeddings without a base model.
Returns: A torch.nn.Module probe ready for training/inference.
Config Helpers¶
load_probe_config()¶
from avex.models.probes.utils import load_probe_config
config = load_probe_config("my_probe.yml")
Supports:
Files with top-level probe fields.
Files with a nested
probe_config: {...}block.
Configuration Structure¶
All probe configs include:
probe_type- Type of probe architecturetarget_layers- Which layers to extract features fromaggregation- How to aggregate features (mean, max, none)input_processing- How to process inputs (pooled, sequence, flatten)freeze_backbone- Whether to freeze backbone weightsonline_training- Whether to train end-to-end or offline
Probe-specific parameters:
MLP:
hidden_dims,dropout_rate,activationLSTM:
lstm_hidden_size,num_layers,bidirectional,max_sequence_lengthAttention:
num_heads,attention_dim,num_layers,max_sequence_lengthTransformer:
num_heads,attention_dim,num_layers,max_sequence_length
Usage Examples¶
Comparing Different Probe Architectures¶
from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe_types = [
("linear", {"aggregation": "mean"}),
("mlp", {"aggregation": "mean", "hidden_dims": [512, 256]}),
("attention", {"input_processing": "sequence", "num_heads": 4, "attention_dim": 128}),
]
for probe_type, extra_cfg in probe_types:
cfg = ProbeConfig(
probe_type=probe_type,
target_layers=["backbone"],
freeze_backbone=True,
online_training=True,
**extra_cfg,
)
probe = build_probe_from_config(
probe_config=cfg,
base_model=base,
num_classes=10,
device="cpu",
)
print(probe_type, "parameters:", sum(p.numel() for p in probe.parameters()))
Expected output:
linear parameters: 7680
mlp parameters: 395264
attention parameters: 66560
Load from Custom YAML¶
# custom_probe.yml
# probe_type: mlp
# target_layers: ["backbone"]
# aggregation: mean
# hidden_dims: [1024, 512]
from avex.models.probes.utils import (
build_probe_from_config,
load_probe_config,
)
from avex import load_model
config = load_probe_config("custom_probe.yml")
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe = build_probe_from_config(config, base_model=base, num_classes=50, device="cpu")
Using ProbeConfig Programmatically¶
from avex.configs import ProbeConfig
from avex.models.probes.utils import build_probe_from_config
# Create config programmatically
config = ProbeConfig(
probe_type="attention",
target_layers=["layer_12"],
aggregation="none",
input_processing="sequence",
num_heads=8,
attention_dim=64,
num_layers=1,
)
# Use it
probe = build_probe_from_config(config, base_model=my_model, num_classes=50, device="cpu")
Implementation Details¶
Architecture¶
The probe API mirrors the model API structure for consistency:
avex/
├── models/probes/
│ ├── utils/ # Probe utilities (parallel to models/utils/)
│ │ ├── __init__.py
│ │ ├── registry.py # Probe class discovery + YAML helpers
│ │ └── factory.py # build_probe_from_config
│ └── [probe implementations]
└── examples/
└── 07_probe_training_and_inference.py # Usage examples
Core Components¶
registry.py¶
Probe Class Registry:
_PROBE_CLASSESfor discovered probe implementationsDiscovery: Dynamically finds all probe classes (LinearProbe, MLPProbe, etc.)
YAML Helpers:
load_probe_config()for loadingProbeConfigfrom disk
factory.py¶
build_probe_from_config(): Unified factory for building probes from
ProbeConfig(supports both online and offline modes)Handle parameter filtering and base-model interaction (freezing, hooks, feature-mode)
Testing¶
Verify Installation¶
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
import torch
# Test offline mode (works independently)
cfg = ProbeConfig(
probe_type="linear",
target_layers=["backbone"],
aggregation="none",
freeze_backbone=True,
online_training=False,
)
probe = build_probe_from_config(
cfg,
input_dim=768,
num_classes=10,
device="cpu",
)
# Test forward pass (inference mode)
probe.eval()
with torch.no_grad():
dummy_embeddings = torch.randn(2, 768)
output = probe(dummy_embeddings)
print(f"Output shape: {output.shape}") # Should be (2, 10)
Run Example Script¶
cd /home/marius/code/avex
python examples/07_probe_training_and_inference.py
Tested Functionality¶
✅ Probe Discovery: Automatically finds all probe classes
✅ Config Loading: load_probe_config() builds ProbeConfig from YAML
✅ Factory Usage: build_probe_from_config() builds probes from ProbeConfig (supports both online and offline modes)
✅ Offline Mode: Creates probes for pre-computed embeddings
✅ Online Mode: Loads and attaches to base models
✅ Forward Pass: Correct output shapes with dummy data
✅ No Linter Errors: All code is ruff-compliant
✅ Layer Variants: _last and _all variants work correctly
Known Issues¶
Model Registry: Pre-existing circular import prevents model loading in some contexts
This is a separate issue in the existing codebase
Doesn’t affect offline probe functionality
Doesn’t affect direct model instance usage
Files Created¶
Core Implementation¶
models/probes/utils/__init__.pymodels/probes/utils/registry.pymodels/probes/utils/factory.py
Examples and Documentation¶
docs/api_probes.md(this file)
Future Enhancements¶
The following components were intentionally not implemented:
models/probes/utils/checkpoint.py- Checkpoint save/load utilitiesEmbedding extraction utilities
These can be added in future iterations following the same design patterns.
See Also¶
examples/07_probe_training_and_inference.py- Complete usage examplesavex/models/probes/- Probe implementationsModel API documentation for parallel structure reference