Probe API Documentation

Overview

The probe API provides an interface for defining, configuring, and attaching probes to backbone/base models that can be used to adapt or fine-tune the backbone/base models to downstream tasks.

Key Ideas:

  • Probes (and backbone models) are regular PyTorch modules (linear, MLP, LSTM, attention, transformer heads).

  • Configuration is done via ProbeConfig (Python) or YAML files that map to ProbeConfig.

  • Probes may be trained online (attached to a base model) or offline (on pre-computed embeddings).

Getting Started

1. Start Simple

Begin with a simple linear probe on the backbone’s last layer:

from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig

base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
cfg = ProbeConfig(
    probe_type="linear",
    target_layers=["last_layer"],
    aggregation="mean",
    freeze_backbone=True,
    online_training=True,
)
probe = build_probe_from_config(cfg, base_model=base, num_classes=50, device="cpu")

2. Increase Complexity If Needed

If performance plateaus, move to MLP, LSTM, attention, or transformer probes by changing probe_type and the related fields in ProbeConfig. Generally, attention probe works best with self-supervised models and transformers and it does not improve much on EfficientNet backbones.

3. Match Probe Complexity to Task

  • Simple in-domain tasks → linear probes work well on bird classification/detection tasks because most of the bioacoustics models were trained on this tasks

  • Out-of-domain tasks → attention/transformer probes on all layers or even lower layers work better for repertoire classification or species that were under-represented in the training data used for the backbones.

4. Consider Computational Budget

  • Limited resources_last variants with linear/MLP

  • Generous resources_all variants with attention/transformer

Performance Trade-offs

_last Variants

Pros:

  • Fast execution

  • Simple architecture

  • Lower memory usage

  • Fewer parameters to train

Cons:

  • Single representation, overfitted for species classification (mostly birds) in the case of supervised models

  • May miss multi-scale features

Use when:

  • Quick experiments needed

  • Limited computational resources

  • Strong, well-trained backbone

  • Simple classification tasks

_all Variants

Pros:

  • Rich multi-scale features

  • More expressive models

  • Better for complex tasks

  • Learns optimal layer weighting

Cons:

  • Slower execution

  • High disk usage in the case of offline probes

  • Higher memory requirements

  • More parameters to train

Use when:

  • Maximum performance needed

  • Sufficient computational resources

  • Out-of-domain tasks

  • Comparing layer-wise representations

Quick Selection Guide

Task Complexity:  LOW ──────────────────────────────────> HIGH
Probe Type:       linear → mlp → lstm → attention → transformer

Feature Scope:    SINGLE LAYER ─────────────────────────> ALL LAYERS
Variant:          _last ─────────────────────────────────> _all

Computational:    FAST ──────────────────────────────────> SLOW
                  linear_last ──────────────────────> transformer_all

Quick Start

Build and Use a Probe (Online Mode)

from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig

# 1. Load a backbone model that returns features
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")

# 2. Define a simple linear probe on the backbone features
probe_config = ProbeConfig(
    probe_type="linear",
    target_layers=["backbone"],   # use final backbone layer
    aggregation="mean",           # mean-pool over time
    freeze_backbone=True,         # keep backbone frozen
    online_training=True,         # end-to-end graph (even if backbone is frozen)
)

# 3. Build the probe
probe = build_probe_from_config(
    probe_config=probe_config,
    base_model=base,
    num_classes=50,
    device="cpu",
)

Offline Mode (Pre-computed Embeddings)

from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig

# For pre-computed embeddings (no base model needed)
probe_config = ProbeConfig(
    probe_type="linear",
    target_layers=["backbone"],   # conceptual; not used when base_model=None
    aggregation="none",
    freeze_backbone=True,
    online_training=False,
)

probe = build_probe_from_config(
    probe_config=probe_config,
    input_dim=768,                # embedding dimension
    num_classes=50,
    device="cpu",
)

# Use with embeddings
# For inference, set the probe to eval mode and use torch.no_grad()
probe.eval()
with torch.no_grad():
    predictions = probe(embeddings)  # embeddings shape: (batch, 768)

Note: The probe’s forward() method does not automatically use inference mode. For inference (when you don’t need gradients), you should:

  • Call probe.eval() to set the model to evaluation mode (disables dropout, batch norm updates, etc.)

  • Wrap the forward pass in with torch.no_grad(): to disable gradient computation and reduce memory usage

For training/fine-tuning, use probe.train() and omit the torch.no_grad() context.

Defining Probe Configurations

Probe Types

Common probe_type values:

  • linear – simple linear classifier

  • mlp – multi-layer perceptron

  • lstm – LSTM sequence model

  • attention – self-attention head

  • transformer – transformer encoder probe

Core Fields in ProbeConfig

All probe configs support (non-exhaustive):

  • Architecture & layers

    • probe_type: "linear" | "mlp" | "lstm" | "attention" | "transformer" - The architecture of the probe head:

      • "linear": 2D probe - Simple linear classifier (single fully-connected layer). Fastest and most memory-efficient. Expects 2D input (batch, features). Use with aggregation="mean" or "max". Best for: baseline performance, simple tasks, limited resources.

      • "mlp": 2D probe - Multi-layer perceptron with configurable hidden layers and non-linear activations. More expressive than linear while still efficient. Expects 2D input (batch, features). Use with aggregation="mean" or "max". Requires hidden_dims parameter. Best for: tasks needing non-linearity, moderate complexity.

      • "lstm": 3D probe - Long Short-Term Memory network for sequence modeling. Processes temporal sequences and captures long-range dependencies. Expects 3D input (batch, time, features). Use with aggregation="none" to preserve sequence structure. Requires lstm_hidden_size, num_layers, and optionally bidirectional. Best for: temporal/sequential tasks, variable-length sequences.

      • "attention": 3D probe - Self-attention mechanism for sequence modeling. Captures relationships between all positions in a sequence. Expects 3D input (batch, time, features). Use with aggregation="none" to preserve sequence structure. Requires num_heads and attention_dim. Best for: tasks requiring global sequence understanding, parallel processing.

      • "transformer": 3D probe - Full transformer encoder architecture with multiple attention layers. Most expressive and powerful probe type. Expects 3D input (batch, time, features). Use with aggregation="none" to preserve sequence structure. Requires num_heads, attention_dim, and num_layers. Best for: complex tasks, maximum performance, sufficient computational resources.

    • target_layers: List of layer names to extract embeddings from. Main options:

      • ["last_layer"]: Uses the final (non-classification) layer of the model. Best for: single-layer probing, baseline experiments, efficient computation.

      • ["all"]: Uses all discoverable layers in the model. Best for: multi-layer probing, learning optimal layer combinations, maximum expressiveness.

      • Specific layer names: Use concrete layer names (e.g., ["layer_6", "layer_12"]). Discover available layers using list_model_layers(model_name). Best for: targeted probing of specific layers, custom layer combinations.

    • aggregation: "mean" | "max" | "none" | "cls_token" - Controls how to reduce the time/sequence dimension of embeddings:

      • "mean": Average pooling over the time dimension. Reduces 3D embeddings (batch, time, features) to 2D (batch, features). Use with 2D probes (linear, MLP) that expect fixed-size feature vectors.

      • "max": Max pooling over the time dimension. Reduces 3D embeddings (batch, time, features) to 2D (batch, features). Alternative to mean pooling, can capture peak activations. Use with 2D probes (linear, MLP).

      • "none": No aggregation - preserves the full sequence structure (batch, time, features). Required for 3D probes (LSTM, attention, transformer) that process sequences. Also enables learned weighted combination of multiple layers.

      • "cls_token": Uses only the first token (CLS token) from transformer models. Reduces to 2D (batch, features). Use with transformer-based backbones and 2D probes.

    • input_processing: "pooled" | "sequence" | "flatten" | "none" - How to process input embeddings before feeding to the probe:

      • "pooled": Default - Pools embeddings to a fixed dimension. Works with embeddings that have already been aggregated (e.g., via aggregation="mean"). Use with 2D probes (linear, MLP) that expect fixed-size feature vectors.

      • "sequence": Keeps sequence structure - Preserves the temporal/sequence dimension (batch, time, features). Required for 3D probes (LSTM, attention, transformer) that process sequences. Only compatible with sequence-based probe types. Must use with aggregation="none".

      • "flatten": Flattens all dimensions - Reshapes multi-dimensional embeddings into a single vector. Converts any shape to (batch, features). Use when you need to flatten complex embeddings (e.g., 4D tensors) for 2D probes.

      • "none": No processing - Uses embeddings as-is without any transformation. Use when embeddings are already in the correct format for your probe type.

  • Training behavior

    • freeze_backbone: True to keep base model frozen

    • online_training: True for online (end-to-end graph) vs False for pure offline

  • Probe-specific parameters

    • MLP: hidden_dims, dropout_rate, activation, …

    • LSTM: lstm_hidden_size, num_layers, bidirectional, max_sequence_length, …

    • Attention/Transformer: num_heads, attention_dim, num_layers, max_sequence_length, use_positional_encoding, …

    See ProbeConfig class documentation or use ProbeConfig.model_json_schema() for complete parameter details, defaults, and valid ranges.

Example: Minimal Linear Probe (Python)

from avex.configs import ProbeConfig

probe_config = ProbeConfig(
    probe_type="linear",
    target_layers=["backbone"],
    aggregation="mean",
    freeze_backbone=True,
    online_training=True,
)

Example: YAML Probe Definition

# my_linear_probe.yml
probe_type: linear
target_layers: ["backbone"]
aggregation: mean
freeze_backbone: true
online_training: true
from avex.models.probes.utils import (
    load_probe_config,
    build_probe_from_config,
)
from avex import load_model

config = load_probe_config("my_linear_probe.yml")
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe = build_probe_from_config(config, base_model=base, num_classes=50, device="cpu")

API Reference

Factory Functions

build_probe_from_config()

Unified factory function for building probe instances from a ProbeConfig. Supports both online (with base model) and offline (with pre-computed embeddings) modes.

from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig

def build_probe_from_config(
    probe_config: ProbeConfig,
    num_classes: int,
    device: str,
    base_model: Optional[torch.nn.Module] = None,
    input_dim: Optional[int] = None,
    target_length: Optional[int] = None,
    **kwargs,
) -> torch.nn.Module:
    ...

Key parameters:

  • probe_config: The ProbeConfig object.

  • num_classes: Number of output classes.

  • device: "cpu" or "cuda", etc.

  • base_model: Optional backbone model to attach the probe to (for online mode). If provided, probe will be attached for end-to-end training.

  • input_dim: Optional embedding dimension (for offline mode). Required if base_model is None.

  • target_length: Optional audio target length override.

Mode detection:

  • Online mode: When base_model is provided, the probe is attached to the base model for end-to-end training.

  • Offline mode: When input_dim is provided, the probe operates on pre-computed embeddings without a base model.

Returns: A torch.nn.Module probe ready for training/inference.

Config Helpers

load_probe_config()

from avex.models.probes.utils import load_probe_config

config = load_probe_config("my_probe.yml")

Supports:

  • Files with top-level probe fields.

  • Files with a nested probe_config: {...} block.

Configuration Structure

All probe configs include:

  • probe_type - Type of probe architecture

  • target_layers - Which layers to extract features from

  • aggregation - How to aggregate features (mean, max, none)

  • input_processing - How to process inputs (pooled, sequence, flatten)

  • freeze_backbone - Whether to freeze backbone weights

  • online_training - Whether to train end-to-end or offline

Probe-specific parameters:

  • MLP: hidden_dims, dropout_rate, activation

  • LSTM: lstm_hidden_size, num_layers, bidirectional, max_sequence_length

  • Attention: num_heads, attention_dim, num_layers, max_sequence_length

  • Transformer: num_heads, attention_dim, num_layers, max_sequence_length

Usage Examples

Comparing Different Probe Architectures

from avex import load_model
from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig

base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")

probe_types = [
    ("linear", {"aggregation": "mean"}),
    ("mlp", {"aggregation": "mean", "hidden_dims": [512, 256]}),
    ("attention", {"input_processing": "sequence", "num_heads": 4, "attention_dim": 128}),
]

for probe_type, extra_cfg in probe_types:
    cfg = ProbeConfig(
        probe_type=probe_type,
        target_layers=["backbone"],
        freeze_backbone=True,
        online_training=True,
        **extra_cfg,
    )
    probe = build_probe_from_config(
        probe_config=cfg,
        base_model=base,
        num_classes=10,
        device="cpu",
    )
    print(probe_type, "parameters:", sum(p.numel() for p in probe.parameters()))

Expected output:

linear parameters: 7680
mlp parameters: 395264
attention parameters: 66560

Load from Custom YAML

# custom_probe.yml
# probe_type: mlp
# target_layers: ["backbone"]
# aggregation: mean
# hidden_dims: [1024, 512]

from avex.models.probes.utils import (
    build_probe_from_config,
    load_probe_config,
)
from avex import load_model

config = load_probe_config("custom_probe.yml")
base = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
probe = build_probe_from_config(config, base_model=base, num_classes=50, device="cpu")

Using ProbeConfig Programmatically

from avex.configs import ProbeConfig
from avex.models.probes.utils import build_probe_from_config

# Create config programmatically
config = ProbeConfig(
    probe_type="attention",
    target_layers=["layer_12"],
    aggregation="none",
    input_processing="sequence",
    num_heads=8,
    attention_dim=64,
    num_layers=1,
)

# Use it
probe = build_probe_from_config(config, base_model=my_model, num_classes=50, device="cpu")

Implementation Details

Architecture

The probe API mirrors the model API structure for consistency:

avex/
├── models/probes/
│   ├── utils/                          # Probe utilities (parallel to models/utils/)
│   │   ├── __init__.py
│   │   ├── registry.py                 # Probe class discovery + YAML helpers
│   │   └── factory.py                  # build_probe_from_config
│   └── [probe implementations]
└── examples/
    └── 07_probe_training_and_inference.py  # Usage examples

Core Components

registry.py

  • Probe Class Registry: _PROBE_CLASSES for discovered probe implementations

  • Discovery: Dynamically finds all probe classes (LinearProbe, MLPProbe, etc.)

  • YAML Helpers: load_probe_config() for loading ProbeConfig from disk

factory.py

  • build_probe_from_config(): Unified factory for building probes from ProbeConfig (supports both online and offline modes)

  • Handle parameter filtering and base-model interaction (freezing, hooks, feature-mode)

Testing

Verify Installation

from avex.models.probes import build_probe_from_config
from avex.configs import ProbeConfig
import torch

# Test offline mode (works independently)
cfg = ProbeConfig(
    probe_type="linear",
    target_layers=["backbone"],
    aggregation="none",
    freeze_backbone=True,
    online_training=False,
)
probe = build_probe_from_config(
    cfg,
    input_dim=768,
    num_classes=10,
    device="cpu",
)

# Test forward pass (inference mode)
probe.eval()
with torch.no_grad():
    dummy_embeddings = torch.randn(2, 768)
    output = probe(dummy_embeddings)
    print(f"Output shape: {output.shape}")  # Should be (2, 10)

Run Example Script

cd /home/marius/code/avex
python examples/07_probe_training_and_inference.py

Tested Functionality

Probe Discovery: Automatically finds all probe classes

Config Loading: load_probe_config() builds ProbeConfig from YAML

Factory Usage: build_probe_from_config() builds probes from ProbeConfig (supports both online and offline modes)

Offline Mode: Creates probes for pre-computed embeddings

Online Mode: Loads and attaches to base models

Forward Pass: Correct output shapes with dummy data

No Linter Errors: All code is ruff-compliant

Layer Variants: _last and _all variants work correctly

Known Issues

  • Model Registry: Pre-existing circular import prevents model loading in some contexts

    • This is a separate issue in the existing codebase

    • Doesn’t affect offline probe functionality

    • Doesn’t affect direct model instance usage

Files Created

Core Implementation

  • models/probes/utils/__init__.py

  • models/probes/utils/registry.py

  • models/probes/utils/factory.py

Examples and Documentation

Future Enhancements

The following components were intentionally not implemented:

  • models/probes/utils/checkpoint.py - Checkpoint save/load utilities

  • Embedding extraction utilities

These can be added in future iterations following the same design patterns.

See Also