Interactive Embedding Visualisation¶
Extracts audio embeddings using esp_aves2_sl_beats_all (BEATs) and
esp_aves2_effnetb0_all (EfficientNet) and renders an interactive
side-by-side UMAP dashboard.
Data sources (in priority order):
Real audio files found in
audio_dir— labelled vialabels_csvor filename parsingPre-extracted
.npyembeddings found inembeddings_dirSynthetic audio (Gaussian noise) passed through the real avex models — always works out of the box without any data, and still produces genuine model embeddings
The exported embedding_explorer.html is a fully self-contained file that can be
opened in any browser — no Python environment required.
Configuration¶
Edit the paths below. Leave audio_dir and embeddings_dir as None to run
with synthetic audio (genuine model embeddings, fake labels).
import warnings
warnings.filterwarnings("ignore")
# ── User configuration ────────────────────────────────────────────────────────
# Directory of .wav / .flac files to embed (None → synthetic audio)
audio_dir = None # e.g. "../01_giant_otter_classifier/data/giant_otter/audio"
# Optional CSV with columns 'filename' and 'label' for colouring points
labels_csv = None # e.g. "../01_giant_otter_classifier/data/giant_otter/labels.csv"
# Or: directory of pre-extracted .npy embedding files from a previous run
# (used only if audio_dir is None and this path exists)
embeddings_dir = None # e.g. "../01_giant_otter_classifier/outputs/giant_otter/embeddings"
# Output HTML path (relative to this notebook)
output_html = "embedding_explorer.html"
# avex model names
BEATS_MODEL = "esp_aves2_sl_beats_all"
EFFNET_MODEL = "esp_aves2_effnetb0_all"
# Synthetic-audio fallback settings (ignored when real data is used)
N_SYNTHETIC = 120 # total synthetic clips
N_CLASSES_SYN = 4 # fake class count
CLIP_SECONDS = 5.0 # clip length
SAMPLE_RATE = 16_000
# ──────────────────────────────────────────────────────────────────────────────
import os
import pathlib
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import torch
from plotly.subplots import make_subplots
pio.renderers.default = "notebook" # ensures text/html output for Sphinx/myst-nb
from avex import load_model
from IPython.display import display
from utils.visualization import plot_umap, plot_umap_static
NOTEBOOK_DIR = pathlib.Path().resolve()
print(f"Working directory: {NOTEBOOK_DIR}")
# CI uses a tight per-cell timeout; avoid heavyweight model inference there.
FAST_SMOKE = os.environ.get("CI") == "true" or bool(os.environ.get("GITHUB_ACTIONS"))
print(f"FAST_SMOKE: {FAST_SMOKE}")
2026-04-10 08:53:52.632184: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-04-10 08:53:54.524994: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
Working directory: /home/marius_miron_earthspecies_org/code/avex-examples/examples/07_interactive_visualization
Step 1 — Gather audio clips and labels¶
We collect (waveform, label, filename) triples from whichever source is available.
import re
import librosa
import pandas as pd
import soundfile as sf
N_SAMPLES = int(CLIP_SECONDS * SAMPLE_RATE)
waveforms: list[np.ndarray] = []
labels_list: list[str] = []
filenames_list: list[str] = []
SOURCE = "synthetic" # will be overridden below
# ── Option A: real audio directory ────────────────────────────────────────────
if audio_dir is not None:
audio_path = pathlib.Path(audio_dir)
wav_files = sorted(audio_path.rglob("*.wav")) + sorted(audio_path.rglob("*.flac"))
# Load optional labels CSV
label_map: dict[str, str] = {}
if labels_csv is not None:
df_lbl = pd.read_csv(labels_csv)
label_map = dict(zip(df_lbl["filename"], df_lbl["label"], strict=False))
for wav in wav_files:
audio, sr = sf.read(str(wav), always_2d=False)
if audio.ndim > 1:
audio = audio.mean(axis=1)
if sr != SAMPLE_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
# Pad / trim to fixed length
if len(audio) < N_SAMPLES:
audio = np.pad(audio, (0, N_SAMPLES - len(audio)))
else:
audio = audio[:N_SAMPLES]
waveforms.append(audio.astype(np.float32))
# Label: from CSV, filename prefix, or 'unknown'
lbl = label_map.get(wav.name, re.match(r"^([a-zA-Z]+)", wav.stem))
labels_list.append(lbl.group(1) if hasattr(lbl, "group") else str(lbl))
filenames_list.append(wav.name)
SOURCE = f"audio_dir ({len(waveforms)} clips)"
# ── Option B: synthetic audio passed through real avex models ─────────────────
if not waveforms:
print("No real audio found — generating synthetic clips (real model, fake labels).")
rng = np.random.default_rng(42)
SPECIES = [f"species_{chr(65 + i)}" for i in range(N_CLASSES_SYN)]
for i in range(N_SYNTHETIC):
labels_list.append(SPECIES[i % N_CLASSES_SYN])
filenames_list.append(f"clip_{i:04d}.wav")
# Each class gets a slightly different spectral centroid to produce
# loose cluster structure in embedding space
freq = 440.0 * (1 + 0.3 * (i % N_CLASSES_SYN))
t = np.linspace(0, CLIP_SECONDS, N_SAMPLES, endpoint=False)
tone = 0.3 * np.sin(2 * np.pi * freq * t).astype(np.float32)
noise = rng.standard_normal(N_SAMPLES).astype(np.float32) * 0.05
waveforms.append(tone + noise)
SOURCE = f"synthetic ({N_SYNTHETIC} clips, {N_CLASSES_SYN} classes)"
labels_arr = np.array(labels_list)
filenames_arr = np.array(filenames_list)
audio_tensor = torch.from_numpy(np.stack(waveforms)) # (N, T)
print(f"Source: {SOURCE}")
print(f"Audio tensor: {audio_tensor.shape} labels: {np.unique(labels_arr)}")
No real audio found — generating synthetic clips (real model, fake labels).
Source: synthetic (120 clips, 4 classes)
Audio tensor: torch.Size([120, 80000]) labels: ['species_A' 'species_B' 'species_C' 'species_D']
Step 2 — Extract embeddings with avex¶
We run both avex models in feature-extraction mode and mean-pool the outputs:
BEATs →
(N, T_tokens, 768)→ mean over time →(N, 768)EfficientNet →
(N, C, H, W)→ global average pool →(N, C)
Warning (slow cell)
Last observed runtime: ~55s (from stored execution timestamps). Runtime depends on cache/hardware.
def extract(model_name: str, audio: torch.Tensor, batch_size: int = 16) -> np.ndarray:
"""Extract mean-pooled embeddings for a batch of audio clips.
Processes `audio` in mini-batches to avoid OOM on large datasets.
"""
model = load_model(model_name, return_features_only=True, device="cpu")
model.eval()
all_emb = []
for start in range(0, len(audio), batch_size):
batch = audio[start : start + batch_size]
with torch.no_grad():
feat = model(batch, padding_mask=None)
if feat.ndim == 4: # EfficientNet: (B, C, H, W)
emb = feat.mean(dim=(2, 3))
elif feat.ndim == 3: # BEATs / EAT: (B, T, D)
emb = feat.mean(dim=1)
else:
emb = feat
all_emb.append(emb.cpu().numpy())
return np.concatenate(all_emb, axis=0)
def smoke_embeddings(audio: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
"""Fast deterministic embeddings for CI smoke tests."""
# (N, T) -> (N, 100) by chunk-averaging, then random project.
x = audio.float()
n, t = x.shape
n_chunks = 100
trim = (t // n_chunks) * n_chunks
x = x[:, :trim].reshape(n, n_chunks, -1).mean(dim=-1)
x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)
rng = np.random.default_rng(0)
w_beats = rng.standard_normal((n_chunks, 768), dtype=np.float32) / np.sqrt(n_chunks)
w_effnet = rng.standard_normal((n_chunks, 1280), dtype=np.float32) / np.sqrt(n_chunks)
x_np = x.cpu().numpy().astype(np.float32, copy=False)
return (x_np @ w_beats, x_np @ w_effnet)
# Check if pre-extracted embeddings exist (skip model inference if so)
_emb_dir = pathlib.Path(embeddings_dir) if embeddings_dir else None
_beats_npy = _emb_dir / f"{BEATS_MODEL}.npy" if _emb_dir else None
_effnet_npy = _emb_dir / f"{EFFNET_MODEL}.npy" if _emb_dir else None
if FAST_SMOKE and not ((_beats_npy and _beats_npy.exists()) or (_effnet_npy and _effnet_npy.exists())):
print("FAST_SMOKE enabled: using lightweight deterministic embeddings")
beats_emb, effnet_emb = smoke_embeddings(audio_tensor)
print(f" → BEATs-like shape: {beats_emb.shape}")
print(f" → EfficientNet-like shape: {effnet_emb.shape}")
elif _beats_npy and _beats_npy.exists():
print(f"Loading pre-extracted BEATs embeddings from {_beats_npy}")
beats_emb = np.load(_beats_npy)
else:
print(f"Extracting BEATs embeddings with {BEATS_MODEL} …")
beats_emb = extract(BEATS_MODEL, audio_tensor)
print(f" → shape: {beats_emb.shape}")
if "effnet_emb" in globals():
pass
elif _effnet_npy and _effnet_npy.exists():
print(f"Loading pre-extracted EfficientNet embeddings from {_effnet_npy}")
effnet_emb = np.load(_effnet_npy)
else:
print(f"Extracting EfficientNet embeddings with {EFFNET_MODEL} …")
effnet_emb = extract(EFFNET_MODEL, audio_tensor)
print(f" → shape: {effnet_emb.shape}")
Extracting BEATs embeddings with esp_aves2_sl_beats_all …
→ shape: (120, 768)
Extracting EfficientNet embeddings with esp_aves2_effnetb0_all …
→ shape: (120, 1280)
Step 3 — UMAP projections¶
We fit UMAP independently on each model’s embeddings and plot them side by side. Both panels share the same colour palette so classes are directly comparable.
import umap
PALETTE = [
"#636EFA",
"#EF553B",
"#00CC96",
"#AB63FA",
"#FFA15A",
"#19D3F3",
"#FF6692",
"#B6E880",
]
unique_labels = sorted(np.unique(labels_arr))
color_map = {lbl: PALETTE[i % len(PALETTE)] for i, lbl in enumerate(unique_labels)}
def fit_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
"""Fit UMAP and return 2-D coordinates."""
reducer = umap.UMAP(n_components=2, random_state=seed, n_neighbors=min(15, len(emb) - 1))
return reducer.fit_transform(emb)
def scatter_traces(
coords: np.ndarray,
labels: np.ndarray,
filenames: np.ndarray,
color_map: dict,
show_legend: bool = True,
) -> list[go.Scattergl]:
"""One Scattergl trace per label class."""
traces = []
for lbl in sorted(np.unique(labels)):
mask = labels == lbl
hover = [f"{lbl}<br>{fn}" for fn in filenames[mask]]
traces.append(
go.Scattergl(
x=coords[mask, 0],
y=coords[mask, 1],
mode="markers",
name=lbl,
marker=dict(color=color_map[lbl], size=7, opacity=0.85),
text=hover,
hoverinfo="text",
legendgroup=lbl,
showlegend=show_legend,
)
)
return traces
print("Fitting UMAP for BEATs …")
beats_2d = fit_umap(beats_emb)
print("Fitting UMAP for EfficientNet …")
effnet_2d = fit_umap(effnet_emb)
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=(f"BEATs ({BEATS_MODEL})", f"EfficientNet ({EFFNET_MODEL})"),
)
for trace in scatter_traces(beats_2d, labels_arr, filenames_arr, color_map, show_legend=True):
fig.add_trace(trace, row=1, col=1)
for trace in scatter_traces(effnet_2d, labels_arr, filenames_arr, color_map, show_legend=False):
fig.add_trace(trace, row=1, col=2)
fig.update_layout(
title_text="Embedding Space Explorer — UMAP Projections",
height=540,
hovermode="closest",
legend=dict(title="Class", itemsizing="constant"),
)
fig.update_xaxes(title_text="UMAP 1", showgrid=False, zeroline=False)
fig.update_yaxes(title_text="UMAP 2", showgrid=False, zeroline=False)
fig.show()
Fitting UMAP for BEATs …
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
warn(
Fitting UMAP for EfficientNet …
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
warn(
Step 4 — Per-model detail view¶
Full-width interactive plot for each model with richer hover information.
fig_beats = plot_umap(
embeddings=beats_emb,
labels=labels_arr,
title=f"BEATs — {BEATS_MODEL}",
hover_text=filenames_arr.tolist(),
)
fig_beats.show()
display(plot_umap_static(beats_emb, labels_arr, title=f"BEATs — {BEATS_MODEL} (static)"))
fig_effnet = plot_umap(
embeddings=effnet_emb,
labels=labels_arr,
title=f"EfficientNet — {EFFNET_MODEL}",
hover_text=filenames_arr.tolist(),
)
fig_effnet.show()
display(plot_umap_static(effnet_emb, labels_arr, title=f"EfficientNet — {EFFNET_MODEL} (static)"))
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
warn(
/home/marius_miron_earthspecies_org/code/avex-examples/.venv/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
warn(
Step 5 — Export to self-contained HTML¶
The combined figure is written as a single HTML file with Plotly loaded from CDN. Share it with anyone — no Python installation needed.
html_path = NOTEBOOK_DIR / output_html
fig.write_html(str(html_path), include_plotlyjs="cdn")
print(f"Exported: {html_path} ({html_path.stat().st_size / 1024:.1f} KB)")
Exported: /home/marius_miron_earthspecies_org/code/avex-examples/examples/07_interactive_visualization/embedding_explorer.html (22.2 KB)