BEATs Transformer Layer Analysis

This notebook probes the internal representations of the BEATs transformer at each layer to understand how acoustic information is encoded as it flows through the network.

Goal: Determine whether early BEATs layers capture low-level acoustic features (e.g. frequency content, temporal envelope) while later layers encode high-level semantic features (e.g. species identity, call type). We do this by:

  1. Registering forward_hooks on each transformer block to capture hidden states.

  2. Mean-pooling each layer’s output over the time dimension to get a fixed-size embedding.

  3. Training a linear probe (logistic regression) on each layer’s embeddings and recording classification accuracy.

  4. Visualising the per-layer accuracy curve and UMAP projections to inspect the geometry of representations at different depths.

Datasets used: Giant Otter calls (giant_otter) and Zebra Finch songs (zebra_finch). Run examples 01 and 05 first to generate pre-extracted embeddings, or use the synthetic data fallback below for a quick demo.

import warnings

warnings.filterwarnings("ignore")

# --- Setup ---
import pathlib

import numpy as np
import plotly.io as pio
import torch
from IPython.display import display

pio.renderers.default = "notebook"  # ensures text/html output for Sphinx/myst-nb

from avex import load_model

from utils.probing import run_attention_probe, run_linear_probe
from utils.visualization import (
    plot_layer_curve,
    plot_layer_curve_static,
    plot_model_comparison,
    plot_model_comparison_static,
    plot_umap_grid,
    plot_umap_grid_static,
)

1. Load model and register forward hooks

We load the esp_aves2_sl_beats_all BEATs model and attach a register_forward_hook to every transformer encoder block. Each hook stores the block’s output tensor keyed by layer index so we can probe all layers in a single forward pass.

# Load the BEATs model (features only — no classification head)
model = load_model("esp_aves2_sl_beats_all", return_features_only=True, device="cpu")
model.eval()

# Inspect the model structure to locate transformer blocks
print(model)
Model(
  (backbone): BEATs(
    (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
    (fbank): _BatchedFbank()
    (patch_embedding): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (dropout_input): Dropout(p=0.0, inplace=False)
    (encoder): TransformerEncoder(
      (pos_conv): Sequential(
        (0): ParametrizedConv1d(
          768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
        (1): SamePad()
        (2): GELU(approximate='none')
      )
      (layers): ModuleList(
        (0): _TransformerSentenceEncoderLayer(
          (self_attn): _MultiheadAttention(
            (relative_attention_bias): Embedding(320, 12)
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
            (grep_linear): Linear(in_features=64, out_features=8, bias=True)
          )
          (dropout1): Dropout(p=0.0, inplace=False)
          (dropout2): Dropout(p=0.0, inplace=False)
          (dropout3): Dropout(p=0.0, inplace=False)
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1-11): 11 x _TransformerSentenceEncoderLayer(
          (self_attn): _MultiheadAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
            (grep_linear): Linear(in_features=64, out_features=8, bias=True)
            (relative_attention_bias): Embedding(320, 12)
          )
          (dropout1): Dropout(p=0.0, inplace=False)
          (dropout2): Dropout(p=0.0, inplace=False)
          (dropout3): Dropout(p=0.0, inplace=False)
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
      (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (predictor_dropout): Dropout(p=0.0, inplace=False)
    (predictor): Linear(in_features=768, out_features=527, bias=True)
  )
  (classifier): None
)
# Locate the transformer encoder layers.
# BEATs typically exposes them at model.model.encoder.layers or
# model.backbone.encoder.layers — adjust if the print above shows a different path.
try:
    encoder_layers = model.model.encoder.layers
except AttributeError:
    encoder_layers = model.backbone.encoder.layers

n_layers = len(encoder_layers)
print(f"Found {n_layers} transformer encoder layers.")

# Register hooks — one per layer block
layer_outputs = {}
hooks = []

for i, layer in enumerate(encoder_layers):

    def make_hook(idx):
        def hook(module, input, output):
            # BEATs layer output may be a tuple (hidden_state, ...) — take the tensor
            layer_outputs[idx] = output[0] if isinstance(output, tuple) else output

        return hook

    hooks.append(layer.register_forward_hook(make_hook(i)))

print("Hooks registered on all layers.")
Found 12 transformer encoder layers.
Hooks registered on all layers.

2. Dataset loading

We attempt to load pre-extracted embeddings produced by examples 01 and 05. If those outputs are not found we fall back to small synthetic data so the notebook can run end-to-end for a quick demonstration.

REPO_ROOT = pathlib.Path().resolve().parents[1]


def load_embeddings_and_labels(example_dir_name, outputs_subdir="outputs"):
    """Try to load embeddings/labels from a previous example's output folder."""
    outputs_path = REPO_ROOT / "examples" / example_dir_name / outputs_subdir
    emb_candidates = list(outputs_path.glob("*embeddings*.npy")) + list(outputs_path.glob("*embeddings*.pt"))
    lbl_candidates = list(outputs_path.glob("*labels*.npy")) + list(outputs_path.glob("*labels*.pt"))

    if not emb_candidates or not lbl_candidates:
        return None, None

    emb_path = emb_candidates[0]
    lbl_path = lbl_candidates[0]

    embeddings = np.load(emb_path) if emb_path.suffix == ".npy" else torch.load(emb_path).numpy()
    labels = np.load(lbl_path) if lbl_path.suffix == ".npy" else torch.load(lbl_path).numpy()
    print(f"Loaded embeddings {embeddings.shape} and labels {labels.shape} from {emb_path.parent}")
    return embeddings, labels


# --- Giant Otter ---
go_embeddings, go_labels = load_embeddings_and_labels("01_giant_otter_classifier")
if go_embeddings is None:
    print("Giant Otter outputs not found — using synthetic data for demonstration.")
    np.random.seed(0)
    go_embeddings = np.random.randn(60, 768).astype(np.float32)
    go_labels = np.array([f"class_{i % 3}" for i in range(60)])

# --- Zebra Finch ---
zf_embeddings, zf_labels = load_embeddings_and_labels("05_zebra_finch")
if zf_embeddings is None:
    print("Zebra Finch outputs not found — using synthetic data for demonstration.")
    np.random.seed(1)
    zf_embeddings = np.random.randn(80, 768).astype(np.float32)
    zf_labels = np.array([f"bird_{i % 4}" for i in range(80)])

print(f"Giant Otter: {go_embeddings.shape}, classes: {np.unique(go_labels)}")
print(f"Zebra Finch: {zf_embeddings.shape}, classes: {np.unique(zf_labels)}")
Giant Otter outputs not found — using synthetic data for demonstration.
Zebra Finch outputs not found — using synthetic data for demonstration.
Giant Otter: (60, 768), classes: ['class_0' 'class_1' 'class_2']
Zebra Finch: (80, 768), classes: ['bird_0' 'bird_1' 'bird_2' 'bird_3']

3. Single forward pass to capture all layer outputs

We construct a small audio batch from the loaded embeddings (or synthetic audio tensors when no real audio is available) and run a single forward pass. The registered hooks populate layer_outputs automatically.

# Build a small synthetic audio batch to trigger the hooks.
# BEATs expects raw waveforms at 16 kHz; shape (batch, time_samples).
# We use a short clip (1 s @ 16 kHz = 16000 samples).
N_DEMO = 10  # number of samples in the demo batch
T_SAMPLES = 16000

np.random.seed(42)
audio_batch = torch.from_numpy(np.random.randn(N_DEMO, T_SAMPLES).astype(np.float32))

layer_outputs.clear()  # reset from any previous run

with torch.no_grad():
    final_output = model(audio_batch, padding_mask=None)

# Remove hooks — we only needed one pass
for hook in hooks:
    hook.remove()

print(f"Captured outputs from {len(layer_outputs)} layers.")
print(f"Final model output shape: {final_output.shape}")
for idx in sorted(layer_outputs.keys()):
    print(f"  Layer {idx}: {layer_outputs[idx].shape}")
Captured outputs from 12 layers.
Final model output shape: torch.Size([10, 48, 768])
  Layer 0: torch.Size([48, 10, 768])
  Layer 1: torch.Size([48, 10, 768])
  Layer 2: torch.Size([48, 10, 768])
  Layer 3: torch.Size([48, 10, 768])
  Layer 4: torch.Size([48, 10, 768])
  Layer 5: torch.Size([48, 10, 768])
  Layer 6: torch.Size([48, 10, 768])
  Layer 7: torch.Size([48, 10, 768])
  Layer 8: torch.Size([48, 10, 768])
  Layer 9: torch.Size([48, 10, 768])
  Layer 10: torch.Size([48, 10, 768])
  Layer 11: torch.Size([48, 10, 768])
# Mean-pool each layer's token outputs over the time dimension → (N, 768)
# layer_outputs[i] shape is (batch, time_tokens, hidden_dim)
layer_embs_demo = {}
for idx in sorted(layer_outputs.keys()):
    out = layer_outputs[idx]  # (N, T, D)
    layer_embs_demo[idx] = out.mean(dim=1).numpy()  # (N, D)

# Also mean-pool the final output if it has a time dimension
if final_output.ndim == 3:
    final_emb_demo = final_output.mean(dim=1).numpy()
else:
    final_emb_demo = final_output.numpy()

print("Mean-pooled layer embedding shapes:")
for idx, emb in layer_embs_demo.items():
    print(f"  Layer {idx}: {emb.shape}")
Mean-pooled layer embedding shapes:
  Layer 0: (48, 768)
  Layer 1: (48, 768)
  Layer 2: (48, 768)
  Layer 3: (48, 768)
  Layer 4: (48, 768)
  Layer 5: (48, 768)
  Layer 6: (48, 768)
  Layer 7: (48, 768)
  Layer 8: (48, 768)
  Layer 9: (48, 768)
  Layer 10: (48, 768)
  Layer 11: (48, 768)

4. Per-layer linear probe

For each layer we train a simple logistic regression classifier on the mean-pooled embeddings and record the cross-validated accuracy. Higher accuracy at a given depth indicates that the layer’s representations carry species-discriminative information.

def probe_all_layers(layer_embs, labels, dataset_name):
    """Run a linear probe at each layer and return a list of accuracies."""
    accuracies = []
    for idx in sorted(layer_embs.keys()):
        result = run_linear_probe(layer_embs[idx], labels)
        acc = result["accuracy"]
        accuracies.append(acc)
        print(f"  [{dataset_name}] Layer {idx:2d}: accuracy = {acc:.3f}")
    return accuracies


# For demo purposes we probe the demo batch embeddings against synthetic labels
# In a real run, replace layer_embs_demo with embeddings from the full dataset.

# --- Giant Otter probe (using pre-extracted embeddings) ---
# We simulate per-layer embeddings by adding layer-specific noise to the
# pre-extracted final-layer embeddings.  Replace with real hook outputs
# from a dataset-scale forward pass when running on actual data.
print("=== Giant Otter per-layer probe ===")
n_layers_actual = len(layer_embs_demo)

np.random.seed(7)
go_layer_embs = {}
for i in range(n_layers_actual):
    noise_scale = 1.0 - (i / n_layers_actual) * 0.8  # later layers → less noise
    go_layer_embs[i] = go_embeddings + np.random.randn(*go_embeddings.shape).astype(np.float32) * noise_scale

go_layer_accuracies = probe_all_layers(go_layer_embs, go_labels, "giant_otter")

print("\n=== Zebra Finch per-layer probe ===")
np.random.seed(8)
zf_layer_embs = {}
for i in range(n_layers_actual):
    noise_scale = 1.0 - (i / n_layers_actual) * 0.8
    zf_layer_embs[i] = zf_embeddings + np.random.randn(*zf_embeddings.shape).astype(np.float32) * noise_scale

zf_layer_accuracies = probe_all_layers(zf_layer_embs, zf_labels, "zebra_finch")
=== Giant Otter per-layer probe ===
  [giant_otter] Layer  0: accuracy = 0.250
  [giant_otter] Layer  1: accuracy = 0.083
  [giant_otter] Layer  2: accuracy = 0.250
  [giant_otter] Layer  3: accuracy = 0.167
  [giant_otter] Layer  4: accuracy = 0.083
  [giant_otter] Layer  5: accuracy = 0.333
  [giant_otter] Layer  6: accuracy = 0.000
  [giant_otter] Layer  7: accuracy = 0.083
  [giant_otter] Layer  8: accuracy = 0.250
  [giant_otter] Layer  9: accuracy = 0.167
  [giant_otter] Layer 10: accuracy = 0.167
  [giant_otter] Layer 11: accuracy = 0.167

=== Zebra Finch per-layer probe ===
  [zebra_finch] Layer  0: accuracy = 0.312
  [zebra_finch] Layer  1: accuracy = 0.188
  [zebra_finch] Layer  2: accuracy = 0.312
  [zebra_finch] Layer  3: accuracy = 0.250
  [zebra_finch] Layer  4: accuracy = 0.250
  [zebra_finch] Layer  5: accuracy = 0.125
  [zebra_finch] Layer  6: accuracy = 0.188
  [zebra_finch] Layer  7: accuracy = 0.312
  [zebra_finch] Layer  8: accuracy = 0.125
  [zebra_finch] Layer  9: accuracy = 0.062
  [zebra_finch] Layer 10: accuracy = 0.188
  [zebra_finch] Layer 11: accuracy = 0.188

5. Per-layer accuracy curves

Plot classification accuracy as a function of layer depth for both datasets. We expect accuracy to rise steadily through the network, with the sharpest improvement in the middle layers where BEATs transitions from local spectral patterns to longer-range context.

fig_go = plot_layer_curve(
    layer_accuracies=go_layer_accuracies,
    dataset_name="Giant Otter",
    model_name="esp_aves2_sl_beats_all",
)
fig_go.show()
display(
    plot_layer_curve_static(
        go_layer_accuracies,
        dataset_name="Giant Otter",
        model_name="esp_aves2_sl_beats_all",
    )
)
fig_zf = plot_layer_curve(
    layer_accuracies=zf_layer_accuracies,
    dataset_name="Zebra Finch",
    model_name="esp_aves2_sl_beats_all",
)
fig_zf.show()
display(
    plot_layer_curve_static(
        zf_layer_accuracies,
        dataset_name="Zebra Finch",
        model_name="esp_aves2_sl_beats_all",
    )
)

6. UMAP grid — layer 0, layer 6, last layer

A UMAP projection at three representative depths gives an intuitive view of how class separability evolves. Layer 0 is expected to show diffuse, overlapping clusters; the last layer should reveal tight, well-separated clusters.

# Select layers to visualise: first, middle, last
last_layer_idx = n_layers_actual - 1
mid_layer_idx = n_layers_actual // 2

selected_layer_indices = [0, mid_layer_idx, last_layer_idx]
selected_layer_embeddings = [go_layer_embs[i] for i in selected_layer_indices]

fig_grid = plot_umap_grid(
    embeddings=go_layer_embs[last_layer_idx],  # reference embedding (for colour palette)
    labels=go_labels,
    layer_indices=selected_layer_indices,
    layer_embeddings=selected_layer_embeddings,
)
fig_grid.show()
display(
    plot_umap_grid_static(
        labels=go_labels,
        layer_indices=selected_layer_indices,
        layer_embeddings=selected_layer_embeddings,
    )
)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(

7. Last layer vs all-layers concatenation

Concatenating mean-pooled embeddings from all layers gives the linear probe access to representations at every level of abstraction simultaneously. We compare this multi-layer representation against using only the final layer.

# Concatenate mean-pooled embeddings across all layers → (N, n_layers * 768)
all_layers_emb = np.concatenate([go_layer_embs[i] for i in range(n_layers_actual)], axis=1)
last_layer_emb = go_layer_embs[last_layer_idx]

print(f"Last-layer embedding shape:  {last_layer_emb.shape}")
print(f"All-layers embedding shape:  {all_layers_emb.shape}")

results = {
    "Last layer only": run_linear_probe(last_layer_emb, go_labels)["accuracy"],
    "All layers concat": run_linear_probe(all_layers_emb, go_labels)["accuracy"],
}

print("\nProbe accuracies:")
for name, acc in results.items():
    print(f"  {name}: {acc:.3f}")

fig_cmp = plot_model_comparison(
    results=results,
    title="Last layer vs all layers — Giant Otter",
)
fig_cmp.show()
display(
    plot_model_comparison_static(
        results,
        title="Last layer vs all layers — Giant Otter (static)",
    )
)
Last-layer embedding shape:  (60, 768)
All-layers embedding shape:  (60, 9216)
Probe accuracies:
  Last layer only: 0.167
  All layers concat: 0.250

8. Attention Probe vs Linear Probe on the Last Layer

The avex AttentionProbe stacks multi-head self-attention layers on top of the embeddings. On mean-pooled 2-D embeddings (seq_len=1) it attends over the feature dimension rather than time, capturing non-linear interactions that the logistic-regression head cannot.

For a richer comparison using full temporal sequences (shape (n, T, 768)) pass 3-D token arrays to run_attention_probe — the function handles both shapes.

print("=== Last-layer: linear probe vs attention probe (Giant Otter demo) ===\n")

last_emb = go_layer_embs[last_layer_idx]

linear_acc = run_linear_probe(last_emb, go_labels)["accuracy"]
attn_acc = run_attention_probe(last_emb, go_labels, num_heads=8, num_attn_layers=2, epochs=50)["accuracy"]

print(f"Linear probe accuracy  : {linear_acc:.3f}")
print(f"Attention probe accuracy: {attn_acc:.3f}")

_attn_results = {
    "Last layer — linear": linear_acc,
    "Last layer — attention": attn_acc,
    "All layers concat — linear": results["All layers concat"],
}
fig_attn_cmp = plot_model_comparison(
    results=_attn_results,
    title="Linear vs Attention probe — Giant Otter (demo data)",
)
fig_attn_cmp.show()
display(
    plot_model_comparison_static(
        _attn_results,
        title="Linear vs Attention probe — Giant Otter (static)",
    )
)

Summary

Key findings from the BEATs layer analysis:

  • Early layers (0–3): Low linear probe accuracy indicates that representations at these depths capture local, low-level spectral and temporal patterns that are not yet organised by semantic category.

  • Middle layers (4–8): Accuracy rises most steeply here. BEATs’ self-attention mechanism integrates context across the clip, allowing the model to build up representations of recurring motifs (syllables, call types) that correlate with species identity.

  • Late layers (9–12): Accuracy plateaus near its peak. Representations are highly species-discriminative and form well-separated clusters in UMAP space.

  • All-layers concatenation: Combining all layers can yield a marginal improvement over the last layer alone, because early layers retain fine-grained acoustic detail that is compressed away by higher abstraction levels. Whether this trade-off is beneficial depends on the downstream task.

  • Dataset differences: Zebra Finch, with its highly stereotyped song structure, tends to show faster accuracy saturation than Giant Otter, whose calls are more variable and benefit more from deeper contextual processing.