Zebra Finch Call-Type and Age Classification

Classify 11 call types and adult/juvenile age groups in zebra finch (Taeniopygia guttata) vocalizations using BEATs and EfficientNet-B0.

Workflow: download → explore → embed → UMAP → training-free metrics (NMI, ARI, R-AUC) → linear probe → attention probe (sl-BEATs) → static + interactive figures → save artifacts.

import pathlib
import sys


def find_repo_root(start: pathlib.Path) -> pathlib.Path:
    for p in [start, *start.parents]:
        if (p / "pyproject.toml").exists():
            return p
    raise FileNotFoundError("Could not locate repo root (pyproject.toml not found).")


REPO_ROOT = find_repo_root(pathlib.Path().resolve())
sys.path.insert(0, str(REPO_ROOT))

import json
import re
import zipfile

import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import soundfile as sf
import torch
from avex import load_model
from IPython.display import display
from tqdm.auto import tqdm

from utils.probing import (
    compute_training_free_metrics,
    run_attention_probe_fixed_split,
    run_linear_probe_fixed_split,
)
from utils.visualization import (
    confusion_heatmap_static,
    plot_model_comparison,
    plot_model_comparison_static,
    plot_umap,
    plot_umap_static,
)

EXAMPLE_DIR = REPO_ROOT / "examples" / "05_zebra_finch"
DATA_DIR = EXAMPLE_DIR / "data"
AUDIO_DIR = DATA_DIR / "audio"
EMBED_DIR = DATA_DIR / "embeddings"
EMBED_DIR.mkdir(parents=True, exist_ok=True)

TARGET_SR = 16_000
DEVICE = "cpu"
BEATS_MODEL = "esp_aves2_sl_beats_all"
EFFNET_MODEL = "esp_aves2_effnetb0_all"

# Full mode: run all cells (can take a while).
FAST_MODE = False

print("Setup complete.")
2026-04-21 10:42:15.695893: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-21 10:42:17.648254: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Setup complete.

1. Download Dataset

FIGSHARE_ARTICLE_ID = 11905533
FIGSHARE_API_URL = f"https://api.figshare.com/v2/articles/{FIGSHARE_ARTICLE_ID}/files"
ARCHIVE_PATH = DATA_DIR / "zebra_finch_calls.zip"
EXTRACT_SENTINEL = AUDIO_DIR / ".extracted"
import json as _json_mod
import urllib.request as _ur


def download_with_progress(url, dest):
    def _hook(count, bs, total):
        if total > 0:
            print(f"\r  {min(count * bs / total * 100, 100):.1f}%", end="", flush=True)

    _ur.urlretrieve(url, dest, reporthook=_hook)
    print()


AUDIO_DIR.mkdir(parents=True, exist_ok=True)
if not ARCHIVE_PATH.exists():
    with _ur.urlopen(FIGSHARE_API_URL) as resp:
        files_info = _json_mod.loads(resp.read())
    target = max(files_info, key=lambda f: f["size"])
    print(f"Downloading {target['name']} ({target['size'] / 1e6:.0f} MB) ...")
    download_with_progress(target["download_url"], ARCHIVE_PATH)
else:
    print(f"Archive present ({ARCHIVE_PATH.stat().st_size / 1e6:.1f} MB) — skipping.")

if not EXTRACT_SENTINEL.exists():
    print("Extracting ...")
    with zipfile.ZipFile(ARCHIVE_PATH) as zf:
        zf.extractall(AUDIO_DIR)
    EXTRACT_SENTINEL.touch()
    print("Done.")
else:
    print("Audio already extracted.")
print(f"WAV files: {len(list(AUDIO_DIR.glob('**/*.wav')))}")
Archive present (140.6 MB) — skipping.
Audio already extracted.
WAV files: 6402

2. Data Exploration

CALL_TYPE_NAMES = {
    "aggc": "aggressive call",
    "ag": "aggressive call",
    "dc": "distance call",
    "disc": "distance call",
    "ltc": "long-term call",
    "nekakle": "nekakle call",
    "nekaklec": "nekakle call",
    "nestc": "nest call",
    "nest": "nest call",
    "ne": "nest call",
    "nearkc": "nest call",
    "neseq": "nest call seq",
    "nestseq": "nest call seq",
    "nestcseq": "nest call seq",
    "so": "song",
    "song": "song",
    "tetc": "tet call",
    "tet": "tet call",
    "te": "tet call",
    "thuckc": "thuck call",
    "thukc": "thuck call",
    "thuc": "thuck call",
    "tukc": "thuck call",
    "whinec": "whine call",
    "whine": "whine call",
    "whi": "whine call",
    "wh": "whine call",
    "wc": "whine call",
    "whic": "whine call",
    "whinecseq": "whine call seq",
    "whicnestc": "whine+nest call",
    "beggseq": "begging seq",
}
META_PATH = DATA_DIR / "metadata.csv"
FNAME_RE = re.compile(r"^([A-Za-z][A-Za-z0-9]*)_\d*-([A-Za-z]+)-", re.IGNORECASE)

if META_PATH.exists():
    # Load existing metadata (preserves same row ordering as any cached embeddings)
    df = pd.read_csv(META_PATH)
    if "call_name" not in df.columns and "call_type" in df.columns:
        df["call_name"] = df["call_type"].map(lambda c: CALL_TYPE_NAMES.get(c, c))
    if "age" not in df.columns:
        df["age"] = "unknown"
    print(f"Loaded existing metadata: {len(df)} rows")
else:
    wav_files = sorted(AUDIO_DIR.glob("**/*.wav"))
    records = []
    for wav_path in wav_files:
        m = FNAME_RE.search(wav_path.name)
        if m is None:
            continue
        specimen = m.group(1)
        call_code = m.group(2).lower()
        info = sf.info(str(wav_path))
        records.append(
            {
                "path": str(wav_path),
                "filename": wav_path.name,
                "specimen": specimen,
                "call_type": call_code,
                "call_name": CALL_TYPE_NAMES.get(call_code, call_code),
                "age": "unknown",
                "duration_s": info.duration,
                "sample_rate": info.samplerate,
            }
        )
    df = pd.DataFrame(records)
    df.to_csv(META_PATH, index=False)
    print(f"Parsed and saved: {len(df)} files")

print(f"Call types: {df['call_name'].nunique()}, specimens: {df['specimen'].nunique()}")
Loaded existing metadata: 3395 rows
Call types: 13, specimens: 49
ct_df = df.groupby("call_name").size().reset_index(name="count").sort_values("count", ascending=False)
fig = px.bar(
    ct_df,
    x="call_name",
    y="count",
    title="Zebra finch call type counts",
    labels={"call_name": "Call type", "count": "Number of calls"},
    color="count",
    color_continuous_scale="Blues",
)
fig.update_layout(xaxis_tickangle=-30, showlegend=False)
display(fig)

3. Embedding Extraction

MIN_SAMPLES = int(0.5 * TARGET_SR)


def load_audio(path: str, target_sr: int = TARGET_SR) -> torch.Tensor:
    """Load a WAV, convert to mono, resample, zero-pad to at least 0.5 s."""
    wav, sr = sf.read(path, dtype="float32", always_2d=True)
    wav = wav.mean(axis=1)
    if sr != target_sr:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
    if len(wav) < MIN_SAMPLES:
        wav = np.pad(wav, (0, MIN_SAMPLES - len(wav)))
    return torch.from_numpy(wav).unsqueeze(0)


print(f"Sample shape: {load_audio(df['path'].iloc[0]).shape}")
Sample shape: torch.Size([1, 8000])
BEATS_CACHE = EMBED_DIR / "beats_embeddings.npy"

if BEATS_CACHE.exists() and np.load(BEATS_CACHE).shape[0] == len(df):
    beats_embs = np.load(BEATS_CACHE)
    print(f"Loaded cached BEATs last-layer embeddings: {beats_embs.shape}")
else:
    print(f"Loading model: {BEATS_MODEL}")
    model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="BEATs last-layer"):
            wav = load_audio(path)
            feats = model(wav)  # (1, T, 768)
            embeddings.append(feats.mean(dim=1).squeeze(0).cpu().numpy())
    beats_embs = np.stack(embeddings)
    np.save(BEATS_CACHE, beats_embs)
    print(f"Saved BEATs last-layer embeddings: {beats_embs.shape}")
    del model
Loaded cached BEATs last-layer embeddings: (3395, 768)
BEATS_ALL_CACHE = EMBED_DIR / "beats_all_layers_embeddings.npy"

if BEATS_ALL_CACHE.exists() and np.load(BEATS_ALL_CACHE).shape[0] == len(df):
    beats_all_embs = np.load(BEATS_ALL_CACHE)
    print(f"Loaded cached BEATs all-layers embeddings: {beats_all_embs.shape}")
else:
    print(f"Loading model for all-layers extraction: {BEATS_MODEL}")
    model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
    model.eval()

    # Register forward hooks on every transformer encoder layer
    try:
        encoder_layers = model.model.encoder.layers
    except AttributeError:
        encoder_layers = model.backbone.encoder.layers
    n_layers = len(encoder_layers)
    layer_store: dict = {}
    hooks = []
    for i, layer in enumerate(encoder_layers):

        def _make_hook(idx):
            def _hook(module, inp, out):
                layer_store[idx] = out[0] if isinstance(out, tuple) else out

            return _hook

        hooks.append(layer.register_forward_hook(_make_hook(i)))
    print(f"  Registered hooks on {n_layers} transformer layers.")

    all_embs = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="BEATs all-layers"):
            layer_store.clear()
            wav = load_audio(path)
            _ = model(wav)
            # Mean-pool each layer then average across layers → (D,)
            # Handles both time-first (T, B, D) and batch-first (B, T, D) outputs
            per_layer = [
                layer_store[i].view(-1, layer_store[i].shape[-1]).mean(dim=0).cpu().numpy() for i in range(n_layers)
            ]
            all_embs.append(np.mean(per_layer, axis=0))

    for h in hooks:
        h.remove()

    beats_all_embs = np.stack(all_embs)
    np.save(BEATS_ALL_CACHE, beats_all_embs)
    print(f"Saved BEATs all-layers embeddings: {beats_all_embs.shape}")
    del model
Loaded cached BEATs all-layers embeddings: (3395, 768)
EFFNET_CACHE = EMBED_DIR / "effnet_embeddings.npy"

if EFFNET_CACHE.exists() and np.load(EFFNET_CACHE).shape[0] == len(df):
    effnet_embs = np.load(EFFNET_CACHE)
    print(f"Loaded cached EfficientNet embeddings: {effnet_embs.shape}")
else:
    print(f"Loading model: {EFFNET_MODEL}")
    model = load_model(EFFNET_MODEL, return_features_only=True, device=DEVICE)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="EfficientNet"):
            wav = load_audio(path)
            feats = model(wav)  # (1, C, H, W)
            embeddings.append(feats.mean(dim=(2, 3)).squeeze(0).cpu().numpy())
    effnet_embs = np.stack(embeddings)
    np.save(EFFNET_CACHE, effnet_embs)
    print(f"Saved EfficientNet embeddings: {effnet_embs.shape}")
    del model
Loaded cached EfficientNet embeddings: (3395, 1280)

4. UMAP Visualisation

call_labels = df["call_name"].tolist()
hover = [f"{row['filename']}<br>Type: {row['call_name']}" for _, row in df.iterrows()]

if FAST_MODE:
    print("FAST_MODE: skipping UMAP visualisations (BEATs last layer).")
else:
    fig_beats = plot_umap(
        beats_embs,
        labels=call_labels,
        title=f"UMAP — {BEATS_MODEL} (last layer)<br><sup>colour = call type</sup>",
        hover_text=hover,
    )
    display(fig_beats)

    fig_umap_beats_static = plot_umap_static(
        beats_embs,
        labels=call_labels,
        title=f"UMAP — {BEATS_MODEL} (last layer) — static",
    )
    display(fig_umap_beats_static)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/7bd40577e4e6c4b7b54a2dd512e61e736106bf60c9b563dde5a244f6d54e867e.png ../../_images/7bd40577e4e6c4b7b54a2dd512e61e736106bf60c9b563dde5a244f6d54e867e.png
if FAST_MODE:
    print("FAST_MODE: skipping UMAP visualisations (BEATs all layers avg).")
else:
    fig_beats_all = plot_umap(
        beats_all_embs,
        labels=call_labels,
        title=f"UMAP — {BEATS_MODEL} (all layers avg)",
        hover_text=hover,
    )
    display(fig_beats_all)

    fig_umap_beats_all_static = plot_umap_static(
        beats_all_embs,
        labels=call_labels,
        title=f"UMAP — {BEATS_MODEL} (all layers avg) — static",
    )
    display(fig_umap_beats_all_static)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/0fafe7870c8177c48710ad5d789fa8878e49e1f933f5bc8b9aebcd911c6eb889.png ../../_images/0fafe7870c8177c48710ad5d789fa8878e49e1f933f5bc8b9aebcd911c6eb889.png
if FAST_MODE:
    print("FAST_MODE: skipping UMAP visualisations (EfficientNet).")
else:
    fig_effnet = plot_umap(
        effnet_embs,
        labels=call_labels,
        title=f"UMAP — {EFFNET_MODEL}",
        hover_text=hover,
    )
    display(fig_effnet)

    fig_umap_effnet_static = plot_umap_static(
        effnet_embs,
        labels=call_labels,
        title=f"UMAP — {EFFNET_MODEL} — static",
    )
    display(fig_umap_effnet_static)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/e6461a0eef6cae3a16006d2f13f7ac4e89a26c2009cad5bcfcd5a4cb9dec8b2c.png ../../_images/e6461a0eef6cae3a16006d2f13f7ac4e89a26c2009cad5bcfcd5a4cb9dec8b2c.png

5. Training-Free Metrics

print("Computing training-free metrics (NMI, ARI, R-AUC) ...")
print("These evaluate embedding quality without fitting any classifier.\n")

_labels_for_metrics = call_labels

_metric_models = [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
    (EFFNET_MODEL, effnet_embs),
]

_metric_rows = []
for _name, _embs in _metric_models:
    _m = compute_training_free_metrics(_embs, _labels_for_metrics)
    _metric_rows.append(
        {"Model": _name, "NMI": round(_m["nmi"], 3), "ARI": round(_m["ari"], 3), "R-AUC": round(_m["r_auc"], 3)}
    )
    print(f"  {_name}: NMI={_m['nmi']:.3f}  ARI={_m['ari']:.3f}  R-AUC={_m['r_auc']:.3f}")

_metrics_df = pd.DataFrame(_metric_rows).set_index("Model")
display(_metrics_df)
Computing training-free metrics (NMI, ARI, R-AUC) ...
These evaluate embedding quality without fitting any classifier.
  esp_aves2_sl_beats_all (last layer): NMI=0.613  ARI=0.464  R-AUC=0.558
  esp_aves2_sl_beats_all (all layers avg): NMI=0.596  ARI=0.436  R-AUC=0.553
  esp_aves2_effnetb0_all: NMI=0.475  ARI=0.309  R-AUC=0.458
NMI ARI R-AUC
Model
esp_aves2_sl_beats_all (last layer) 0.613 0.464 0.558
esp_aves2_sl_beats_all (all layers avg) 0.596 0.436 0.553
esp_aves2_effnetb0_all 0.475 0.309 0.458

6. Linear Probe

Two tasks: 11-class call type and 2-class age group.

from sklearn.model_selection import GroupShuffleSplit

probe_kwargs = dict(random_state=42, max_iter=1000)


def _probe_acc_or_nan(res: dict) -> float:
    acc = res.get("accuracy")
    if acc is None:
        return float("nan")
    try:
        return float(acc)
    except (TypeError, ValueError):
        return float("nan")


def _probe_acc_str(res: dict) -> str:
    acc = res.get("accuracy")
    if acc is None:
        return "NA"
    try:
        v = float(acc)
    except (TypeError, ValueError):
        return "NA"
    if np.isnan(v):
        return "NA"
    return f"{v:.3f}"


def _label_values(labels: list | np.ndarray, idx: np.ndarray) -> set[str]:
    vals: set[str] = set()
    for x in np.asarray(labels, dtype=object)[idx]:
        if x is None:
            continue
        if isinstance(x, float) and np.isnan(x):
            continue
        s = str(x).strip()
        if not s:
            continue
        vals.add(s)
    return vals


def _get_age_group_labels(_df: pd.DataFrame) -> list[str] | None:
    if "age_group" in _df.columns:
        vals = _df["age_group"].astype(str).tolist()
        uniq = sorted({v for v in vals if v and v.lower() not in {"nan", "none"}})
        if len(uniq) >= 2:
            return vals

    if "age" in _df.columns:
        vals = _df["age"].astype(str).tolist()
        uniq = sorted({v for v in vals if v and v.lower() not in {"nan", "none", "unknown"}})
        if len(uniq) >= 2:
            return vals

    return None


age_labels = _get_age_group_labels(df)

tasks_all: list[tuple[str, list[str]]] = [("call_type", call_labels)]
if age_labels is not None:
    tasks_all.append(("age_group", age_labels))

specimen_groups = np.array(df["specimen"].tolist())


def _split_valid(train_idx: np.ndarray, test_idx: np.ndarray) -> bool:
    for _task, _labs in tasks_all:
        tr = _label_values(_labs, train_idx)
        te = _label_values(_labs, test_idx)
        if len(tr) < 2 or len(te) < 2:
            return False
    return True


TRAIN_IDX = None
TEST_IDX = None
split_seed = None

for _seed in range(0, 500):
    _gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=_seed)
    _tr, _te = next(_gss.split(np.zeros(len(df)), call_labels, groups=specimen_groups))
    if _split_valid(_tr, _te):
        TRAIN_IDX, TEST_IDX, split_seed = _tr, _te, _seed
        break

if TRAIN_IDX is None or TEST_IDX is None:
    print(
        "WARNING: Could not find a specimen-grouped split where all tasks have ≥2 classes in both train and test. "
        "Falling back to call_type-only probing."
    )
    tasks: list[tuple[str, list[str]]] = [("call_type", call_labels)]

    _gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=probe_kwargs["random_state"])
    TRAIN_IDX, TEST_IDX = next(_gss.split(np.zeros(len(df)), call_labels, groups=specimen_groups))
    split_seed = probe_kwargs["random_state"]
else:
    tasks = tasks_all

print(
    f"Split seed: {split_seed} | Train specimens: {len(set(specimen_groups[TRAIN_IDX]))} | Test specimens: {len(set(specimen_groups[TEST_IDX]))}"
)

RUN_CALL_TYPE_PROBE = any(t == "call_type" for t, _ in tasks)
RUN_AGE_GROUP_PROBE = any(t == "age_group" for t, _ in tasks)

for _task, _labs in tasks:
    tr = _label_values(_labs, TRAIN_IDX)
    te = _label_values(_labs, TEST_IDX)
    print(f"  {_task}: train_classes={len(tr)} test_classes={len(te)}")

probe_results: dict[str, dict] = {}
for emb_name, embs in [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
    (EFFNET_MODEL, effnet_embs),
]:
    for task, task_labels in tasks:
        key = f"{emb_name} | {task}"
        res = run_linear_probe_fixed_split(
            embs,
            task_labels,
            train_idx=TRAIN_IDX,
            test_idx=TEST_IDX,
            random_state=probe_kwargs["random_state"],
            max_iter=probe_kwargs["max_iter"],
        )
        probe_results[key] = res
        print(f"{key}: accuracy = {_probe_acc_str(res)}")
Split seed: 0 | Train specimens: 39 | Test specimens: 10
  call_type: train_classes=13 test_classes=11
  age_group: train_classes=2 test_classes=2
esp_aves2_sl_beats_all (last layer) | call_type: accuracy = 0.796
esp_aves2_sl_beats_all (last layer) | age_group: accuracy = 0.757
esp_aves2_sl_beats_all (all layers avg) | call_type: accuracy = 0.824
esp_aves2_sl_beats_all (all layers avg) | age_group: accuracy = 0.839
esp_aves2_effnetb0_all | call_type: accuracy = 0.833
esp_aves2_effnetb0_all | age_group: accuracy = 0.985
def _probe_acc_or_nan(res: dict) -> float:
    acc = res.get("accuracy")
    if acc is None:
        return float("nan")
    try:
        return float(acc)
    except (TypeError, ValueError):
        return float("nan")

rows = [
    {"Model": k.split(" | ")[0], "Task": k.split(" | ")[1], "Accuracy": round(_probe_acc_or_nan(v), 4)}
    for k, v in probe_results.items()
]
acc_df = pd.DataFrame(rows)
display(acc_df.pivot(index="Model", columns="Task", values="Accuracy").style.format("{:.4f}"))

fig_cmp = plot_model_comparison(
    {k: _probe_acc_or_nan(v) for k, v in probe_results.items()},
    title="Zebra Finch — call type & age classification accuracy",
)
display(fig_cmp)  # avoid interactive show() in non-interactive runs
display(
    plot_model_comparison_static(
        {k: _probe_acc_or_nan(v) for k, v in probe_results.items()},
        title="Zebra Finch — linear probe (static)",
    )
)
Task age_group call_type
Model    
esp_aves2_effnetb0_all 0.9851 0.8334
esp_aves2_sl_beats_all (all layers avg) 0.8390 0.8236
esp_aves2_sl_beats_all (last layer) 0.7572 0.7963
../../_images/d54bb0a97dcfcbb58941c54c49102022d91b5b7993110c282de8536bebbd44ab.png ../../_images/d54bb0a97dcfcbb58941c54c49102022d91b5b7993110c282de8536bebbd44ab.png

7. Attention Probe (sl-BEATs)

Attention probe on mean-pooled BEATs embeddings for call type and age tasks.

attn_probe_kwargs = dict(num_heads=8, num_attn_layers=2, epochs=50, random_state=42, device=DEVICE)

attention_results: dict[str, dict] = {}
fig_attn_call = None
fig_attn_age = None

if not RUN_CALL_TYPE_PROBE and not RUN_AGE_GROUP_PROBE:
    print("Skipping attention probe (no valid tasks on this split).")
else:
    print("Running attention probe (this can be slow) ...")

    if RUN_CALL_TYPE_PROBE:
        key = f"{BEATS_MODEL} (attn) | call_type"
        res = run_attention_probe_fixed_split(
            beats_embs,
            call_labels,
            train_idx=TRAIN_IDX,
            test_idx=TEST_IDX,
            **attn_probe_kwargs,
        )
        attention_results[key] = res
        print(f"  {key}: accuracy = {_probe_acc_str(res)}")

        fig_attn_call = plot_model_comparison(
            {key: _probe_acc_or_nan(res)},
            title="Zebra Finch — BEATs attention probe (call type)",
        )
        display(fig_attn_call)
        display(
            plot_model_comparison_static(
                {key: _probe_acc_or_nan(res)},
                title="Zebra Finch — BEATs attention probe (call type) — static",
            )
        )

    if RUN_AGE_GROUP_PROBE:
        key = f"{BEATS_MODEL} (attn) | age_group"
        res = run_attention_probe_fixed_split(
            beats_embs,
            age_labels,
            train_idx=TRAIN_IDX,
            test_idx=TEST_IDX,
            **attn_probe_kwargs,
        )
        attention_results[key] = res
        print(f"  {key}: accuracy = {_probe_acc_str(res)}")

        fig_attn_age = plot_model_comparison(
            {key: _probe_acc_or_nan(res)},
            title="Zebra Finch — BEATs attention probe (age)",
        )
        display(fig_attn_age)
        display(
            plot_model_comparison_static(
                {key: _probe_acc_or_nan(res)},
                title="Zebra Finch — BEATs attention probe (age) — static",
            )
        )
Running attention probe (this can be slow) ...
  esp_aves2_sl_beats_all (attn) | call_type: accuracy = 0.793
../../_images/3fc585ad679947a90d03c16072e642ca26c7dc2fca39266aee3f4d7bc2bebe7d.png
  esp_aves2_sl_beats_all (attn) | age_group: accuracy = 0.942
../../_images/876331f549f815489b4b3631a33f10715044ffa5c192b1f0f85ad0c4edc20075.png ../../_images/3fc585ad679947a90d03c16072e642ca26c7dc2fca39266aee3f4d7bc2bebe7d.png ../../_images/876331f549f815489b4b3631a33f10715044ffa5c192b1f0f85ad0c4edc20075.png
def plot_confusion_matrix(cm, classes, title):
    cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
    return px.imshow(
        cm_norm,
        x=classes,
        y=classes,
        color_continuous_scale="Blues",
        zmin=0,
        zmax=1,
        text_auto=".2f",
        title=title,
        labels={"x": "Predicted", "y": "True", "color": "Recall"},
        aspect="auto",
    )


for key in [k for k in probe_results if "call_type" in k]:
    res = probe_results[key]
    display(plot_confusion_matrix(res["confusion_matrix"], res["classes"], title=f"Call type confusion — {key}"))
    display(
        confusion_heatmap_static(
            res["confusion_matrix"],
            res["classes"],
            title=f"Call type confusion — {key} (static)",
        )
    )
/tmp/ipykernel_1192382/1326845768.py:2: RuntimeWarning: invalid value encountered in divide
  cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
/home/marius_miron_earthspecies_org/code/avex-examples/utils/visualization.py:562: RuntimeWarning: invalid value encountered in divide
  cm_norm = np.where(row_sums > 0, cm.astype(float) / row_sums, 0.0)
../../_images/f0c5af7543ebc6ddb250a7efcffa82e5c7616ab61e518828ab73173e0cf5c91a.png
/tmp/ipykernel_1192382/1326845768.py:2: RuntimeWarning: invalid value encountered in divide
  cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
/home/marius_miron_earthspecies_org/code/avex-examples/utils/visualization.py:562: RuntimeWarning: invalid value encountered in divide
  cm_norm = np.where(row_sums > 0, cm.astype(float) / row_sums, 0.0)
../../_images/f591ebfeb4f5f8e518a8c0e295bd97c28fc8ac73189f3347f02a4b52b048a2e6.png
/tmp/ipykernel_1192382/1326845768.py:2: RuntimeWarning: invalid value encountered in divide
  cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
/home/marius_miron_earthspecies_org/code/avex-examples/utils/visualization.py:562: RuntimeWarning: invalid value encountered in divide
  cm_norm = np.where(row_sums > 0, cm.astype(float) / row_sums, 0.0)
../../_images/b40bbf0aaa0017f4cd29d420efa0b4eeebd1c4f9145f1cf6b853185e957c53db.png ../../_images/f0c5af7543ebc6ddb250a7efcffa82e5c7616ab61e518828ab73173e0cf5c91a.png ../../_images/f591ebfeb4f5f8e518a8c0e295bd97c28fc8ac73189f3347f02a4b52b048a2e6.png ../../_images/b40bbf0aaa0017f4cd29d420efa0b4eeebd1c4f9145f1cf6b853185e957c53db.png

9. Save Artifacts

ARTIFACTS_DIR = EXAMPLE_DIR / "artifacts"
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

# Some figures are skipped in fast runs.
for _fig, _name in [
    (locals().get("fig_beats"), "umap_beats.html"),
    (locals().get("fig_beats_all"), "umap_beats_all_layers.html"),
    (locals().get("fig_effnet"), "umap_effnet.html"),
    (locals().get("fig_cmp"), "model_comparison.html"),
    (locals().get("fig_attn_call"), "model_comparison_beats_attn_call_type.html"),
    (locals().get("fig_attn_age"), "model_comparison_beats_attn_age.html"),
]:
    if _fig is not None:
        _fig.write_html(str(ARTIFACTS_DIR / _name))

for _fig, _name in [
    (locals().get("fig_umap_beats_static"), "umap_beats_static.png"),
    (locals().get("fig_umap_beats_all_static"), "umap_beats_all_layers_static.png"),
    (locals().get("fig_umap_effnet_static"), "umap_effnet_static.png"),
]:
    if _fig is not None:
        _fig.savefig(str(ARTIFACTS_DIR / _name), dpi=150, bbox_inches="tight")

plt.close("all")

_metrics_out = {
    "n_calls": len(df),
    "n_call_types": df["call_name"].nunique(),
    "training_free": {
        k: compute_training_free_metrics(v, call_labels)
        for k, v in [
            ("beats_last", beats_embs),
            ("beats_all_layers", beats_all_embs),
            ("effnet", effnet_embs),
        ]
    },
    "linear_probe_accuracy": {k: round(_probe_acc_or_nan(v), 4) for k, v in probe_results.items()},
    "attention_probe_accuracy": {k: round(_probe_acc_or_nan(v), 4) for k, v in attention_results.items()},
}

with open(ARTIFACTS_DIR / "zebra_finch_metrics.json", "w") as _f:
    json.dump(_metrics_out, _f, indent=2)

print(f"Artifacts saved to {ARTIFACTS_DIR}")
Artifacts saved to /home/marius_miron_earthspecies_org/code/avex-examples/examples/05_zebra_finch/artifacts