Macaque Individual ID Classification

Classify 8 individuals from Japanese macaque (Macaca fuscata) coo calls using BEATs and EfficientNet-B0 embeddings.

Labels available in the processed dataset: individual codename only (annotations.csv provides class, split, filename).

Workflow: download → explore → embed → UMAP → training-free metrics (NMI, ARI, R-AUC) → linear probe → attention probe (sl-BEATs) → static + interactive figures → save artifacts.

import warnings

warnings.filterwarnings("ignore")

import pathlib
import sys


def find_repo_root(start: pathlib.Path) -> pathlib.Path:
    for p in [start, *start.parents]:
        if (p / "pyproject.toml").exists():
            return p
    raise FileNotFoundError("Could not locate repo root (pyproject.toml not found).")


REPO_ROOT = find_repo_root(pathlib.Path().resolve())
sys.path.insert(0, str(REPO_ROOT))

import json
import re
import urllib.request
import zipfile

import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import soundfile as sf
import torch
from avex import load_model
from IPython.display import display
from tqdm.auto import tqdm

from utils.probing import (
    compute_training_free_metrics,
    run_attention_probe,
    run_attention_probe_fixed_split,
    run_linear_probe,
    run_linear_probe_fixed_split,
)
from utils.visualization import (
    confusion_heatmap_static,
    plot_model_comparison,
    plot_model_comparison_static,
    plot_umap,
    plot_umap_static,
)

EXAMPLE_DIR = REPO_ROOT / "examples" / "03_macaques_individual_id"
DATA_DIR = EXAMPLE_DIR / "data"
AUDIO_DIR = DATA_DIR / "audio"
EMBED_DIR = DATA_DIR / "embeddings"
EMBED_DIR.mkdir(parents=True, exist_ok=True)

TARGET_SR = 16_000
DEVICE = "cpu"
BEATS_MODEL = "esp_aves2_sl_beats_all"
EFFNET_MODEL = "esp_aves2_effnetb0_all"
print("Setup complete.")
2026-04-21 07:56:39.540524: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-21 07:56:41.532433: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Setup complete.

1. Download Dataset

ARCHIVE_URL = "https://archive.org/download/macaque_coo_calls/macaques.zip"
ARCHIVE_PATH = DATA_DIR / "macaques.zip"
EXTRACT_SENTINEL = AUDIO_DIR / ".extracted"


def download_with_progress(url: str, dest: pathlib.Path) -> None:
    def _hook(count, block_size, total_size):
        if total_size > 0:
            pct = min(count * block_size / total_size * 100, 100)
            print(f"\r  {pct:5.1f}%", end="", flush=True)

    urllib.request.urlretrieve(url, dest, reporthook=_hook)
    print()


if not ARCHIVE_PATH.exists():
    print(f"Downloading {ARCHIVE_URL} ...")
    download_with_progress(ARCHIVE_URL, ARCHIVE_PATH)
else:
    print(f"Archive present ({ARCHIVE_PATH.stat().st_size / 1e6:.1f} MB) — skipping.")

if not EXTRACT_SENTINEL.exists():
    AUDIO_DIR.mkdir(parents=True, exist_ok=True)
    print("Extracting ...")
    with zipfile.ZipFile(ARCHIVE_PATH) as zf:
        zf.extractall(AUDIO_DIR)
    EXTRACT_SENTINEL.touch()
    print("Done.")
else:
    print("Audio already extracted.")

print(f"WAV files: {len(list(AUDIO_DIR.rglob('*.wav')))}")
Archive present (132.0 MB) — skipping.
Audio already extracted.
WAV files: 7285

2. Data Exploration

ANNOTATIONS_PATH = AUDIO_DIR / "annotations.csv"
if not ANNOTATIONS_PATH.exists():
    raise FileNotFoundError(
        f"Expected annotations at {ANNOTATIONS_PATH}. "
        "If you downloaded the dataset manually, ensure annotations.csv is present."
    )

ann = pd.read_csv(ANNOTATIONS_PATH)
required_cols = {"class", "split", "filename"}
missing = required_cols.difference(set(ann.columns))
if missing:
    raise ValueError(
        f"annotations.csv missing columns: {sorted(missing)} (found {list(ann.columns)})"
    )

# Build a fast lookup from filename -> absolute path.
all_wavs = list(AUDIO_DIR.rglob("*.wav"))
path_by_name = {p.name: p for p in all_wavs}

records: list[dict[str, object]] = []
missing_paths: list[str] = []

# NOTE: avoid itertuples attribute access for the `class` column.
for row in ann.to_dict(orient="records"):
    filename = str(row["filename"])
    individual = str(row["class"])
    split = str(row["split"])

    wav_path = path_by_name.get(filename)
    if wav_path is None:
        wav_path = AUDIO_DIR / split / filename
        if not wav_path.exists():
            missing_paths.append(filename)
            continue

    info = sf.info(str(wav_path))
    records.append(
        {
            "path": str(wav_path),
            "filename": filename,
            "split": split,
            "individual": individual,
            "sample_rate": int(info.samplerate),
            "duration_s": float(info.duration),
        }
    )

if missing_paths:
    raise FileNotFoundError(
        "Some annotated files were not found under data/audio/. "
        f"Example missing: {missing_paths[:5]} (n_missing={len(missing_paths)})"
    )

df = pd.DataFrame.from_records(records)

# Prefer 44100 Hz recordings for consistent preprocessing.
if (df["sample_rate"] == 44100).any():
    df = df[df["sample_rate"] == 44100].reset_index(drop=True)

# Save metadata
META_PATH = DATA_DIR / "metadata.csv"
df.to_csv(META_PATH, index=False)

print(
    f"Working set: {len(df)} calls across {df['individual'].nunique()} individuals "
    f"(splits: {sorted(df['split'].unique().tolist())})"
)
display(df.groupby(["individual", "split"]).size().rename("n_calls"))
Working set: 1529 calls across 5 individuals (splits: ['train', 'valid'])
individual  split
AL          train    190
            valid     52
BE          train      9
            valid      4
IO          train    708
            valid    178
SN          train    306
            valid     81
TW          train      1
Name: n_calls, dtype: int64
count_df = (
    df.groupby(["individual", "split"])
    .size()
    .reset_index(name="n_calls")
    .sort_values(["individual", "split"])
)
fig = px.bar(
    count_df,
    x="individual",
    y="n_calls",
    color="split",
    title="Number of coo calls per individual",
    labels={"individual": "Individual", "n_calls": "Number of calls", "split": "Split"},
    text="n_calls",
)
fig.update_traces(textposition="outside")
fig.show()

3. Embedding Extraction

MIN_SAMPLES = int(0.5 * TARGET_SR)


def load_audio(path: str, target_sr: int = TARGET_SR) -> torch.Tensor:
    """Load a WAV, convert to mono, resample, zero-pad to at least 0.5 s."""
    wav, sr = sf.read(path, dtype="float32", always_2d=True)
    wav = wav.mean(axis=1)
    if sr != target_sr:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
    if len(wav) < MIN_SAMPLES:
        wav = np.pad(wav, (0, MIN_SAMPLES - len(wav)))
    return torch.from_numpy(wav).unsqueeze(0)


print(f"Sample shape: {load_audio(df['path'].iloc[0]).shape}")
Sample shape: torch.Size([1, 8000])
BEATS_CACHE = EMBED_DIR / "beats_embeddings.npy"

if BEATS_CACHE.exists() and np.load(BEATS_CACHE).shape[0] == len(df):
    beats_embs = np.load(BEATS_CACHE)
    print(f"Loaded cached BEATs last-layer embeddings: {beats_embs.shape}")
else:
    print(f"Loading model: {BEATS_MODEL}")
    model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="BEATs last-layer"):
            wav = load_audio(path)
            feats = model(wav)  # (1, T, 768)
            embeddings.append(feats.mean(dim=1).squeeze(0).cpu().numpy())
    beats_embs = np.stack(embeddings)
    np.save(BEATS_CACHE, beats_embs)
    print(f"Saved BEATs last-layer embeddings: {beats_embs.shape}")
    del model
Loaded cached BEATs last-layer embeddings: (1529, 768)
BEATS_ALL_CACHE = EMBED_DIR / "beats_all_layers_embeddings.npy"

if BEATS_ALL_CACHE.exists() and np.load(BEATS_ALL_CACHE).shape[0] == len(df):
    beats_all_embs = np.load(BEATS_ALL_CACHE)
    print(f"Loaded cached BEATs all-layers embeddings: {beats_all_embs.shape}")
else:
    print(f"Loading model for all-layers extraction: {BEATS_MODEL}")
    model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
    model.eval()

    # Register forward hooks on every transformer encoder layer
    try:
        encoder_layers = model.model.encoder.layers
    except AttributeError:
        encoder_layers = model.backbone.encoder.layers
    n_layers = len(encoder_layers)
    layer_store: dict = {}
    hooks = []
    for i, layer in enumerate(encoder_layers):

        def _make_hook(idx):
            def _hook(module, inp, out):
                layer_store[idx] = out[0] if isinstance(out, tuple) else out

            return _hook

        hooks.append(layer.register_forward_hook(_make_hook(i)))
    print(f"  Registered hooks on {n_layers} transformer layers.")

    all_embs = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="BEATs all-layers"):
            layer_store.clear()
            wav = load_audio(path)
            _ = model(wav)
            # Mean-pool each layer then average across layers → (D,)
            # Handles both time-first (T, B, D) and batch-first (B, T, D) outputs
            per_layer = [
                layer_store[i].view(-1, layer_store[i].shape[-1]).mean(dim=0).cpu().numpy() for i in range(n_layers)
            ]
            all_embs.append(np.mean(per_layer, axis=0))

    for h in hooks:
        h.remove()

    beats_all_embs = np.stack(all_embs)
    np.save(BEATS_ALL_CACHE, beats_all_embs)
    print(f"Saved BEATs all-layers embeddings: {beats_all_embs.shape}")
    del model
Loaded cached BEATs all-layers embeddings: (1529, 768)
EFFNET_CACHE = EMBED_DIR / "effnet_embeddings.npy"

if EFFNET_CACHE.exists() and np.load(EFFNET_CACHE).shape[0] == len(df):
    effnet_embs = np.load(EFFNET_CACHE)
    print(f"Loaded cached EfficientNet embeddings: {effnet_embs.shape}")
else:
    print(f"Loading model: {EFFNET_MODEL}")
    model = load_model(EFFNET_MODEL, return_features_only=True, device=DEVICE)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="EfficientNet"):
            wav = load_audio(path)
            feats = model(wav)  # (1, C, H, W)
            embeddings.append(feats.mean(dim=(2, 3)).squeeze(0).cpu().numpy())
    effnet_embs = np.stack(embeddings)
    np.save(EFFNET_CACHE, effnet_embs)
    print(f"Saved EfficientNet embeddings: {effnet_embs.shape}")
    del model
Loaded cached EfficientNet embeddings: (1529, 1280)

4. UMAP Visualisation

id_labels = df["individual"].tolist()

# Individual-ID only example.

hover = [
    f"{ind} · split={split} · file={fn}"
    for ind, split, fn in zip(
        df["individual"].tolist(),
        df["split"].tolist(),
        df["filename"].tolist(),
        strict=False,
    )
]

# Consistent colour map.
# - If IDs look like "F_1" / "M_7", use warm (F) vs cool (M) palettes.
# - Otherwise (e.g. "AL", "BE", ...), assign a stable categorical palette.
_unique_ids = sorted(set(id_labels))
if any(i.startswith(("F_", "M_")) for i in _unique_ids):
    _female_ids = [i for i in _unique_ids if i.startswith("F_")]
    _male_ids = [i for i in _unique_ids if i.startswith("M_")]
    _f_pal = ["#e6550d", "#fd8d3c", "#fdae6b", "#fdd0a2", "#a63603", "#d94801", "#f16913", "#7f2704"]
    _m_pal = ["#08519c", "#2171b5", "#4292c6", "#6baed6", "#084594", "#2c7bb6", "#08306b", "#3182bd"]
    INDIVIDUAL_COLOR_MAP = {ind: _f_pal[i % len(_f_pal)] for i, ind in enumerate(_female_ids)}
    INDIVIDUAL_COLOR_MAP.update({ind: _m_pal[i % len(_m_pal)] for i, ind in enumerate(_male_ids)})
else:
    _pal = px.colors.qualitative.Dark24
    INDIVIDUAL_COLOR_MAP = {ind: _pal[i % len(_pal)] for i, ind in enumerate(_unique_ids)}

fig_beats = plot_umap(
    beats_embs,
    labels=id_labels,
    title=f"UMAP — {BEATS_MODEL} (last layer)<br><sup>colour = individual</sup>",
    hover_text=hover,
    color_discrete_map=INDIVIDUAL_COLOR_MAP,
)
fig_beats.show()
fig_umap_beats_static = plot_umap_static(
    beats_embs,
    labels=id_labels,
    title=f"UMAP — {BEATS_MODEL} (last layer) — static",
    color_map=INDIVIDUAL_COLOR_MAP,
)
display(fig_umap_beats_static)
../../_images/d2e981c07e1dcd3293d536ea2a69a486f89d34405507b933940dff3d03cccbbe.png
fig_beats_all = plot_umap(
    beats_all_embs,
    labels=id_labels,
    title=f"UMAP — {BEATS_MODEL} (all layers avg)",
    hover_text=hover,
    color_discrete_map=INDIVIDUAL_COLOR_MAP,
)
fig_beats_all.show()
fig_umap_beats_all_static = plot_umap_static(
    beats_all_embs,
    labels=id_labels,
    title=f"UMAP — {BEATS_MODEL} (all layers avg) — static",
    color_map=INDIVIDUAL_COLOR_MAP,
)
display(fig_umap_beats_all_static)
../../_images/afceb3e2e5cbacc68b6937733de6c80fcf110948b8458cb9ce4b0b3458828285.png
fig_effnet = plot_umap(
    effnet_embs,
    labels=id_labels,
    title=f"UMAP — {EFFNET_MODEL}",
    hover_text=hover,
    color_discrete_map=INDIVIDUAL_COLOR_MAP,
)
fig_effnet.show()
fig_umap_effnet_static = plot_umap_static(
    effnet_embs,
    labels=id_labels,
    title=f"UMAP — {EFFNET_MODEL} — static",
    color_map=INDIVIDUAL_COLOR_MAP,
)
display(fig_umap_effnet_static)
../../_images/dc7af4b391e86788bf4ff579263024ad251acb8f95f59c075355e93768d74f32.png

5. Training-Free Metrics

print("Computing training-free metrics (NMI, ARI, R-AUC) ...")
print("These evaluate embedding quality without fitting any classifier.\n")

_labels_for_metrics = id_labels

_metric_models = [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
    (EFFNET_MODEL, effnet_embs),
]

_metric_rows = []
for _name, _embs in _metric_models:
    _m = compute_training_free_metrics(_embs, _labels_for_metrics)
    _metric_rows.append(
        {"Model": _name, "NMI": round(_m["nmi"], 3), "ARI": round(_m["ari"], 3), "R-AUC": round(_m["r_auc"], 3)}
    )
    print(f"  {_name}: NMI={_m['nmi']:.3f}  ARI={_m['ari']:.3f}  R-AUC={_m['r_auc']:.3f}")

_metrics_df = pd.DataFrame(_metric_rows).set_index("Model")
display(_metrics_df)
Computing training-free metrics (NMI, ARI, R-AUC) ...
These evaluate embedding quality without fitting any classifier.
  esp_aves2_sl_beats_all (last layer): NMI=0.154  ARI=0.070  R-AUC=0.504
  esp_aves2_sl_beats_all (all layers avg): NMI=0.119  ARI=0.078  R-AUC=0.493
  esp_aves2_effnetb0_all: NMI=0.258  ARI=0.187  R-AUC=0.533
NMI ARI R-AUC
Model
esp_aves2_sl_beats_all (last layer) 0.154 0.070 0.504
esp_aves2_sl_beats_all (all layers avg) 0.119 0.078 0.493
esp_aves2_effnetb0_all 0.258 0.187 0.533

6. Linear Probe

Single task: 8-class individual ID classification.

probe_kwargs = dict(test_size=0.2, random_state=42, max_iter=1000)

probe_results: dict[str, dict] = {}
for emb_name, embs in [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
    (EFFNET_MODEL, effnet_embs),
]:
    key = f"{emb_name} | individual_id"
    res = run_linear_probe(
        embs,
        id_labels,
        test_size=probe_kwargs["test_size"],
        random_state=probe_kwargs["random_state"],
        max_iter=probe_kwargs["max_iter"],
    )
    probe_results[key] = res
    print(f"{key}: accuracy = {res['accuracy']:.3f}")
esp_aves2_sl_beats_all (last layer) | individual_id: accuracy = 0.605
esp_aves2_sl_beats_all (all layers avg) | individual_id: accuracy = 0.609
esp_aves2_effnetb0_all | individual_id: accuracy = 0.650
rows = [
    {"Model": k.split(" | ")[0], "Accuracy": round(v["accuracy"], 4)}
    for k, v in probe_results.items()
]
acc_df = pd.DataFrame(rows)
display(acc_df.set_index("Model").style.format({"Accuracy": "{:.4f}"}))

fig_cmp = plot_model_comparison(
    {k: v["accuracy"] for k, v in probe_results.items()},
    title="Macaque — individual ID accuracy",
)
fig_cmp.show()
display(
    plot_model_comparison_static(
        {k: v["accuracy"] for k, v in probe_results.items()},
        title="Macaque — linear probe (static)",
    )
)
  Accuracy
Model  
esp_aves2_sl_beats_all (last layer) 0.6047
esp_aves2_sl_beats_all (all layers avg) 0.6087
esp_aves2_effnetb0_all 0.6500
../../_images/ccf5e32a41e3c028ceb36d26ecd5913ba0ecd7e02facf9be04c6d903f949da2b.png ../../_images/ccf5e32a41e3c028ceb36d26ecd5913ba0ecd7e02facf9be04c6d903f949da2b.png

7. Attention Probe (sl-BEATs)

Multi-head attention probe on mean-pooled BEATs embeddings for the same two tasks.

attn_probe_kwargs = dict(num_heads=8, num_attn_layers=2, epochs=50, test_size=0.2, random_state=42)
attention_results: dict[str, dict] = {}

for emb_name, embs in [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
]:
    key = f"{emb_name} | individual_id"
    res = run_attention_probe(
        embs,
        id_labels,
        num_heads=attn_probe_kwargs["num_heads"],
        num_attn_layers=attn_probe_kwargs["num_attn_layers"],
        epochs=attn_probe_kwargs["epochs"],
        test_size=attn_probe_kwargs["test_size"],
        random_state=attn_probe_kwargs["random_state"],
        device=DEVICE,
    )
    attention_results[key] = res
    print(f"{key} (attention): accuracy = {res['accuracy']:.3f}")

cmp_individual: dict[str, float] = {}
for base in [f"{BEATS_MODEL} (last layer)", f"{BEATS_MODEL} (all layers avg)"]:
    lk = f"{base} | individual_id"
    cmp_individual[f"{base} (linear)"] = float(probe_results[lk]["accuracy"])
    cmp_individual[f"{base} (attention)"] = float(attention_results[lk]["accuracy"])

fig_attn_individual = plot_model_comparison(
    cmp_individual, title="Macaque — sl-BEATs individual ID: linear vs attention"
)
fig_attn_individual.show()
display(plot_model_comparison_static(cmp_individual, title="Macaque — sl-BEATs individual ID (static)"))
esp_aves2_sl_beats_all (last layer) | individual_id (attention): accuracy = 0.751
esp_aves2_sl_beats_all (all layers avg) | individual_id (attention): accuracy = 0.713
../../_images/56f11ad48fc1960faea238da6e73cc1be8d1c87c5842600e9340250b2776ea92.png ../../_images/56f11ad48fc1960faea238da6e73cc1be8d1c87c5842600e9340250b2776ea92.png
def plot_confusion_matrix(cm, classes, title):
    cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
    return px.imshow(
        cm_norm,
        x=classes,
        y=classes,
        color_continuous_scale="Blues",
        zmin=0,
        zmax=1,
        text_auto=".2f",
        title=title,
        labels={"x": "Predicted", "y": "True", "color": "Recall"},
        aspect="auto",
    )


for key, res in probe_results.items():
    plot_confusion_matrix(res["confusion_matrix"], res["classes"], title=f"Confusion matrix: {key}").show()
    display(
        confusion_heatmap_static(
            res["confusion_matrix"],
            res["classes"],
            title=f"Confusion matrix: {key} (static)",
        )
    )
../../_images/aa91b91b02e57843965384bb7715c47fae3ecdb27a097687ffba2b2c7f35a91b.png ../../_images/3a5ae6a9e1e586076b3a855cf418a1422ff6ffecc09592a41d3e20dd76e7726c.png ../../_images/7a0af0265459e7e46717cc5fbe2d3a6a97243b61a841b40f958cd4bf2312c0b0.png ../../_images/aa91b91b02e57843965384bb7715c47fae3ecdb27a097687ffba2b2c7f35a91b.png ../../_images/3a5ae6a9e1e586076b3a855cf418a1422ff6ffecc09592a41d3e20dd76e7726c.png ../../_images/7a0af0265459e7e46717cc5fbe2d3a6a97243b61a841b40f958cd4bf2312c0b0.png

9. Save Artifacts

ARTIFACTS_DIR = EXAMPLE_DIR / "artifacts"
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

fig_beats.write_html(str(ARTIFACTS_DIR / "umap_beats.html"))
fig_beats_all.write_html(str(ARTIFACTS_DIR / "umap_beats_all_layers.html"))
fig_effnet.write_html(str(ARTIFACTS_DIR / "umap_effnet.html"))
fig_cmp.write_html(str(ARTIFACTS_DIR / "model_comparison.html"))
fig_attn_individual.write_html(str(ARTIFACTS_DIR / "model_comparison_beats_attn_individual.html"))

fig_umap_beats_static.savefig(str(ARTIFACTS_DIR / "umap_beats_static.png"), dpi=150, bbox_inches="tight")
fig_umap_beats_all_static.savefig(str(ARTIFACTS_DIR / "umap_beats_all_layers_static.png"), dpi=150, bbox_inches="tight")
fig_umap_effnet_static.savefig(str(ARTIFACTS_DIR / "umap_effnet_static.png"), dpi=150, bbox_inches="tight")
plt.close("all")

_metrics_out = {
    "n_calls": int(len(df)),
    "n_individuals": int(df["individual"].nunique()),
    "training_free": {
        k: compute_training_free_metrics(v, id_labels)
        for k, v in [("beats_last", beats_embs), ("beats_all_layers", beats_all_embs), ("effnet", effnet_embs)]
    },
    "linear_probe_accuracy": {k: round(v["accuracy"], 4) for k, v in probe_results.items()},
    "attention_probe_accuracy": {k: round(v["accuracy"], 4) for k, v in attention_results.items()},
}

with open(ARTIFACTS_DIR / "macaques_metrics.json", "w") as _f:
    json.dump(_metrics_out, _f, indent=2)
print(f"Artifacts saved to {ARTIFACTS_DIR}")
Artifacts saved to /home/marius_miron_earthspecies_org/code/avex-examples/examples/03_macaques_individual_id/artifacts