Embedding Extraction and Feature Representations

Understanding return_features_only=True

When loading models with return_features_only=True, the model returns unpooled features instead of classification logits. This preserves temporal and spatial information, providing richer representations for downstream tasks.

# Load model for embedding extraction
model = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
model.eval()

# Get unpooled features
audio = torch.randn(1, 16000 * 5)  # 5 seconds at 16kHz
features = model(audio, padding_mask=None)
# features.shape = (batch, time_steps, feature_dim)

Model-Specific Output Formats

Different models return features in different formats when return_features_only=True:

BEATs (Bidirectional Encoder representation from Audio Transformers)

Output Shape: (batch, time_steps, 768)

Key Characteristics:

  • Each time step contains 8 embeddings (one per frequency band)

  • Structure: [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5, T0_6, T0_7, T1_0, T1_1, ...]

  • Frame rate: 6.25 Hz (not 100 Hz)

    • Calculated as: 100 Hz / 16 (patch embedding size) = 6.25 Hz

    • For 16 kHz input audio

  • Feature dimension: 768 per embedding

Example:

model = load_model("esp_aves2_naturelm_audio_v1_beats", return_features_only=True, device="cpu")
audio = torch.randn(1, 16000 * 5)  # 5 seconds at 16kHz
features = model(audio, padding_mask=None)
# features.shape = (1, ~31, 768)
# 31 frames ≈ 5 seconds * 6.25 Hz
# Each frame has 768-dimensional features representing 8 frequency bands

Understanding BEATs Frame Structure:

  • Audio at 16 kHz: 16,000 samples per second

  • Patch embedding size: 16 samples

  • Base frame rate: 100 Hz (1000 ms / 10 ms per frame)

  • Actual frame rate after patching: 100 Hz / 16 = 6.25 Hz

  • Each frame covers: 1 / 6.25 = 160 ms of audio

  • For 5 seconds of audio: 5 * 6.25 = 31.25 frames

Use Cases:

# Option 1: Pool manually for classification
pooled = features.mean(dim=1)  # (batch, 768)

# Option 2: Use specific frequency band
band_0 = features[:, :, :96]  # First frequency band (assuming 96-dim per band)

# Option 3: Use for sequence modeling
# Features preserve temporal structure for RNNs, Transformers, etc.

EAT (Efficient Audio Transformer)

Output Shape: (batch, num_patches, 768)

Key Characteristics:

  • Returns unpooled patch embeddings from transformer backbone

  • Includes CLS token as first patch (index 0)

  • Number of patches depends on input length and patch size

  • Feature dimension: 768 per patch

Example:

model = load_model("esp_aves2_sl_eat_all_ssl_all", return_features_only=True, device="cpu")
audio = torch.randn(1, 16000 * 5)  # 5 seconds at 16kHz
features = model(audio, padding_mask=None)
# features.shape = (1, 513, 768)
# 513 patches = 1 CLS token + 512 spectrogram patches

Use Cases:

# Option 1: Use CLS token (typically most informative)
cls_token = features[:, 0]  # (batch, 768)

# Option 2: Mean pooling over all patches
pooled = features.mean(dim=1)  # (batch, 768)

# Option 3: Exclude CLS token and pool
spatial_features = features[:, 1:]  # Exclude CLS token
pooled = spatial_features.mean(dim=1)  # (batch, 768)

EfficientNet

Output Shape: (batch, channels, height, width)

Key Characteristics:

  • Returns spatial feature maps before global average pooling

  • Preserves 2D spatial structure of spectrogram

  • Channel and spatial dimensions depend on model variant

Example:

model = load_model("esp_aves2_effnetb0_all", return_features_only=True, device="cpu")
audio = torch.randn(1, 16000 * 5)  # 5 seconds at 16kHz
features = model(audio, padding_mask=None)
# features.shape = (1, 1280, 4, 5) for EfficientNet-B0
# 1280 channels, 4x5 spatial dimensions

Use Cases:

# Option 1: Global average pooling
pooled = features.mean(dim=[2, 3])  # (batch, 1280)

# Option 2: Max pooling
pooled = features.amax(dim=[2, 3])  # (batch, 1280)

# Option 3: Flatten for spatial awareness
flattened = features.flatten(1)  # (batch, 1280*4*5)

See examples/05_embedding_extraction.py for comprehensive examples of embedding extraction with different models.