Ungulate Emotional Valence Classification

Classify emotional valence (positive vs. negative) and species from ungulate contact calls using BEATs and EfficientNet-B0 embeddings.

Species: Cow, Pig, Sheep, Goat, Horse, Przewalski’s Horse, Wild Boar (7 species, 3,181 calls)

Dataset: Zenodo 14636641 — supplemental data for Machine Learning Algorithms Can Predict Emotional Valence Across Ungulate Vocalizations

Workflow: download → explore → embed → UMAP → training-free metrics (NMI, ARI, R-AUC) → linear probe → attention probe (sl-BEATs) → static + interactive figures → save artifacts.

import pathlib
import sys


def find_repo_root(start: pathlib.Path) -> pathlib.Path:
    for p in [start, *start.parents]:
        if (p / "pyproject.toml").exists():
            return p
    raise FileNotFoundError("Could not locate repo root (pyproject.toml not found).")


REPO_ROOT = find_repo_root(pathlib.Path().resolve())
sys.path.insert(0, str(REPO_ROOT))

import json
import urllib.request
import zipfile

import librosa
import matplotlib.pyplot as plt
import numpy as np
import openpyxl
import pandas as pd
import plotly.express as px
import torch
from avex import load_model
from IPython.display import display
from sklearn.linear_model import LogisticRegression
from tqdm.auto import tqdm

from utils.probing import compute_training_free_metrics, run_attention_probe, run_linear_probe
from utils.visualization import (
    confusion_heatmap_static,
    plot_model_comparison,
    plot_model_comparison_static,
    plot_umap,
    plot_umap_static,
)

EXAMPLE_DIR = REPO_ROOT / "examples" / "09_ungulate_valence"
DATA_DIR = EXAMPLE_DIR / "data"
AUDIO_DIR = DATA_DIR / "audio"
META_DIR = DATA_DIR / "metadata"
EMBED_DIR = DATA_DIR / "embeddings"
ARTIFACTS_DIR = EXAMPLE_DIR / "artifacts"
for _d in [AUDIO_DIR, META_DIR, EMBED_DIR, ARTIFACTS_DIR]:
    _d.mkdir(parents=True, exist_ok=True)

TARGET_SR = 16_000
DEVICE = "cpu"
BEATS_MODEL = "esp_aves2_sl_beats_all"
EFFNET_MODEL = "esp_aves2_effnetb0_all"
print("Setup complete.")
2026-04-16 17:21:47.411760: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-16 17:21:49.258093: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Setup complete.

1. Download Dataset

SOUNDS_URL = "https://zenodo.org/records/14636641/files/SoundsDatabase.zip"
META_URL = "https://zenodo.org/records/14636641/files/Data%20S1.zip"

SOUNDS_ZIP = DATA_DIR / "SoundsDatabase.zip"
META_ZIP = DATA_DIR / "Data_S1.zip"
EXTRACT_SENTINEL = AUDIO_DIR / ".extracted"
META_SENTINEL = META_DIR / ".extracted"


def download_with_progress(url: str, dest: pathlib.Path) -> None:
    def _hook(count: int, block_size: int, total_size: int) -> None:
        if total_size > 0:
            pct = min(count * block_size / total_size * 100, 100)
            print(f"\r  {pct:5.1f}%", end="", flush=True)

    urllib.request.urlretrieve(url, dest, reporthook=_hook)
    print()


if not SOUNDS_ZIP.exists():
    print("Downloading sounds database (~287 MB) ...")
    download_with_progress(SOUNDS_URL, SOUNDS_ZIP)
else:
    print(f"Sounds archive present ({SOUNDS_ZIP.stat().st_size / 1e6:.1f} MB) — skipping.")

if not META_ZIP.exists():
    print("Downloading metadata (~19 KB) ...")
    download_with_progress(META_URL, META_ZIP)
else:
    print(f"Metadata archive present ({META_ZIP.stat().st_size / 1024:.1f} KB) — skipping.")

if not EXTRACT_SENTINEL.exists():
    print("Extracting sounds (may take a moment) ...")
    with zipfile.ZipFile(SOUNDS_ZIP) as zf:
        zf.extractall(AUDIO_DIR)
    EXTRACT_SENTINEL.touch()
    print("Done.")
else:
    print("Sounds already extracted.")

if not META_SENTINEL.exists():
    print("Extracting metadata ...")
    with zipfile.ZipFile(META_ZIP) as zf:
        zf.extractall(META_DIR)
    META_SENTINEL.touch()
    print("Done.")
else:
    print("Metadata already extracted.")
Sounds archive present (286.9 MB) — skipping.
Metadata archive present (18.9 KB) — skipping.
Sounds already extracted.
Metadata already extracted.

2. Data Exploration

# Inspect extracted audio directory structure
print("Top-level contents of audio directory:")
for item in sorted(AUDIO_DIR.iterdir()):
    if item.name.startswith("."):
        continue
    if item.is_dir():
        n = sum(1 for _ in item.rglob("*") if _.is_file())
        print(f"  [DIR]  {item.name}/  ({n} files)")
    else:
        print(f"  [FILE] {item.name}")

print("\nMetadata files:")
for f in sorted(META_DIR.rglob("*")):
    if f.is_file() and not f.name.startswith(".") and not f.name.startswith("~"):
        print(f"  {f.relative_to(META_DIR)}  ({f.stat().st_size / 1024:.1f} KB)")
Top-level contents of audio directory:
  [DIR]  SoundsDatabase/  (3181 files)

Metadata files:
  Data S1.xlsx  (21.2 KB)
# Convert raw Excel annotations to CSV in artifacts (once), then always load from there.
# openpyxl is used explicitly here; after this cell openpyxl is no longer needed.
ANNOTATIONS_CSV = ARTIFACTS_DIR / "annotations.csv"

if not ANNOTATIONS_CSV.exists():
    _xlsx_candidates = [
        f for f in META_DIR.rglob("*") if f.suffix.lower() in {".xlsx", ".xls"} and not f.name.startswith((".", "~"))
    ]
    assert _xlsx_candidates, f"No Excel spreadsheet found in {META_DIR}. Run the download cell first."
    _xlsx_path = _xlsx_candidates[0]
    print(f"Converting {_xlsx_path.name} → annotations.csv via openpyxl ...")

    wb = openpyxl.load_workbook(_xlsx_path, read_only=True, data_only=True)
    ws = wb.active
    rows = list(ws.values)
    wb.close()

    _headers = [str(h).strip() if h is not None else f"col_{i}" for i, h in enumerate(rows[0])]
    _data = [list(row) for row in rows[1:] if any(v is not None for v in row)]
    _ann_df = pd.DataFrame(_data, columns=_headers)
    _ann_df.to_csv(ANNOTATIONS_CSV, index=False)
    print(f"Saved {len(_ann_df)} annotation rows → {ANNOTATIONS_CSV.relative_to(EXAMPLE_DIR)}")
else:
    print("annotations.csv already exists — skipping openpyxl conversion.")

# Always load from the CSV copy in artifacts
annotations_df = pd.read_csv(ANNOTATIONS_CSV)
print(f"\nAnnotations shape: {annotations_df.shape}")
print(f"Columns: {list(annotations_df.columns)}")
display(annotations_df.head(10))
annotations.csv already exists — skipping openpyxl conversion.

Annotations shape: (335, 7)
Columns: ['Data S1. Repartition of the calls according to animal identity (ID), species, call type, emotional valence, context of vocal production, and their respective number of occurrences, along with references to original studies. Related to STAR Methods.', 'col_1', 'col_2', 'col_3', 'col_4', 'col_5', 'col_6']
Data S1. Repartition of the calls according to animal identity (ID), species, call type, emotional valence, context of vocal production, and their respective number of occurrences, along with references to original studies. Related to STAR Methods. col_1 col_2 col_3 col_4 col_5 col_6
0 Animal ID Species Call type Valence Context Nb of calls Reference
1 1140 Cow Low- frequency (closed-mouth) and high-frequen... Negative Full (physical and visual) separation from cal... 40 Padilla de la Torre, M., Hillmann, E., and Bri...
2 1140 NaN NaN Positive Physical reunion with calf 2 NaN
3 1140 NaN NaN Positive Visual (but not physical) reunion with calf 2 NaN
4 1150 NaN NaN Negative Full (physical and visual) separation from cal... 48 NaN
5 1150 NaN NaN Negative Physical (but not visual) separation from calf 1 NaN
6 1150 NaN NaN Positive Visual (but not physical) reunion with calf 3 NaN
7 1152 NaN NaN Negative Full (physical and visual) separation from cal... 44 NaN
8 1152 NaN NaN Positive Physical reunion with calf 1 NaN
9 1152 NaN NaN Positive Visual (but not physical) reunion with calf 2 NaN
# Collect all audio files
_audio_exts = {".wav", ".mp3", ".flac", ".ogg"}
all_audio_paths = sorted(p for p in AUDIO_DIR.rglob("*") if p.suffix.lower() in _audio_exts)
print(f"Found {len(all_audio_paths)} audio files")
print("\nSample paths (first 5):")
for p in all_audio_paths[:5]:
    print(f"  {p.relative_to(AUDIO_DIR)}")
Found 3181 audio files

Sample paths (first 5):
  SoundsDatabase/1140-Cow-FullSeparation-Negative-1094.wav
  SoundsDatabase/1140-Cow-FullSeparation-Negative-1168.wav
  SoundsDatabase/1140-Cow-FullSeparation-Negative-1185.wav
  SoundsDatabase/1140-Cow-FullSeparation-Negative-1216.wav
  SoundsDatabase/1140-Cow-FullSeparation-Negative-1289.wav
# Extract species and valence labels from path structure
# The dataset encodes labels in folder names or filenames.
_VALENCE_POS = {"positive", "pos"}
_VALENCE_NEG = {"negative", "neg"}
_SPECIES_KEYS = {
    "cow": "Cow",
    "cattle": "Cow",
    "pig": "Pig",
    "sheep": "Sheep",
    "goat": "Goat",
    "horse": "Horse",
    "przewalski": "Przewalski's Horse",
    "boar": "Wild Boar",
    "wildboar": "Wild Boar",
}


def _parse_labels(path: pathlib.Path, base: pathlib.Path) -> dict[str, str | None]:
    """Extract species and valence from file path components."""
    parts = [p.lower().replace(" ", "").replace("-", "").replace("_", "") for p in path.relative_to(base).parts]
    parts_joined = " ".join(parts)

    species = None
    for key, name in _SPECIES_KEYS.items():
        if key in parts_joined:
            species = name
            break

    valence = None
    for part in parts:
        if any(kw in part for kw in _VALENCE_POS):
            valence = "positive"
            break
        if any(kw in part for kw in _VALENCE_NEG):
            valence = "negative"
            break

    return {"species": species, "valence": valence}


records = []
for p in all_audio_paths:
    labels = _parse_labels(p, AUDIO_DIR)
    records.append({"path": str(p), **labels})

df = pd.DataFrame(records)

# Report any unresolved labels
_missing_species = df["species"].isna().sum()
_missing_valence = df["valence"].isna().sum()
if _missing_species > 0:
    print(f"WARNING: {_missing_species} files with unknown species. Sample paths:")
    display(df[df["species"].isna()]["path"].head(5))
if _missing_valence > 0:
    print(f"WARNING: {_missing_valence} files with unknown valence. Sample paths:")
    display(df[df["valence"].isna()]["path"].head(5))

# Drop rows where labels could not be resolved
df = df.dropna(subset=["species", "valence"]).reset_index(drop=True)

META_PATH = DATA_DIR / "metadata.csv"
df.to_csv(META_PATH, index=False)
print(f"\nWorking set: {len(df)} calls")
print(f"  Species: {sorted(df['species'].unique())}")
print(f"  Valence: {sorted(df['valence'].unique())}")
display(df.groupby(["species", "valence"]).size().rename("n_calls").unstack(fill_value=0))
Working set: 3181 calls
  Species: ['Cow', 'Goat', 'Horse', 'Pig', 'Sheep', 'Wild Boar']
  Valence: ['negative', 'positive']
valence negative positive
species
Cow 1179 75
Goat 268 93
Horse 245 99
Pig 316 248
Sheep 342 54
Wild Boar 158 104
# Distribution: calls per species coloured by valence
count_df = df.groupby(["species", "valence"]).size().reset_index(name="n_calls").sort_values("species")
fig_dist = px.bar(
    count_df,
    x="species",
    y="n_calls",
    color="valence",
    color_discrete_map={"positive": "#2ecc71", "negative": "#e74c3c"},
    barmode="group",
    title="Calls per species and emotional valence",
    labels={"species": "Species", "n_calls": "Number of calls", "valence": "Valence"},
    text="n_calls",
)
fig_dist.update_traces(textposition="outside")
fig_dist.show()

3. Embedding Extraction

MIN_SAMPLES = int(0.5 * TARGET_SR)


def load_audio(path: str, target_sr: int = TARGET_SR) -> torch.Tensor:
    """Load any audio file, convert to mono, resample, zero-pad to at least 0.5 s."""
    # librosa handles WAV, MP3, FLAC, OGG, etc.
    wav, _ = librosa.load(path, sr=target_sr, mono=True)
    if len(wav) < MIN_SAMPLES:
        wav = np.pad(wav, (0, MIN_SAMPLES - len(wav)))
    return torch.from_numpy(wav).unsqueeze(0)


print(f"Sample shape: {load_audio(df['path'].iloc[0]).shape}")
Sample shape: torch.Size([1, 24957])
BEATS_CACHE = EMBED_DIR / "beats_embeddings.npy"

if BEATS_CACHE.exists() and np.load(BEATS_CACHE).shape[0] == len(df):
    beats_embs = np.load(BEATS_CACHE)
    print(f"Loaded cached BEATs last-layer embeddings: {beats_embs.shape}")
else:
    print(f"Loading model: {BEATS_MODEL}")
    model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="BEATs last-layer"):
            wav = load_audio(path)
            feats = model(wav)  # (1, T, 768)
            embeddings.append(feats.mean(dim=1).squeeze(0).cpu().numpy())
    beats_embs = np.stack(embeddings)
    np.save(BEATS_CACHE, beats_embs)
    print(f"Saved BEATs last-layer embeddings: {beats_embs.shape}")
    del model
Loaded cached BEATs last-layer embeddings: (3181, 768)
BEATS_ALL_CACHE = EMBED_DIR / "beats_all_layers_embeddings.npy"

if BEATS_ALL_CACHE.exists() and np.load(BEATS_ALL_CACHE).shape[0] == len(df):
    beats_all_embs = np.load(BEATS_ALL_CACHE)
    print(f"Loaded cached BEATs all-layers embeddings: {beats_all_embs.shape}")
else:
    print(f"Loading model for all-layers extraction: {BEATS_MODEL}")
    model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
    model.eval()

    try:
        encoder_layers = model.model.encoder.layers
    except AttributeError:
        encoder_layers = model.backbone.encoder.layers
    n_layers = len(encoder_layers)
    layer_store: dict = {}
    hooks = []
    for i, layer in enumerate(encoder_layers):

        def _make_hook(idx: int):
            def _hook(module, inp, out):
                layer_store[idx] = out[0] if isinstance(out, tuple) else out

            return _hook

        hooks.append(layer.register_forward_hook(_make_hook(i)))
    print(f"  Registered hooks on {n_layers} transformer layers.")

    all_embs = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="BEATs all-layers"):
            layer_store.clear()
            wav = load_audio(path)
            _ = model(wav)
            per_layer = [
                layer_store[i].view(-1, layer_store[i].shape[-1]).mean(dim=0).cpu().numpy() for i in range(n_layers)
            ]
            all_embs.append(np.mean(per_layer, axis=0))

    for h in hooks:
        h.remove()

    beats_all_embs = np.stack(all_embs)
    np.save(BEATS_ALL_CACHE, beats_all_embs)
    print(f"Saved BEATs all-layers embeddings: {beats_all_embs.shape}")
    del model
Loaded cached BEATs all-layers embeddings: (3181, 768)
EFFNET_CACHE = EMBED_DIR / "effnet_embeddings.npy"

if EFFNET_CACHE.exists() and np.load(EFFNET_CACHE).shape[0] == len(df):
    effnet_embs = np.load(EFFNET_CACHE)
    print(f"Loaded cached EfficientNet embeddings: {effnet_embs.shape}")
else:
    print(f"Loading model: {EFFNET_MODEL}")
    model = load_model(EFFNET_MODEL, return_features_only=True, device=DEVICE)
    model.eval()
    embeddings = []
    with torch.no_grad():
        for path in tqdm(df["path"], desc="EfficientNet"):
            wav = load_audio(path)
            feats = model(wav)  # (1, C, H, W)
            embeddings.append(feats.mean(dim=(2, 3)).squeeze(0).cpu().numpy())
    effnet_embs = np.stack(embeddings)
    np.save(EFFNET_CACHE, effnet_embs)
    print(f"Saved EfficientNet embeddings: {effnet_embs.shape}")
    del model
Loaded cached EfficientNet embeddings: (3181, 1280)

4. UMAP Visualisation

valence_labels = df["valence"].tolist()
species_labels = df["species"].tolist()
hover = [f"{sp} ({v})" for sp, v in zip(species_labels, valence_labels, strict=False)]

VALENCE_COLOR_MAP = {"positive": "#2ecc71", "negative": "#e74c3c"}

_unique_species = sorted(set(species_labels))
_sp_palette = px.colors.qualitative.Dark24
SPECIES_COLOR_MAP = {sp: _sp_palette[i % len(_sp_palette)] for i, sp in enumerate(_unique_species)}

_model_embs = [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
    (EFFNET_MODEL, effnet_embs),
]

# UMAP coloured by emotional valence
umap_valence_figs = {}
umap_valence_static_figs = {}
for name, embs in _model_embs:
    fig = plot_umap(
        embs,
        labels=valence_labels,
        title=f"UMAP — {name}<br><sup>colour = emotional valence</sup>",
        hover_text=hover,
        color_discrete_map=VALENCE_COLOR_MAP,
    )
    umap_valence_figs[name] = fig
    fig.show()
    static = plot_umap_static(
        embs,
        labels=valence_labels,
        title=f"UMAP — {name} — valence (static)",
        color_map=VALENCE_COLOR_MAP,
    )
    umap_valence_static_figs[name] = static
    display(static)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/e04d02f69d7a626900f539ab2d6bb3a6a0766ba438b51535cb8c73d2f6c08e60.png
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/18f0782248436ff561fbbfb84427773ed1c5547f937574a6e531aefd667fb5e1.png
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/f5f88d03c671e4bfcf168a169b559d83d8fb2f20392613b98fd7db0fddbab935.png ../../_images/e04d02f69d7a626900f539ab2d6bb3a6a0766ba438b51535cb8c73d2f6c08e60.png ../../_images/18f0782248436ff561fbbfb84427773ed1c5547f937574a6e531aefd667fb5e1.png ../../_images/f5f88d03c671e4bfcf168a169b559d83d8fb2f20392613b98fd7db0fddbab935.png
# UMAP coloured by species
umap_species_figs = {}
umap_species_static_figs = {}
for name, embs in _model_embs:
    fig = plot_umap(
        embs,
        labels=species_labels,
        title=f"UMAP — {name}<br><sup>colour = species</sup>",
        hover_text=hover,
        color_discrete_map=SPECIES_COLOR_MAP,
    )
    umap_species_figs[name] = fig
    fig.show()
    static = plot_umap_static(
        embs,
        labels=species_labels,
        title=f"UMAP — {name} — species (static)",
        color_map=SPECIES_COLOR_MAP,
    )
    umap_species_static_figs[name] = static
    display(static)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/782ae11e30a9e8b60e3fbac48475e41f934f288fc72a6c195fed7101cbab273a.png
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/5edaaf9bb0a0b2b8596d65209be212ab2ec415a3dfeaa5a83a1bd3de8005c1ca.png
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/8615c375924884f3e7b9ea4ae807448363345e9781a75d21275dfba2bfcb1161.png ../../_images/782ae11e30a9e8b60e3fbac48475e41f934f288fc72a6c195fed7101cbab273a.png ../../_images/5edaaf9bb0a0b2b8596d65209be212ab2ec415a3dfeaa5a83a1bd3de8005c1ca.png ../../_images/8615c375924884f3e7b9ea4ae807448363345e9781a75d21275dfba2bfcb1161.png

5. Training-Free Metrics

NMI, ARI, and R-AUC evaluate embedding quality without fitting any classifier.

tf_results = {}
for task_name, task_labels in [("valence", valence_labels), ("species", species_labels)]:
    print(f"\n--- Task: {task_name} ---")
    for name, embs in _model_embs:
        m = compute_training_free_metrics(embs, task_labels)
        key = f"{name} | {task_name}"
        tf_results[key] = m
        print(f"  {name}: NMI={m['nmi']:.3f}  ARI={m['ari']:.3f}  R-AUC={m['r_auc']:.3f}")

_tf_rows = [
    {
        "Model": k.split(" | ")[0],
        "Task": k.split(" | ")[1],
        "NMI": round(v["nmi"], 3),
        "ARI": round(v["ari"], 3),
        "R-AUC": round(v["r_auc"], 3),
    }
    for k, v in tf_results.items()
]
display(pd.DataFrame(_tf_rows).set_index(["Model", "Task"]))
--- Task: valence ---
  esp_aves2_sl_beats_all (last layer): NMI=0.092  ARI=0.213  R-AUC=0.725
  esp_aves2_sl_beats_all (all layers avg): NMI=0.131  ARI=0.278  R-AUC=0.740
  esp_aves2_effnetb0_all: NMI=0.099  ARI=0.189  R-AUC=0.707

--- Task: species ---
  esp_aves2_sl_beats_all (last layer): NMI=0.785  ARI=0.641  R-AUC=0.785
  esp_aves2_sl_beats_all (all layers avg): NMI=0.751  ARI=0.680  R-AUC=0.797
  esp_aves2_effnetb0_all: NMI=0.581  ARI=0.446  R-AUC=0.674
NMI ARI R-AUC
Model Task
esp_aves2_sl_beats_all (last layer) valence 0.092 0.213 0.725
esp_aves2_sl_beats_all (all layers avg) valence 0.131 0.278 0.740
esp_aves2_effnetb0_all valence 0.099 0.189 0.707
esp_aves2_sl_beats_all (last layer) species 0.785 0.641 0.785
esp_aves2_sl_beats_all (all layers avg) species 0.751 0.680 0.797
esp_aves2_effnetb0_all species 0.581 0.446 0.674

6. Linear Probe

Two tasks: binary emotional valence (positive / negative) and 7-class species identification.

probe_kwargs = dict(test_size=0.2, random_state=42, max_iter=1000)

probe_results = {}
for emb_name, embs in _model_embs:
    for task, task_labels in [("valence", valence_labels), ("species", species_labels)]:
        key = f"{emb_name} | {task}"
        res = run_linear_probe(embs, task_labels, **probe_kwargs)
        probe_results[key] = res
        print(f"{key}: accuracy = {res['accuracy']:.3f}")
esp_aves2_sl_beats_all (last layer) | valence: accuracy = 0.750
esp_aves2_sl_beats_all (last layer) | species: accuracy = 0.992
esp_aves2_sl_beats_all (all layers avg) | valence: accuracy = 0.725
esp_aves2_sl_beats_all (all layers avg) | species: accuracy = 0.998
esp_aves2_effnetb0_all | valence: accuracy = 0.759
esp_aves2_effnetb0_all | species: accuracy = 0.979
rows = [
    {"Model": k.split(" | ")[0], "Task": k.split(" | ")[1], "Accuracy": round(v["accuracy"], 4)}
    for k, v in probe_results.items()
]
acc_df = pd.DataFrame(rows)
display(acc_df.pivot(index="Model", columns="Task", values="Accuracy").style.format("{:.4f}"))

fig_cmp_valence = plot_model_comparison(
    {k: v["accuracy"] for k, v in probe_results.items() if "valence" in k},
    title="Ungulate — emotional valence classification accuracy",
)
fig_cmp_valence.show()
display(
    plot_model_comparison_static(
        {k: v["accuracy"] for k, v in probe_results.items() if "valence" in k},
        title="Ungulate — valence linear probe (static)",
    )
)

fig_cmp_species = plot_model_comparison(
    {k: v["accuracy"] for k, v in probe_results.items() if "species" in k},
    title="Ungulate — species identification accuracy",
)
fig_cmp_species.show()
display(
    plot_model_comparison_static(
        {k: v["accuracy"] for k, v in probe_results.items() if "species" in k},
        title="Ungulate — species linear probe (static)",
    )
)
Task species valence
Model    
esp_aves2_effnetb0_all 0.9790 0.7586
esp_aves2_sl_beats_all (all layers avg) 0.9975 0.7254
esp_aves2_sl_beats_all (last layer) 0.9917 0.7497
../../_images/e9508192dcf837faabb8717ea4a761fa40c0d338b61ef9856a6cf9cc65ae2a06.png ../../_images/c0381383d172d3f6c93fdd8c22bdd4f6f750e44c5db0fd06af823521d1436bcf.png ../../_images/e9508192dcf837faabb8717ea4a761fa40c0d338b61ef9856a6cf9cc65ae2a06.png ../../_images/c0381383d172d3f6c93fdd8c22bdd4f6f750e44c5db0fd06af823521d1436bcf.png

7. Attention Probe (sl-BEATs)

Multi-head attention probe on BEATs embeddings for both classification tasks.

attn_probe_kwargs = dict(num_heads=8, num_attn_layers=2, epochs=50, test_size=0.2, random_state=42)
attention_results = {}
for emb_name, embs in [
    (f"{BEATS_MODEL} (last layer)", beats_embs),
    (f"{BEATS_MODEL} (all layers avg)", beats_all_embs),
]:
    for task, task_labels in [("valence", valence_labels), ("species", species_labels)]:
        key = f"{emb_name} | {task}"
        res = run_attention_probe(embs, task_labels, **attn_probe_kwargs)
        attention_results[key] = res
        print(f"{key} (attention): accuracy = {res['accuracy']:.3f}")

for task in ["valence", "species"]:
    cmp = {}
    for base in [f"{BEATS_MODEL} (last layer)", f"{BEATS_MODEL} (all layers avg)"]:
        k = f"{base} | {task}"
        cmp[f"{base} (linear)"] = probe_results[k]["accuracy"]
        cmp[f"{base} (attention)"] = attention_results[k]["accuracy"]
    fig_attn = plot_model_comparison(cmp, title=f"Ungulate — sl-BEATs {task}: linear vs attention")
    fig_attn.show()
    display(plot_model_comparison_static(cmp, title=f"Ungulate — sl-BEATs {task} (static)"))
esp_aves2_sl_beats_all (last layer) | valence (attention): accuracy = 0.762
esp_aves2_sl_beats_all (last layer) | species (attention): accuracy = 0.986
esp_aves2_sl_beats_all (all layers avg) | valence (attention): accuracy = 0.757
esp_aves2_sl_beats_all (all layers avg) | species (attention): accuracy = 0.988
../../_images/105dc6a7b5f280c701e55c0d2afa13837368a1e22d8f8b0c71f8b924d03b1aa6.png ../../_images/b87ee9c04713975e15eadac700d7aab74dc73b3f8d12229d9b46120df307e816.png ../../_images/105dc6a7b5f280c701e55c0d2afa13837368a1e22d8f8b0c71f8b924d03b1aa6.png ../../_images/b87ee9c04713975e15eadac700d7aab74dc73b3f8d12229d9b46120df307e816.png
def plot_confusion_matrix(cm: np.ndarray, classes: list, title: str):
    cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
    return px.imshow(
        cm_norm,
        x=classes,
        y=classes,
        color_continuous_scale="Blues",
        zmin=0,
        zmax=1,
        text_auto=".2f",
        title=title,
        labels={"x": "Predicted", "y": "True", "color": "Recall"},
        aspect="auto",
    )


for key, res in probe_results.items():
    plot_confusion_matrix(res["confusion_matrix"], res["classes"], title=f"Confusion matrix: {key}").show()
    display(
        confusion_heatmap_static(
            res["confusion_matrix"],
            res["classes"],
            title=f"Confusion matrix: {key} (static)",
        )
    )
../../_images/0ea2dff71bacb3025871358ebf75cf69430f351e78310c24dbb22b7b5102bf4f.png ../../_images/f36de00fe95d40ddc598ed208c4404673647f1454ba36d9c3c9575ee458c5065.png ../../_images/fca35423204d15858362784e16bdabfa30aec87dcbe789b80883bca0366731de.png ../../_images/22181b149a84025ad896bcada4815e66bfa4632b7de52aab1555cacc51c2708c.png ../../_images/e26770768d308b6a187216d8bb06ad7a947409ab84aa7d78ba64bfb1ca3b1518.png ../../_images/e7d43752939fe470272ea660aa26af6385b17413d651168af697bbaa9c2091dc.png ../../_images/0ea2dff71bacb3025871358ebf75cf69430f351e78310c24dbb22b7b5102bf4f.png ../../_images/f36de00fe95d40ddc598ed208c4404673647f1454ba36d9c3c9575ee458c5065.png ../../_images/fca35423204d15858362784e16bdabfa30aec87dcbe789b80883bca0366731de.png ../../_images/22181b149a84025ad896bcada4815e66bfa4632b7de52aab1555cacc51c2708c.png ../../_images/e26770768d308b6a187216d8bb06ad7a947409ab84aa7d78ba64bfb1ca3b1518.png ../../_images/e7d43752939fe470272ea660aa26af6385b17413d651168af697bbaa9c2091dc.png

8. Cross-Species Evaluation (Leave-One-Species-Out)

Train a linear valence probe on all species except one, then test on the held-out species. This measures how well embeddings transfer to an unseen species at inference time — a realistic deployment scenario where not all species are labelled at training time.

def loso_valence_probe(
    embs: np.ndarray,
    species: list[str],
    valence: list[str],
    max_iter: int = 2000,
    random_state: int = 42,
) -> dict[str, dict]:
    """Leave-one-species-out cross-species valence classification.

    For each species S: train LogisticRegression on all other species,
    evaluate on S.  Returns per-species accuracy and test-set size.
    """
    sp_arr = np.array(species)
    val_arr = np.array(valence)
    unique_species = sorted(set(species))
    results = {}
    for test_sp in unique_species:
        train_mask = sp_arr != test_sp
        test_mask = sp_arr == test_sp
        X_train, y_train = embs[train_mask], val_arr[train_mask]
        X_test, y_test = embs[test_mask], val_arr[test_mask]
        if len(np.unique(y_train)) < 2 or len(X_test) == 0:
            results[test_sp] = {"accuracy": float("nan"), "n_test": int(test_mask.sum())}
            continue
        from sklearn.metrics import confusion_matrix
        from sklearn.preprocessing import StandardScaler
        from sklearn.utils import resample

        from utils.probing import _balanced_accuracy_from_confusion_matrix

        # Balance the *training* fold by oversampling minority classes.
        cls_u, cnt_u = np.unique(y_train, return_counts=True)
        max_cnt = int(cnt_u.max())
        x_parts: list[np.ndarray] = []
        y_parts: list[np.ndarray] = []
        for cls in cls_u:
            mask = y_train == cls
            x_c = X_train[mask]
            y_c = y_train[mask]
            if len(x_c) < max_cnt:
                x_c, y_c = resample(
                    x_c,
                    y_c,
                    n_samples=max_cnt,
                    random_state=random_state,
                    replace=True,
                )
            x_parts.append(x_c)
            y_parts.append(y_c)
        X_train = np.vstack(x_parts)
        y_train = np.concatenate(y_parts)

        scaler = StandardScaler()
        X_train_sc = scaler.fit_transform(X_train)
        X_test_sc = scaler.transform(X_test)
        clf = LogisticRegression(
            max_iter=max_iter,
            random_state=random_state,
            C=1.0,
            solver="lbfgs",
        )
        clf.fit(X_train_sc, y_train)
        y_pred = clf.predict(X_test_sc)
        all_labels = np.array(sorted(set(valence)))
        cm = confusion_matrix(y_test, y_pred, labels=all_labels)
        results[test_sp] = {
            "accuracy": _balanced_accuracy_from_confusion_matrix(cm),
            "n_test": int(test_mask.sum()),
        }
    return results


loso_results: dict[str, dict[str, dict]] = {}
for emb_name, embs in _model_embs:
    print(f"\n--- {emb_name} ---")
    res = loso_valence_probe(embs, species_labels, valence_labels)
    loso_results[emb_name] = res
    for sp, r in res.items():
        acc_str = f"{r['accuracy']:.3f}" if not np.isnan(r["accuracy"]) else "  n/a"
        print(f"  test={sp:<22s}  acc={acc_str}  (n={r['n_test']})")
--- esp_aves2_sl_beats_all (last layer) ---
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
  test=Cow                     acc=0.503  (n=1254)
  test=Goat                    acc=0.477  (n=361)
  test=Horse                   acc=0.479  (n=344)
  test=Pig                     acc=0.485  (n=564)
  test=Sheep                   acc=0.552  (n=396)
  test=Wild Boar               acc=0.551  (n=262)

--- esp_aves2_sl_beats_all (all layers avg) ---
  test=Cow                     acc=0.520  (n=1254)
  test=Goat                    acc=0.410  (n=361)
  test=Horse                   acc=0.500  (n=344)
  test=Pig                     acc=0.568  (n=564)
  test=Sheep                   acc=0.499  (n=396)
  test=Wild Boar               acc=0.554  (n=262)

--- esp_aves2_effnetb0_all ---
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
  test=Cow                     acc=0.589  (n=1254)
  test=Goat                    acc=0.463  (n=361)
  test=Horse                   acc=0.520  (n=344)
  test=Pig                     acc=0.553  (n=564)
  test=Sheep                   acc=0.507  (n=396)
  test=Wild Boar               acc=0.629  (n=262)
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/sklearn/linear_model/_sag.py:348: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  warnings.warn(
# Summarise LOSO as a heatmap: species (rows) × model (columns)
_loso_rows = []
for emb_name, per_sp in loso_results.items():
    for sp, r in per_sp.items():
        _loso_rows.append({"Model": emb_name, "Test species": sp, "Accuracy": r["accuracy"]})

loso_df = pd.DataFrame(_loso_rows)
loso_pivot = loso_df.pivot(index="Test species", columns="Model", values="Accuracy")

# Append mean row
loso_pivot.loc["Mean"] = loso_pivot.mean()
display(loso_pivot.style.format("{:.3f}").background_gradient(cmap="RdYlGn", axis=None, vmin=0.4, vmax=1.0))

# Interactive bar chart: per-species LOSO accuracy for each model
fig_loso = px.bar(
    loso_df,
    x="Test species",
    y="Accuracy",
    color="Model",
    barmode="group",
    title="Cross-species valence classification (LOSO) — per held-out species accuracy",
    labels={"Accuracy": "Valence accuracy (test species)"},
    text_auto=".3f",
)
fig_loso.update_layout(yaxis_range=[0, 1])
fig_loso.show()

display(
    plot_model_comparison_static(
        {k: float(np.nanmean([r["accuracy"] for r in v.values()])) for k, v in loso_results.items()},
        title="LOSO — mean cross-species valence accuracy (static)",
    )
)
Model esp_aves2_effnetb0_all esp_aves2_sl_beats_all (all layers avg) esp_aves2_sl_beats_all (last layer)
Test species      
Cow 0.589 0.520 0.503
Goat 0.463 0.410 0.477
Horse 0.520 0.500 0.479
Pig 0.553 0.568 0.485
Sheep 0.507 0.499 0.552
Wild Boar 0.629 0.554 0.551
Mean 0.544 0.508 0.508
../../_images/d2885521a75930c41f2755efc6ec28f00d2353bd63ef3d5bad2af69a13913d78.png ../../_images/d2885521a75930c41f2755efc6ec28f00d2353bd63ef3d5bad2af69a13913d78.png

Padding Strategy Experiment

For audio clips shorter than 3 s, the choice of padding method can affect how well embeddings capture the vocalization. We compare six strategies for extending clips to a fixed 3-second window:

Strategy

Description

Pad left

Silence prepended before the call

Pad right

Silence appended after the call (current default)

Pad both

Silence split evenly before and after the call

Pad both + pink noise

Silence regions filled with low-amplitude pink (1/f) noise

Slow down 2×

Time-stretched to half speed; any remainder zero-padded right

Slow down 3×

Time-stretched to one-third speed; any remainder zero-padded right

Clips already ≥ 3 s are trimmed to 3 s. Only BEATs last-layer embeddings are used for efficiency.

PAD_TARGET_SAMPLES = int(3.0 * TARGET_SR)
PAD_EXP_EMBED_DIR = EMBED_DIR / "padding_experiment"
PAD_EXP_EMBED_DIR.mkdir(exist_ok=True)

_pad_rng = np.random.default_rng(42)


def _pink_noise(n_samples: int, rng: np.random.Generator) -> np.ndarray:
    """Approximate pink (1/f) noise via FFT spectrum shaping."""
    white = rng.standard_normal(n_samples)
    fft = np.fft.rfft(white)
    freqs = np.fft.rfftfreq(n_samples)
    freqs[0] = 1.0
    fft /= np.sqrt(freqs)
    pink = np.fft.irfft(fft, n=n_samples)
    return (pink / (np.std(pink) + 1e-8) * 0.01).astype(np.float32)


def _pad_audio(wav: np.ndarray, strategy: str, target: int, rng: np.random.Generator) -> np.ndarray:
    """Pad or time-stretch wav to exactly `target` samples."""
    if strategy in ("slowdown_2x", "slowdown_3x"):
        rate = 0.5 if strategy == "slowdown_2x" else 1.0 / 3.0
        stretched = librosa.effects.time_stretch(wav, rate=rate)
        if len(stretched) >= target:
            return stretched[:target]
        return np.pad(stretched, (0, target - len(stretched)))
    clipped = wav[:target]
    n_pad = max(0, target - len(clipped))
    if strategy == "pad_left":
        return np.pad(clipped, (n_pad, 0))
    if strategy == "pad_right":
        return np.pad(clipped, (0, n_pad))
    if strategy == "pad_both":
        left = n_pad // 2
        return np.pad(clipped, (left, n_pad - left))
    if strategy == "pad_both_pink":
        left = n_pad // 2
        noise = _pink_noise(target, rng)
        result = noise.copy()
        result[left : left + len(clipped)] = clipped
        return result
    raise ValueError(f"Unknown strategy: {strategy!r}")


PADDING_STRATEGIES = ["pad_left", "pad_right", "pad_both", "pad_both_pink", "slowdown_2x", "slowdown_3x"]
STRATEGY_LABELS = {
    "pad_left": "Pad left",
    "pad_right": "Pad right",
    "pad_both": "Pad both",
    "pad_both_pink": "Pad both + pink noise",
    "slowdown_2x": "Slow 2×",
    "slowdown_3x": "Slow 3×",
}

padding_embs: dict[str, np.ndarray] = {}
_pad_model = None

for _strategy in PADDING_STRATEGIES:
    _cache = PAD_EXP_EMBED_DIR / f"beats_last_{_strategy}.npy"
    if _cache.exists():
        _arr = np.load(_cache)
        if _arr.shape[0] == len(df):
            padding_embs[_strategy] = _arr
            print(f"Loaded  '{_strategy}': {padding_embs[_strategy].shape}")
            continue

    # Recompute embeddings (inference only; no training)
    if _pad_model is None:
        print(f"Loading {BEATS_MODEL} ...")
        _pad_model = load_model(BEATS_MODEL, return_features_only=True, device=DEVICE)
        _pad_model.eval()

    _embeddings = []
    with torch.no_grad():
        for _path in tqdm(df["path"], desc=f"BEATs ({_strategy})"):
            _wav, _ = librosa.load(_path, sr=TARGET_SR, mono=True)
            _wav = _pad_audio(_wav, _strategy, PAD_TARGET_SAMPLES, _pad_rng)
            _tensor = torch.from_numpy(_wav).unsqueeze(0).to(DEVICE)
            _feats = _pad_model(_tensor)
            _embeddings.append(_feats.mean(dim=1).squeeze(0).cpu().numpy())

    padding_embs[_strategy] = np.stack(_embeddings)
    np.save(_cache, padding_embs[_strategy])
    print(f"Saved   '{_strategy}': {padding_embs[_strategy].shape}")

if _pad_model is not None:
    del _pad_model
if not padding_embs:
    print("Skipping padding probes: no cached embeddings.")
else:
    pad_probe_results: dict[str, dict] = {}
    for _strategy, _embs in padding_embs.items():
        for _task, _task_labels in [("valence", valence_labels), ("species", species_labels)]:
            _key = f"{_strategy} | {_task}"
            _res = run_linear_probe(_embs, _task_labels, **probe_kwargs)
            pad_probe_results[_key] = _res
            print(f"{STRATEGY_LABELS[_strategy]:<24s} | {_task:<14s}: {_res['accuracy']:.3f}")
Skipping padding probes: no cached embeddings.
if not globals().get("pad_probe_results"):
    print("Skipping padding viz: no results.")
else:
    _pad_rows = []
    for _key, _res in pad_probe_results.items():
        _strategy, _task = _key.split(" | ")
        _pad_rows.append(
            {"Strategy": STRATEGY_LABELS[_strategy], "Task": _task, "Accuracy": round(_res["accuracy"], 4)}
        )
    pad_df = pd.DataFrame(_pad_rows)
    display(pad_df.pivot(index="Strategy", columns="Task", values="Accuracy").style.format("{:.4f}"))

    fig_pad = px.bar(
        pad_df,
        x="Strategy",
        y="Accuracy",
        color="Task",
        barmode="group",
        title="Padding strategy comparison — BEATs last layer, 3-second window",
        labels={"Strategy": "Padding strategy", "Accuracy": "Linear probe accuracy"},
        text_auto=".3f",
    )
    fig_pad.update_layout(yaxis_range=[0, 1], xaxis_tickangle=-20)
    fig_pad.show()

    ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
    fig_pad.write_html(str(ARTIFACTS_DIR / "padding_strategy_comparison.html"))

    for _task in sorted(pad_df["Task"].unique()):
        _task_dict = {row["Strategy"]: row["Accuracy"] for _, row in pad_df[pad_df["Task"] == _task].iterrows()}
        _static = plot_model_comparison_static(_task_dict, title=f"Padding strategy — {_task} (static)")
        display(_static)
        _static.savefig(str(ARTIFACTS_DIR / f"padding_strategy_{_task}_static.png"), dpi=150, bbox_inches="tight")
        plt.close()
    print("Padding experiment figures saved.")
Skipping padding viz: no results.

9. Save Artifacts

ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

# Save interactive HTML figures
for name, fig in umap_valence_figs.items():
    slug = name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "-")
    fig.write_html(str(ARTIFACTS_DIR / f"umap_valence_{slug}.html"))

for name, fig in umap_species_figs.items():
    slug = name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "-")
    fig.write_html(str(ARTIFACTS_DIR / f"umap_species_{slug}.html"))

fig_cmp_valence.write_html(str(ARTIFACTS_DIR / "model_comparison_valence.html"))
fig_cmp_species.write_html(str(ARTIFACTS_DIR / "model_comparison_species.html"))
fig_loso.write_html(str(ARTIFACTS_DIR / "loso_cross_species.html"))

# Save static PNG figures
for name, fig in umap_valence_static_figs.items():
    slug = name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "-")
    fig.savefig(str(ARTIFACTS_DIR / f"umap_valence_{slug}.png"), dpi=150, bbox_inches="tight")

for name, fig in umap_species_static_figs.items():
    slug = name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "-")
    fig.savefig(str(ARTIFACTS_DIR / f"umap_species_{slug}.png"), dpi=150, bbox_inches="tight")

plt.close("all")

# Save metrics JSON
_metrics_out = {
    "n_calls": len(df),
    "n_species": df["species"].nunique(),
    "species": sorted(df["species"].unique().tolist()),
    "valence_counts": df["valence"].value_counts().to_dict(),
    "training_free": {k: v for k, v in tf_results.items()},
    "linear_probe_accuracy": {k: round(v["accuracy"], 4) for k, v in probe_results.items()},
    "attention_probe_accuracy": {k: round(v["accuracy"], 4) for k, v in attention_results.items()},
    "loso_cross_species": {
        emb_name: {sp: round(r["accuracy"], 4) for sp, r in per_sp.items()} for emb_name, per_sp in loso_results.items()
    },
    "loso_cross_species_mean": {
        emb_name: round(float(np.nanmean([r["accuracy"] for r in per_sp.values()])), 4)
        for emb_name, per_sp in loso_results.items()
    },
}
with open(ARTIFACTS_DIR / "ungulate_metrics.json", "w") as _f:
    json.dump(_metrics_out, _f, indent=2, default=float)

print(f"Artifacts saved to {ARTIFACTS_DIR}")
print("  annotations.csv (raw annotations from Excel)")
print("  ungulate_metrics.json")
print(f"  {len(umap_valence_figs) + len(umap_species_figs)} UMAP HTML + PNG files")
print("  model_comparison_valence.html / _species.html")
print("  loso_cross_species.html")
Artifacts saved to /home/marius_miron_earthspecies_org/code/avex-examples/examples/09_ungulate_valence/artifacts
  annotations.csv (raw annotations from Excel)
  ungulate_metrics.json
  6 UMAP HTML + PNG files
  model_comparison_valence.html / _species.html
  loso_cross_species.html