Macaque Individual ID Classification¶
Classify 8 individuals from Japanese macaque (Macaca fuscata) coo calls using BEATs and EfficientNet-B0 embeddings.
Labels available in the processed dataset: individual codename only (annotations.csv provides class, split, filename).
Workflow: download → explore → embed → UMAP → training-free metrics (NMI, ARI, R-AUC) → linear probe → attention probe (sl-BEATs) → static + interactive figures → save artifacts.
import warnings
warnings.filterwarnings("ignore")
import pathlib
import sys
def find_repo_root(start: pathlib.Path) -> pathlib.Path:
for p in [start, *start.parents]:
if (p / "pyproject.toml").exists():
return p
raise FileNotFoundError("Could not locate repo root (pyproject.toml not found).")
REPO_ROOT = find_repo_root(pathlib.Path().resolve())
sys.path.insert(0, str(REPO_ROOT))
import json
import re
import urllib.request
import zipfile
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import soundfile as sf
import torch
from avex import load_model
from IPython.display import display
from tqdm.auto import tqdm
from utils.probing import (
compute_training_free_metrics,
run_attention_probe,
run_attention_probe_fixed_split,
run_linear_probe,
run_linear_probe_fixed_split,
)
from utils.visualization import (
confusion_heatmap_static,
plot_model_comparison,
plot_model_comparison_static,
plot_umap,
plot_umap_static,
)
EXAMPLE_DIR = REPO_ROOT / "examples" / "03_macaques_individual_id"
DATA_DIR = EXAMPLE_DIR / "data"
AUDIO_DIR = DATA_DIR / "audio"
EMBED_DIR = DATA_DIR / "embeddings"
EMBED_DIR.mkdir(parents=True, exist_ok=True)
TARGET_SR = 16_000
DEVICE = "cpu"
BEATS_MODEL = "esp_aves2_sl_beats_all"
EFFNET_MODEL = "esp_aves2_effnetb0_all"
print("Setup complete.")
2026-04-21 07:56:39.540524: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-21 07:56:41.532433: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Setup complete.
1. Download Dataset¶
ARCHIVE_URL = "https://archive.org/download/macaque_coo_calls/macaques.zip"
ARCHIVE_PATH = DATA_DIR / "macaques.zip"
EXTRACT_SENTINEL = AUDIO_DIR / ".extracted"
def download_with_progress(url: str, dest: pathlib.Path) -> None:
def _hook(count, block_size, total_size):
if total_size > 0:
pct = min(count * block_size / total_size * 100, 100)
print(f"\r {pct:5.1f}%", end="", flush=True)
urllib.request.urlretrieve(url, dest, reporthook=_hook)
print()
if not ARCHIVE_PATH.exists():
print(f"Downloading {ARCHIVE_URL} ...")
download_with_progress(ARCHIVE_URL, ARCHIVE_PATH)
else:
print(f"Archive present ({ARCHIVE_PATH.stat().st_size / 1e6:.1f} MB) — skipping.")
if not EXTRACT_SENTINEL.exists():
AUDIO_DIR.mkdir(parents=True, exist_ok=True)
print("Extracting ...")
with zipfile.ZipFile(ARCHIVE_PATH) as zf:
zf.extractall(AUDIO_DIR)
EXTRACT_SENTINEL.touch()
print("Done.")
else:
print("Audio already extracted.")
print(f"WAV files: {len(list(AUDIO_DIR.rglob('*.wav')))}")
Archive present (132.0 MB) — skipping.
Audio already extracted.
WAV files: 7285
2. Data Exploration¶
ANNOTATIONS_PATH = AUDIO_DIR / "annotations.csv"
if not ANNOTATIONS_PATH.exists():
raise FileNotFoundError(
f"Expected annotations at {ANNOTATIONS_PATH}. "
"If you downloaded the dataset manually, ensure annotations.csv is present."
)
ann = pd.read_csv(ANNOTATIONS_PATH)
required_cols = {"class", "split", "filename"}
missing = required_cols.difference(set(ann.columns))
if missing:
raise ValueError(
f"annotations.csv missing columns: {sorted(missing)} (found {list(ann.columns)})"
)
# Build a fast lookup from filename -> absolute path.
all_wavs = list(AUDIO_DIR.rglob("*.wav"))
path_by_name = {p.name: p for p in all_wavs}
records: list[dict[str, object]] = []
missing_paths: list[str] = []
# NOTE: avoid itertuples attribute access for the `class` column.
for row in ann.to_dict(orient="records"):
filename = str(row["filename"])
individual = str(row["class"])
split = str(row["split"])
wav_path = path_by_name.get(filename)
if wav_path is None:
wav_path = AUDIO_DIR / split / filename
if not wav_path.exists():
missing_paths.append(filename)
continue
info = sf.info(str(wav_path))
records.append(
{
"path": str(wav_path),
"filename": filename,
"split": split,
"individual": individual,
"sample_rate": int(info.samplerate),
"duration_s": float(info.duration),
}
)
if missing_paths:
raise FileNotFoundError(
"Some annotated files were not found under data/audio/. "
f"Example missing: {missing_paths[:5]} (n_missing={len(missing_paths)})"
)
df = pd.DataFrame.from_records(records)
# Prefer 44100 Hz recordings for consistent preprocessing.
if (df["sample_rate"] == 44100).any():
df = df[df["sample_rate"] == 44100].reset_index(drop=True)
# Save metadata
META_PATH = DATA_DIR / "metadata.csv"
df.to_csv(META_PATH, index=False)
print(
f"Working set: {len(df)} calls across {df['individual'].nunique()} individuals "
f"(splits: {sorted(df['split'].unique().tolist())})"
)
display(df.groupby(["individual", "split"]).size().rename("n_calls"))
Working set: 1529 calls across 5 individuals (splits: ['train', 'valid'])
individual split
AL train 190
valid 52
BE train 9
valid 4
IO train 708
valid 178
SN train 306
valid 81
TW train 1
Name: n_calls, dtype: int64
count_df = (
df.groupby(["individual", "split"])
.size()
.reset_index(name="n_calls")
.sort_values(["individual", "split"])
)
fig = px.bar(
count_df,
x="individual",
y="n_calls",
color="split",
title="Number of coo calls per individual",
labels={"individual": "Individual", "n_calls": "Number of calls", "split": "Split"},
text="n_calls",
)
fig.update_traces(textposition="outside")
fig.show()
3. Embedding Extraction¶
MIN_SAMPLES = int(0.5 * TARGET_SR)
def load_audio(path: str, target_sr: int = TARGET_SR) -> torch.Tensor:
"""Load a WAV, convert to mono, resample, zero-pad to at least 0.5 s."""
wav, sr = sf.read(path, dtype="float32", always_2d=True)
wav = wav.mean(axis=1)
if sr != target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
if len(wav) < MIN_SAMPLES:
wav = np.pad(wav, (0, MIN_SAMPLES - len(wav)))
return torch.from_numpy(wav).unsqueeze(0)
print(f"Sample shape: {load_audio(df['path'].iloc[0]).shape}")
Sample shape: torch.Size([1, 8000])
BEATS_CACHE = EMBED_DIR / "beats_embeddings.npy"
if BEATS_CACHE.exists() and np.load(BEATS_CACHE).shape[0] == len(df):
beats_embs = np.load(BEATS_CACHE)
print(f"Loaded cached BEATs last-layer embeddings: {beats_embs.shape}")
else:
print(f"Loading model: {BEATS_MODEL}")
model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
model.eval()
embeddings = []
with torch.no_grad():
for path in tqdm(df["path"], desc="BEATs last-layer"):
wav = load_audio(path)
feats = model(wav) # (1, T, 768)
embeddings.append(feats.mean(dim=1).squeeze(0).cpu().numpy())
beats_embs = np.stack(embeddings)
np.save(BEATS_CACHE, beats_embs)
print(f"Saved BEATs last-layer embeddings: {beats_embs.shape}")
del model
Loaded cached BEATs last-layer embeddings: (1529, 768)
BEATS_ALL_CACHE = EMBED_DIR / "beats_all_layers_embeddings.npy"
if BEATS_ALL_CACHE.exists() and np.load(BEATS_ALL_CACHE).shape[0] == len(df):
beats_all_embs = np.load(BEATS_ALL_CACHE)
print(f"Loaded cached BEATs all-layers embeddings: {beats_all_embs.shape}")
else:
print(f"Loading model for all-layers extraction: {BEATS_MODEL}")
model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
model.eval()
# Register forward hooks on every transformer encoder layer
try:
encoder_layers = model.model.encoder.layers
except AttributeError:
encoder_layers = model.backbone.encoder.layers
n_layers = len(encoder_layers)
layer_store: dict = {}
hooks = []
for i, layer in enumerate(encoder_layers):
def _make_hook(idx):
def _hook(module, inp, out):
layer_store[idx] = out[0] if isinstance(out, tuple) else out
return _hook
hooks.append(layer.register_forward_hook(_make_hook(i)))
print(f" Registered hooks on {n_layers} transformer layers.")
all_embs = []
with torch.no_grad():
for path in tqdm(df["path"], desc="BEATs all-layers"):
layer_store.clear()
wav = load_audio(path)
_ = model(wav)
# Mean-pool each layer then average across layers → (D,)
# Handles both time-first (T, B, D) and batch-first (B, T, D) outputs
per_layer = [
layer_store[i].view(-1, layer_store[i].shape[-1]).mean(dim=0).cpu().numpy() for i in range(n_layers)
]
all_embs.append(np.mean(per_layer, axis=0))
for h in hooks:
h.remove()
beats_all_embs = np.stack(all_embs)
np.save(BEATS_ALL_CACHE, beats_all_embs)
print(f"Saved BEATs all-layers embeddings: {beats_all_embs.shape}")
del model
Loaded cached BEATs all-layers embeddings: (1529, 768)
EFFNET_CACHE = EMBED_DIR / "effnet_embeddings.npy"
if EFFNET_CACHE.exists() and np.load(EFFNET_CACHE).shape[0] == len(df):
effnet_embs = np.load(EFFNET_CACHE)
print(f"Loaded cached EfficientNet embeddings: {effnet_embs.shape}")
else:
print(f"Loading model: {EFFNET_MODEL}")
model = load_model(EFFNET_MODEL, return_features_only=True, device=DEVICE)
model.eval()
embeddings = []
with torch.no_grad():
for path in tqdm(df["path"], desc="EfficientNet"):
wav = load_audio(path)
feats = model(wav) # (1, C, H, W)
embeddings.append(feats.mean(dim=(2, 3)).squeeze(0).cpu().numpy())
effnet_embs = np.stack(embeddings)
np.save(EFFNET_CACHE, effnet_embs)
print(f"Saved EfficientNet embeddings: {effnet_embs.shape}")
del model
Loaded cached EfficientNet embeddings: (1529, 1280)
4. UMAP Visualisation¶
id_labels = df["individual"].tolist()
# Individual-ID only example.
hover = [
f"{ind} · split={split} · file={fn}"
for ind, split, fn in zip(
df["individual"].tolist(),
df["split"].tolist(),
df["filename"].tolist(),
strict=False,
)
]
# Consistent colour map.
# - If IDs look like "F_1" / "M_7", use warm (F) vs cool (M) palettes.
# - Otherwise (e.g. "AL", "BE", ...), assign a stable categorical palette.
_unique_ids = sorted(set(id_labels))
if any(i.startswith(("F_", "M_")) for i in _unique_ids):
_female_ids = [i for i in _unique_ids if i.startswith("F_")]
_male_ids = [i for i in _unique_ids if i.startswith("M_")]
_f_pal = ["#e6550d", "#fd8d3c", "#fdae6b", "#fdd0a2", "#a63603", "#d94801", "#f16913", "#7f2704"]
_m_pal = ["#08519c", "#2171b5", "#4292c6", "#6baed6", "#084594", "#2c7bb6", "#08306b", "#3182bd"]
INDIVIDUAL_COLOR_MAP = {ind: _f_pal[i % len(_f_pal)] for i, ind in enumerate(_female_ids)}
INDIVIDUAL_COLOR_MAP.update({ind: _m_pal[i % len(_m_pal)] for i, ind in enumerate(_male_ids)})
else:
_pal = px.colors.qualitative.Dark24
INDIVIDUAL_COLOR_MAP = {ind: _pal[i % len(_pal)] for i, ind in enumerate(_unique_ids)}
fig_beats = plot_umap(
beats_embs,
labels=id_labels,
title=f"UMAP — {BEATS_MODEL} (last layer)<br><sup>colour = individual</sup>",
hover_text=hover,
color_discrete_map=INDIVIDUAL_COLOR_MAP,
)
fig_beats.show()
fig_umap_beats_static = plot_umap_static(
beats_embs,
labels=id_labels,
title=f"UMAP — {BEATS_MODEL} (last layer) — static",
color_map=INDIVIDUAL_COLOR_MAP,
)
display(fig_umap_beats_static)
fig_beats_all = plot_umap(
beats_all_embs,
labels=id_labels,
title=f"UMAP — {BEATS_MODEL} (all layers avg)",
hover_text=hover,
color_discrete_map=INDIVIDUAL_COLOR_MAP,
)
fig_beats_all.show()
fig_umap_beats_all_static = plot_umap_static(
beats_all_embs,
labels=id_labels,
title=f"UMAP — {BEATS_MODEL} (all layers avg) — static",
color_map=INDIVIDUAL_COLOR_MAP,
)
display(fig_umap_beats_all_static)
fig_effnet = plot_umap(
effnet_embs,
labels=id_labels,
title=f"UMAP — {EFFNET_MODEL}",
hover_text=hover,
color_discrete_map=INDIVIDUAL_COLOR_MAP,
)
fig_effnet.show()
fig_umap_effnet_static = plot_umap_static(
effnet_embs,
labels=id_labels,
title=f"UMAP — {EFFNET_MODEL} — static",
color_map=INDIVIDUAL_COLOR_MAP,
)
display(fig_umap_effnet_static)
5. Training-Free Metrics¶
print("Computing training-free metrics (NMI, ARI, R-AUC) ...")
print("These evaluate embedding quality without fitting any classifier.\n")
_labels_for_metrics = id_labels
_metric_models = [
(f"{BEATS_MODEL} (last layer)", beats_embs),
(f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
(EFFNET_MODEL, effnet_embs),
]
_metric_rows = []
for _name, _embs in _metric_models:
_m = compute_training_free_metrics(_embs, _labels_for_metrics)
_metric_rows.append(
{"Model": _name, "NMI": round(_m["nmi"], 3), "ARI": round(_m["ari"], 3), "R-AUC": round(_m["r_auc"], 3)}
)
print(f" {_name}: NMI={_m['nmi']:.3f} ARI={_m['ari']:.3f} R-AUC={_m['r_auc']:.3f}")
_metrics_df = pd.DataFrame(_metric_rows).set_index("Model")
display(_metrics_df)
Computing training-free metrics (NMI, ARI, R-AUC) ...
These evaluate embedding quality without fitting any classifier.
esp_aves2_sl_beats_all (last layer): NMI=0.154 ARI=0.070 R-AUC=0.504
esp_aves2_sl_beats_all (all layers avg): NMI=0.119 ARI=0.078 R-AUC=0.493
esp_aves2_effnetb0_all: NMI=0.258 ARI=0.187 R-AUC=0.533
| NMI | ARI | R-AUC | |
|---|---|---|---|
| Model | |||
| esp_aves2_sl_beats_all (last layer) | 0.154 | 0.070 | 0.504 |
| esp_aves2_sl_beats_all (all layers avg) | 0.119 | 0.078 | 0.493 |
| esp_aves2_effnetb0_all | 0.258 | 0.187 | 0.533 |
6. Linear Probe¶
Single task: 8-class individual ID classification.
probe_kwargs = dict(test_size=0.2, random_state=42, max_iter=1000)
probe_results: dict[str, dict] = {}
for emb_name, embs in [
(f"{BEATS_MODEL} (last layer)", beats_embs),
(f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
(EFFNET_MODEL, effnet_embs),
]:
key = f"{emb_name} | individual_id"
res = run_linear_probe(
embs,
id_labels,
test_size=probe_kwargs["test_size"],
random_state=probe_kwargs["random_state"],
max_iter=probe_kwargs["max_iter"],
)
probe_results[key] = res
print(f"{key}: accuracy = {res['accuracy']:.3f}")
esp_aves2_sl_beats_all (last layer) | individual_id: accuracy = 0.605
esp_aves2_sl_beats_all (all layers avg) | individual_id: accuracy = 0.609
esp_aves2_effnetb0_all | individual_id: accuracy = 0.650
rows = [
{"Model": k.split(" | ")[0], "Accuracy": round(v["accuracy"], 4)}
for k, v in probe_results.items()
]
acc_df = pd.DataFrame(rows)
display(acc_df.set_index("Model").style.format({"Accuracy": "{:.4f}"}))
fig_cmp = plot_model_comparison(
{k: v["accuracy"] for k, v in probe_results.items()},
title="Macaque — individual ID accuracy",
)
fig_cmp.show()
display(
plot_model_comparison_static(
{k: v["accuracy"] for k, v in probe_results.items()},
title="Macaque — linear probe (static)",
)
)
| Accuracy | |
|---|---|
| Model | |
| esp_aves2_sl_beats_all (last layer) | 0.6047 |
| esp_aves2_sl_beats_all (all layers avg) | 0.6087 |
| esp_aves2_effnetb0_all | 0.6500 |
7. Attention Probe (sl-BEATs)¶
Multi-head attention probe on mean-pooled BEATs embeddings for the same two tasks.
attn_probe_kwargs = dict(num_heads=8, num_attn_layers=2, epochs=50, test_size=0.2, random_state=42)
attention_results: dict[str, dict] = {}
for emb_name, embs in [
(f"{BEATS_MODEL} (last layer)", beats_embs),
(f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
]:
key = f"{emb_name} | individual_id"
res = run_attention_probe(
embs,
id_labels,
num_heads=attn_probe_kwargs["num_heads"],
num_attn_layers=attn_probe_kwargs["num_attn_layers"],
epochs=attn_probe_kwargs["epochs"],
test_size=attn_probe_kwargs["test_size"],
random_state=attn_probe_kwargs["random_state"],
device=DEVICE,
)
attention_results[key] = res
print(f"{key} (attention): accuracy = {res['accuracy']:.3f}")
cmp_individual: dict[str, float] = {}
for base in [f"{BEATS_MODEL} (last layer)", f"{BEATS_MODEL} (all layers avg)"]:
lk = f"{base} | individual_id"
cmp_individual[f"{base} (linear)"] = float(probe_results[lk]["accuracy"])
cmp_individual[f"{base} (attention)"] = float(attention_results[lk]["accuracy"])
fig_attn_individual = plot_model_comparison(
cmp_individual, title="Macaque — sl-BEATs individual ID: linear vs attention"
)
fig_attn_individual.show()
display(plot_model_comparison_static(cmp_individual, title="Macaque — sl-BEATs individual ID (static)"))
esp_aves2_sl_beats_all (last layer) | individual_id (attention): accuracy = 0.751
esp_aves2_sl_beats_all (all layers avg) | individual_id (attention): accuracy = 0.713
def plot_confusion_matrix(cm, classes, title):
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
return px.imshow(
cm_norm,
x=classes,
y=classes,
color_continuous_scale="Blues",
zmin=0,
zmax=1,
text_auto=".2f",
title=title,
labels={"x": "Predicted", "y": "True", "color": "Recall"},
aspect="auto",
)
for key, res in probe_results.items():
plot_confusion_matrix(res["confusion_matrix"], res["classes"], title=f"Confusion matrix: {key}").show()
display(
confusion_heatmap_static(
res["confusion_matrix"],
res["classes"],
title=f"Confusion matrix: {key} (static)",
)
)
9. Save Artifacts¶
ARTIFACTS_DIR = EXAMPLE_DIR / "artifacts"
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
fig_beats.write_html(str(ARTIFACTS_DIR / "umap_beats.html"))
fig_beats_all.write_html(str(ARTIFACTS_DIR / "umap_beats_all_layers.html"))
fig_effnet.write_html(str(ARTIFACTS_DIR / "umap_effnet.html"))
fig_cmp.write_html(str(ARTIFACTS_DIR / "model_comparison.html"))
fig_attn_individual.write_html(str(ARTIFACTS_DIR / "model_comparison_beats_attn_individual.html"))
fig_umap_beats_static.savefig(str(ARTIFACTS_DIR / "umap_beats_static.png"), dpi=150, bbox_inches="tight")
fig_umap_beats_all_static.savefig(str(ARTIFACTS_DIR / "umap_beats_all_layers_static.png"), dpi=150, bbox_inches="tight")
fig_umap_effnet_static.savefig(str(ARTIFACTS_DIR / "umap_effnet_static.png"), dpi=150, bbox_inches="tight")
plt.close("all")
_metrics_out = {
"n_calls": int(len(df)),
"n_individuals": int(df["individual"].nunique()),
"training_free": {
k: compute_training_free_metrics(v, id_labels)
for k, v in [("beats_last", beats_embs), ("beats_all_layers", beats_all_embs), ("effnet", effnet_embs)]
},
"linear_probe_accuracy": {k: round(v["accuracy"], 4) for k, v in probe_results.items()},
"attention_probe_accuracy": {k: round(v["accuracy"], 4) for k, v in attention_results.items()},
}
with open(ARTIFACTS_DIR / "macaques_metrics.json", "w") as _f:
json.dump(_metrics_out, _f, indent=2)
print(f"Artifacts saved to {ARTIFACTS_DIR}")
Artifacts saved to /home/marius_miron_earthspecies_org/code/avex-examples/examples/03_macaques_individual_id/artifacts