Interactive Embedding Visualisation

Extracts audio embeddings using esp_aves2_sl_beats_all (BEATs) and esp_aves2_effnetb0_all (EfficientNet) and renders an interactive side-by-side UMAP dashboard.

Data sources (in priority order):

  1. Real audio files found in audio_dir — labelled via labels_csv or filename parsing

  2. Pre-extracted .npy embeddings found in embeddings_dir

  3. Synthetic audio (Gaussian noise) passed through the real avex models — always works out of the box without any data, and still produces genuine model embeddings

The exported embedding_explorer.html is a fully self-contained file that can be opened in any browser — no Python environment required.

Configuration

Edit the paths below. Leave audio_dir and embeddings_dir as None to run with synthetic audio (genuine model embeddings, fake labels).

import warnings

warnings.filterwarnings("ignore")

# ── User configuration ────────────────────────────────────────────────────────

# Directory of .wav / .flac files to embed (None → synthetic audio)
audio_dir = None  # e.g. "../01_giant_otter_classifier/data/giant_otter/audio"

# Optional CSV with columns 'filename' and 'label' for colouring points
labels_csv = None  # e.g. "../01_giant_otter_classifier/data/giant_otter/labels.csv"

# Or: directory of pre-extracted .npy embedding files from a previous run
# (used only if audio_dir is None and this path exists)
embeddings_dir = None  # e.g. "../01_giant_otter_classifier/outputs/giant_otter/embeddings"

# Output HTML path (relative to this notebook)
output_html = "embedding_explorer.html"

# avex model names
BEATS_MODEL = "esp_aves2_sl_beats_all"
EFFNET_MODEL = "esp_aves2_effnetb0_all"

# Synthetic-audio fallback settings (ignored when real data is used)
N_SYNTHETIC = 120  # total synthetic clips
N_CLASSES_SYN = 4  # fake class count
CLIP_SECONDS = 5.0  # clip length
SAMPLE_RATE = 16_000
# ──────────────────────────────────────────────────────────────────────────────
import os
import pathlib

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import torch
from plotly.subplots import make_subplots

pio.renderers.default = "notebook"  # ensures text/html output for Sphinx/myst-nb

from avex import load_model
from IPython.display import display

from utils.visualization import plot_umap, plot_umap_static

NOTEBOOK_DIR = pathlib.Path().resolve()
print(f"Working directory: {NOTEBOOK_DIR}")

# CI uses a tight per-cell timeout; avoid heavyweight model inference there.
FAST_SMOKE = os.environ.get("CI") == "true" or bool(os.environ.get("GITHUB_ACTIONS"))
print(f"FAST_SMOKE: {FAST_SMOKE}")
2026-04-10 08:53:52.632184: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-10 08:53:54.524994: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Working directory: /home/marius_miron_earthspecies_org/code/avex-examples/examples/07_interactive_visualization

Step 1 — Gather audio clips and labels

We collect (waveform, label, filename) triples from whichever source is available.

import re

import librosa
import pandas as pd
import soundfile as sf

N_SAMPLES = int(CLIP_SECONDS * SAMPLE_RATE)
waveforms: list[np.ndarray] = []
labels_list: list[str] = []
filenames_list: list[str] = []

SOURCE = "synthetic"  # will be overridden below

# ── Option A: real audio directory ────────────────────────────────────────────
if audio_dir is not None:
    audio_path = pathlib.Path(audio_dir)
    wav_files = sorted(audio_path.rglob("*.wav")) + sorted(audio_path.rglob("*.flac"))

    # Load optional labels CSV
    label_map: dict[str, str] = {}
    if labels_csv is not None:
        df_lbl = pd.read_csv(labels_csv)
        label_map = dict(zip(df_lbl["filename"], df_lbl["label"], strict=False))

    for wav in wav_files:
        audio, sr = sf.read(str(wav), always_2d=False)
        if audio.ndim > 1:
            audio = audio.mean(axis=1)
        if sr != SAMPLE_RATE:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
        # Pad / trim to fixed length
        if len(audio) < N_SAMPLES:
            audio = np.pad(audio, (0, N_SAMPLES - len(audio)))
        else:
            audio = audio[:N_SAMPLES]
        waveforms.append(audio.astype(np.float32))
        # Label: from CSV, filename prefix, or 'unknown'
        lbl = label_map.get(wav.name, re.match(r"^([a-zA-Z]+)", wav.stem))
        labels_list.append(lbl.group(1) if hasattr(lbl, "group") else str(lbl))
        filenames_list.append(wav.name)

    SOURCE = f"audio_dir ({len(waveforms)} clips)"

# ── Option B: synthetic audio passed through real avex models ─────────────────
if not waveforms:
    print("No real audio found — generating synthetic clips (real model, fake labels).")
    rng = np.random.default_rng(42)
    SPECIES = [f"species_{chr(65 + i)}" for i in range(N_CLASSES_SYN)]
    for i in range(N_SYNTHETIC):
        labels_list.append(SPECIES[i % N_CLASSES_SYN])
        filenames_list.append(f"clip_{i:04d}.wav")
        # Each class gets a slightly different spectral centroid to produce
        # loose cluster structure in embedding space
        freq = 440.0 * (1 + 0.3 * (i % N_CLASSES_SYN))
        t = np.linspace(0, CLIP_SECONDS, N_SAMPLES, endpoint=False)
        tone = 0.3 * np.sin(2 * np.pi * freq * t).astype(np.float32)
        noise = rng.standard_normal(N_SAMPLES).astype(np.float32) * 0.05
        waveforms.append(tone + noise)
    SOURCE = f"synthetic ({N_SYNTHETIC} clips, {N_CLASSES_SYN} classes)"

labels_arr = np.array(labels_list)
filenames_arr = np.array(filenames_list)
audio_tensor = torch.from_numpy(np.stack(waveforms))  # (N, T)

print(f"Source: {SOURCE}")
print(f"Audio tensor: {audio_tensor.shape}  labels: {np.unique(labels_arr)}")
No real audio found — generating synthetic clips (real model, fake labels).
Source: synthetic (120 clips, 4 classes)
Audio tensor: torch.Size([120, 80000])  labels: ['species_A' 'species_B' 'species_C' 'species_D']

Step 2 — Extract embeddings with avex

We run both avex models in feature-extraction mode and mean-pool the outputs:

  • BEATs(N, T_tokens, 768) → mean over time → (N, 768)

  • EfficientNet(N, C, H, W) → global average pool → (N, C)

Warning (slow cell)

Last observed runtime: ~55s (from stored execution timestamps). Runtime depends on cache/hardware.

def extract(model_name: str, audio: torch.Tensor, batch_size: int = 16) -> np.ndarray:
    """Extract mean-pooled embeddings for a batch of audio clips.

    Processes `audio` in mini-batches to avoid OOM on large datasets.
    """
    model = load_model(model_name, return_features_only=True, device="cpu")
    model.eval()
    all_emb = []
    for start in range(0, len(audio), batch_size):
        batch = audio[start : start + batch_size]
        with torch.no_grad():
            feat = model(batch, padding_mask=None)
        if feat.ndim == 4:  # EfficientNet: (B, C, H, W)
            emb = feat.mean(dim=(2, 3))
        elif feat.ndim == 3:  # BEATs / EAT: (B, T, D)
            emb = feat.mean(dim=1)
        else:
            emb = feat
        all_emb.append(emb.cpu().numpy())
    return np.concatenate(all_emb, axis=0)


def smoke_embeddings(audio: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
    """Fast deterministic embeddings for CI smoke tests."""
    # (N, T) -> (N, 100) by chunk-averaging, then random project.
    x = audio.float()
    n, t = x.shape
    n_chunks = 100
    trim = (t // n_chunks) * n_chunks
    x = x[:, :trim].reshape(n, n_chunks, -1).mean(dim=-1)
    x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)

    rng = np.random.default_rng(0)
    w_beats = rng.standard_normal((n_chunks, 768), dtype=np.float32) / np.sqrt(n_chunks)
    w_effnet = rng.standard_normal((n_chunks, 1280), dtype=np.float32) / np.sqrt(n_chunks)
    x_np = x.cpu().numpy().astype(np.float32, copy=False)
    return (x_np @ w_beats, x_np @ w_effnet)


# Check if pre-extracted embeddings exist (skip model inference if so)
_emb_dir = pathlib.Path(embeddings_dir) if embeddings_dir else None
_beats_npy = _emb_dir / f"{BEATS_MODEL}.npy" if _emb_dir else None
_effnet_npy = _emb_dir / f"{EFFNET_MODEL}.npy" if _emb_dir else None

if FAST_SMOKE and not ((_beats_npy and _beats_npy.exists()) or (_effnet_npy and _effnet_npy.exists())):
    print("FAST_SMOKE enabled: using lightweight deterministic embeddings")
    beats_emb, effnet_emb = smoke_embeddings(audio_tensor)
    print(f"  → BEATs-like shape: {beats_emb.shape}")
    print(f"  → EfficientNet-like shape: {effnet_emb.shape}")
elif _beats_npy and _beats_npy.exists():
    print(f"Loading pre-extracted BEATs embeddings from {_beats_npy}")
    beats_emb = np.load(_beats_npy)
else:
    print(f"Extracting BEATs embeddings with {BEATS_MODEL} …")
    beats_emb = extract(BEATS_MODEL, audio_tensor)
    print(f"  → shape: {beats_emb.shape}")

if "effnet_emb" in globals():
    pass
elif _effnet_npy and _effnet_npy.exists():
    print(f"Loading pre-extracted EfficientNet embeddings from {_effnet_npy}")
    effnet_emb = np.load(_effnet_npy)
else:
    print(f"Extracting EfficientNet embeddings with {EFFNET_MODEL} …")
    effnet_emb = extract(EFFNET_MODEL, audio_tensor)
    print(f"  → shape: {effnet_emb.shape}")
Extracting BEATs embeddings with esp_aves2_sl_beats_all …
  → shape: (120, 768)
Extracting EfficientNet embeddings with esp_aves2_effnetb0_all …
  → shape: (120, 1280)

Step 3 — UMAP projections

We fit UMAP independently on each model’s embeddings and plot them side by side. Both panels share the same colour palette so classes are directly comparable.

import umap

PALETTE = [
    "#636EFA",
    "#EF553B",
    "#00CC96",
    "#AB63FA",
    "#FFA15A",
    "#19D3F3",
    "#FF6692",
    "#B6E880",
]
unique_labels = sorted(np.unique(labels_arr))
color_map = {lbl: PALETTE[i % len(PALETTE)] for i, lbl in enumerate(unique_labels)}


def fit_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
    """Fit UMAP and return 2-D coordinates."""
    reducer = umap.UMAP(n_components=2, random_state=seed, n_neighbors=min(15, len(emb) - 1))
    return reducer.fit_transform(emb)


def scatter_traces(
    coords: np.ndarray,
    labels: np.ndarray,
    filenames: np.ndarray,
    color_map: dict,
    show_legend: bool = True,
) -> list[go.Scattergl]:
    """One Scattergl trace per label class."""
    traces = []
    for lbl in sorted(np.unique(labels)):
        mask = labels == lbl
        hover = [f"{lbl}<br>{fn}" for fn in filenames[mask]]
        traces.append(
            go.Scattergl(
                x=coords[mask, 0],
                y=coords[mask, 1],
                mode="markers",
                name=lbl,
                marker=dict(color=color_map[lbl], size=7, opacity=0.85),
                text=hover,
                hoverinfo="text",
                legendgroup=lbl,
                showlegend=show_legend,
            )
        )
    return traces


print("Fitting UMAP for BEATs …")
beats_2d = fit_umap(beats_emb)
print("Fitting UMAP for EfficientNet …")
effnet_2d = fit_umap(effnet_emb)

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(f"BEATs ({BEATS_MODEL})", f"EfficientNet ({EFFNET_MODEL})"),
)

for trace in scatter_traces(beats_2d, labels_arr, filenames_arr, color_map, show_legend=True):
    fig.add_trace(trace, row=1, col=1)
for trace in scatter_traces(effnet_2d, labels_arr, filenames_arr, color_map, show_legend=False):
    fig.add_trace(trace, row=1, col=2)

fig.update_layout(
    title_text="Embedding Space Explorer — UMAP Projections",
    height=540,
    hovermode="closest",
    legend=dict(title="Class", itemsizing="constant"),
)
fig.update_xaxes(title_text="UMAP 1", showgrid=False, zeroline=False)
fig.update_yaxes(title_text="UMAP 2", showgrid=False, zeroline=False)
fig.show()
Fitting UMAP for BEATs …
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
Fitting UMAP for EfficientNet …
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(

Step 4 — Per-model detail view

Full-width interactive plot for each model with richer hover information.

fig_beats = plot_umap(
    embeddings=beats_emb,
    labels=labels_arr,
    title=f"BEATs — {BEATS_MODEL}",
    hover_text=filenames_arr.tolist(),
)
fig_beats.show()
display(plot_umap_static(beats_emb, labels_arr, title=f"BEATs — {BEATS_MODEL} (static)"))

fig_effnet = plot_umap(
    embeddings=effnet_emb,
    labels=labels_arr,
    title=f"EfficientNet — {EFFNET_MODEL}",
    hover_text=filenames_arr.tolist(),
)
fig_effnet.show()
display(plot_umap_static(effnet_emb, labels_arr, title=f"EfficientNet — {EFFNET_MODEL} (static)"))
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(

Step 5 — Export to self-contained HTML

The combined figure is written as a single HTML file with Plotly loaded from CDN. Share it with anyone — no Python installation needed.

html_path = NOTEBOOK_DIR / output_html
fig.write_html(str(html_path), include_plotlyjs="cdn")
print(f"Exported: {html_path}  ({html_path.stat().st_size / 1024:.1f} KB)")
Exported: /home/marius_miron_earthspecies_org/code/avex-examples/examples/07_interactive_visualization/embedding_explorer.html  (22.2 KB)