Source code for src.toolbox.utils.diagnostics

# This file is part of the NOC Autonomy Toolbox.
#
# Copyright 2025-2026 National Oceanography Centre and The Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A module for diagnostic plotting and data summarization.
"""

import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.dates as mdates
from geopy.distance import geodesic
from toolbox.utils.time import safe_median_datetime, add_datetime_secondary_xaxis
from typing import Dict, List, Optional


[docs] def plot_time_series( data, x_var, y_var, title="Time Series Plot", xlabel=None, ylabel=None, **kwargs ): """Generates a time series plot for xarray data.""" if isinstance(data, xr.Dataset): # Ensure that the variables exist in the xarray dataset if x_var not in data.coords or y_var not in data: raise ValueError( f"Variables {x_var} and {y_var} must exist in the dataset." ) x_data = data[x_var].values # Extract x_data (usually time dimension) y_data = data[y_var].values # Extract the y_data (variable to plot) else: # Assuming custom format such as lists or arrays x_data, y_data = data[0], data[1] plt.figure(figsize=(10, 6)) plt.plot(x_data, y_data, **kwargs) plt.xlabel(xlabel or x_var) plt.ylabel(ylabel or y_var) plt.title(title) plt.show()
[docs] def plot_histogram(data, var, bins=30, title="Histogram", xlabel=None, **kwargs): """Generates a histogram for a given variable in xarray data.""" if isinstance(data, xr.Dataset): # Ensure that the variable exists in the xarray dataset if var not in data: raise ValueError(f"Variable {var} must exist in the dataset.") data_to_plot = data[var].values else: # Handle custom data types like lists or arrays data_to_plot = data plt.figure(figsize=(10, 6)) plt.hist(data_to_plot, bins=bins, alpha=0.7, **kwargs) plt.xlabel(xlabel or var) plt.ylabel("Frequency") plt.title(title) plt.show()
[docs] def plot_boxplot(data, var, title="Box Plot", xlabel=None, **kwargs): """Generates a box plot for a given variable in xarray data.""" if isinstance(data, xr.Dataset): # Ensure that the variable exists in the xarray dataset if var not in data: raise ValueError(f"Variable {var} must exist in the dataset.") data_to_plot = data[var].values else: # Handle custom data types like lists or arrays data_to_plot = data plt.figure(figsize=(10, 6)) sns.boxplot(data=data_to_plot, **kwargs) plt.title(title) plt.xlabel(xlabel or var) plt.show()
[docs] def plot_correlation_matrix(data, variables=None, title="Correlation Matrix", **kwargs): """Generates a heatmap of the correlation matrix for xarray data.""" if isinstance(data, xr.Dataset): if variables is None: variables = list(data.data_vars) # Use all variables by default # Extract the variables to calculate the correlation matrix corr = data[variables].to_array().T.corr(dim="dim_0") else: raise TypeError("Data must be a Xarray Dataset to generate correlation matrix.") plt.figure(figsize=(10, 6)) sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5, **kwargs) plt.title(title) plt.show()
[docs] def generate_info(data): """Generate info for a given dataset""" if isinstance(data, xr.Dataset): # For xarray, we'll summarize each data variable print("Data Info:") print(data.info()) else: print("Data Info only supported for xarray Dataset ")
[docs] def check_missing_values(data): """Check for missing values in the dataset.""" if isinstance(data, xr.Dataset): missing = data.isnull().sum() print("Missing Values in Xarray Dataset:\n", missing) else: print("Missing value check only supported for xarray Dataset ")
#### General Diagnostics Functions ####
[docs] def summarising_profiles(ds: xr.Dataset, source_name: str) -> pd.DataFrame: """ Summarise profiles from an xarray Dataset by computing medians of TIME, LATITUDE, LONGITUDE grouped by PROFILE_NUMBER. Handles datetime median safely using pandas. Parameters ---------- ds : xr.Dataset Input dataset with PROFILE_NUMBER as a coordinate. source_name : str Name of the glider/source to include in output. Returns ------- pd.DataFrame Profile-level summary DataFrame. """ if "PROFILE_NUMBER" not in ds: raise ValueError("Dataset must include PROFILE_NUMBER.") if "PROFILE_NUMBER" not in ds.coords: ds = ds.set_coords("PROFILE_NUMBER") summary_vars = [v for v in ["TIME", "LATITUDE", "LONGITUDE"] if v in ds] medians = {} for var in summary_vars: if var not in ds: continue da = ds[var] if "PROFILE_NUMBER" not in da.coords: da = da.set_coords("PROFILE_NUMBER") grouped = da.groupby("PROFILE_NUMBER") if np.issubdtype(da.dtype, np.datetime64): # Use pandas to compute median datetime safely medians[f"median_{var}"] = grouped.reduce(safe_median_datetime) else: medians[f"median_{var}"] = grouped.median(skipna=True) df = xr.Dataset(medians).to_dataframe().reset_index() df["glider_name"] = source_name df.rename(columns={"PROFILE_NUMBER": "PROFILE_NUMBER"}, inplace=True) # sort by time df.sort_values(by="median_TIME", inplace=True) # also add to the dataset return df
[docs] def find_closest_prof(df_a: pd.DataFrame, df_b: pd.DataFrame) -> pd.DataFrame: """ For each profile in df_a, find the closest profile in df_b based on time, and calculate spatial distance to it. Parameters ---------- df_a : pd.DataFrame Summary dataframe for glider A (reference). df_b : pd.DataFrame Summary dataframe for glider B (comparison). Returns ------- pd.DataFrame df_a with additional columns: - closest_glider_b_profile - glider_b_time_diff - glider_b_distance_km """ a_times = df_a["median_TIME"].values a_lats = df_a["median_LATITUDE"].values a_lons = df_a["median_LONGITUDE"].values b_times = df_b["median_TIME"].values b_lats = df_b["median_LATITUDE"].values b_lons = df_b["median_LONGITUDE"].values b_ids = df_b["PROFILE_NUMBER"].values closest_ids = [] time_diffs = [] distances = [] for a_time, a_lat, a_lon in zip(a_times, a_lats, a_lons): time_diff = np.abs(b_times - a_time) idx = time_diff.argmin() closest_ids.append(b_ids[idx]) time_diffs.append(time_diff[idx]) if np.all(np.isfinite([a_lat, a_lon, b_lats[idx], b_lons[idx]])): dist_km = geodesic((a_lat, a_lon), (b_lats[idx], b_lons[idx])).km else: dist_km = np.nan distances.append(dist_km) df_result = df_a.copy() df_result["closest_glider_b_profile"] = closest_ids df_result["glider_b_time_diff"] = time_diffs df_result["glider_b_distance_km"] = distances return df_result
[docs] def plot_distance_time_grid( summaries: Dict[str, pd.DataFrame], output_path: str = None, show: bool = True, figsize: tuple = (16, 16), ): """ Plot a grid of distance-over-time plots for all glider pair combinations. Parameters ---------- summaries : dict Dictionary of {glider_name: pd.DataFrame} from summarising_profiles(). output_path : str, optional If provided, the grid will be saved to this path. show : bool If True, plt.show() will be called. figsize : tuple Size of the full figure. """ glider_names = list(summaries.keys()) grid_size = len(glider_names) fig, axes = plt.subplots( grid_size, grid_size, figsize=figsize, sharex=True, sharey=True ) fig.suptitle("Distance Between Gliders Over Time", fontsize=18) combined_summaries = [] for i, g_id in enumerate(glider_names): for j, g_b_id in enumerate(glider_names): if g_id == g_b_id: axes[i, j].set_title(f"{g_id} vs {g_b_id} (self-comparison)") if i != 0 or j != len(glider_names) - 1: axes[i, j].axis("off") continue ref_df = summaries[g_id] comp_df = summaries[g_b_id] paired_df = find_closest_prof(ref_df, comp_df) # TODO: ------- Rename column headers and add glider name labels to PROFILE_NUMBER ------- combined_summaries.append(paired_df) ax = axes[i, j] if paired_df.empty: ax.set_title(f"{g_id} vs {g_b_id}\n(no data)") ax.axis("off") continue for name, group in paired_df.groupby("glider_name"): ax.plot( group["median_TIME"], group["glider_b_distance_km"], label=name, marker="o", linestyle="-", ) # Rotate X tick labels for label in ax.get_xticklabels(): label.set_rotation(45) label.set_ha("right") ax.set_title(f"{g_id} vs {g_b_id}") ax.grid(True) # add additional axis if top row or right column if i == 0: add_datetime_secondary_xaxis(ax) if j == grid_size - 1: ax.secondary_yaxis("right") if i == grid_size - 1: ax.set_xlabel("Datetime") if j == 0: ax.set_ylabel("Distance (km)") if i == j: ax.legend(fontsize=8) fig.tight_layout(rect=[0, 0, 1, 0.95]) if output_path: plt.savefig(output_path) print(f"[Diagnostics] Saved glider distance grid to: {output_path}") elif show: plt.show() else: plt.close() return pd.concat(combined_summaries, ignore_index=True)
[docs] def find_candidate_glider_pairs( df_a: pd.DataFrame, df_b: pd.DataFrame, glider_a_name: str, glider_b_name: str, time_thresh_hr: float = 2.0, dist_thresh_km: float = 5.0, ) -> pd.DataFrame: """ Vectorised version: match glider A profiles to glider B profiles within time and space thresholds. Returns one match per glider A profile (closest B match within threshold). """ if df_a.empty or df_b.empty: return pd.DataFrame() # Ensure datetime format df_a["median_datetime"] = pd.to_datetime(df_a["median_TIME"]) df_b["median_datetime"] = pd.to_datetime(df_b["median_TIME"]) # Cartesian join: every profile A against every profile B df_a["_key"] = 1 df_b["_key"] = 1 df_cross = pd.merge(df_a, df_b, on="_key", suffixes=("_a", "_b")).drop( columns="_key" ) # Time difference df_cross["time_diff_hr"] = ( np.abs( ( df_cross["median_datetime_a"] - df_cross["median_datetime_b"] ).dt.total_seconds() ) / 3600.0 ) # Filter time threshold early df_cross = df_cross[df_cross["time_diff_hr"] <= time_thresh_hr] if df_cross.empty: return pd.DataFrame() # Vectorised geodesic distance using np.vectorize def compute_dist_km(lat_a, lon_a, lat_b, lon_b): if pd.isna(lat_a) or pd.isna(lon_a) or pd.isna(lat_b) or pd.isna(lon_b): return np.nan return geodesic((lat_a, lon_a), (lat_b, lon_b)).km dist_func = np.vectorize(compute_dist_km) df_cross["dist_km"] = dist_func( df_cross["median_LATITUDE_a"], df_cross["median_LONGITUDE_a"], df_cross["median_LATITUDE_b"], df_cross["median_LONGITUDE_b"], ) # Filter by distance threshold TODO: ---------- SAVE THIS -------------- df_cross = df_cross[df_cross["dist_km"] <= dist_thresh_km] if df_cross.empty: return pd.DataFrame() # Keep only best match (min dist) per PROFILE_NUMBER_a TODO: ------------ CHECK IF NECESSARY ------------ best_matches = df_cross.loc[ df_cross.groupby("PROFILE_NUMBER_a")["dist_km"].idxmin() ].copy() # Return clean structure best_matches = best_matches.rename( columns={ "glider_name": "glider_a_name", "PROFILE_NUMBER_a": "glider_a_PROFILE_NUMBER", "PROFILE_NUMBER_b": "glider_b_PROFILE_NUMBER", } ) best_matches["glider_a_name"] = glider_a_name best_matches["glider_b_name"] = glider_b_name return best_matches[ [ "glider_a_PROFILE_NUMBER", "glider_a_name", "glider_b_PROFILE_NUMBER", "glider_b_name", "time_diff_hr", "dist_km", ] ].reset_index(drop=True)
[docs] def plot_heatmap_glider_df( ax, matchup_df: pd.DataFrame, time_bins: np.ndarray, dist_bins: np.ndarray, glider_a_name: str, glider_b_name: str, i: int, j: int, grid_size: int, ): """ Plot cumulative 2D histogram of time/distance matchups for a glider pair on a given axis. """ if matchup_df.empty: ax.set_title(f"{glider_a_name} vs {glider_b_name} (no matches)") ax.axis("off") return H, xedges, yedges = np.histogram2d( matchup_df["time_diff_hr"], matchup_df["dist_km"], bins=[time_bins, dist_bins] ) H_cum = H.cumsum(axis=0).cumsum(axis=1) X, Y = np.meshgrid(yedges, xedges) im = ax.pcolormesh(X, Y, H_cum, cmap="PuBu", shading="auto") # add additional axis if top row or right column if i == 0: secax = ax.secondary_xaxis("top") secax.set_xlabel("Distance Threshold (km)") if j == grid_size - 1: secax = ax.secondary_yaxis("right") secax.set_ylabel("Time Threshold (hr)") if i == grid_size - 1: ax.set_xlabel("Distance Threshold (km)") if j == 0: ax.set_ylabel("Time Threshold (hr)") ax.set_title(f"{glider_a_name} vs {glider_b_name}") # Annotate values for i in range(H_cum.shape[0]): for j in range(H_cum.shape[1]): val = int(H_cum[i, j]) if val > 0: x_center = (yedges[j] + yedges[j + 1]) / 2 y_center = (xedges[i] + xedges[i + 1]) / 2 color = "white" if val > H_cum.max() / 2 else "black" ax.text( x_center, y_center, str(val), ha="center", va="center", fontsize=7, color=color, )
[docs] def plot_glider_pair_heatmap_grid( summaries: Dict[str, pd.DataFrame], time_bins: np.ndarray, dist_bins: np.ndarray, output_path: Optional[str] = None, show: bool = True, figsize: tuple = (16, 16), ): """ Generate an NxN grid of cumulative heatmaps for all glider pair combinations. """ glider_names = list(summaries.keys()) grid_size = len(glider_names) fig, axes = plt.subplots( grid_size, grid_size, figsize=figsize, sharex=True, sharey=True ) fig.suptitle("Heatmap of Matchups Between Gliders", fontsize=18) for i, g_a in enumerate(glider_names): for j, g_b in enumerate(glider_names): df_a = summaries[g_a] df_b = summaries[g_b] if g_a == g_b: axes[i, j].axis("off") continue ax = axes[i, j] matches = find_candidate_glider_pairs( df_a, df_b, glider_a_name=g_a, glider_b_name=g_b, time_thresh_hr=max(time_bins), dist_thresh_km=max(dist_bins), ) plot_heatmap_glider_df( ax=ax, matchup_df=matches, time_bins=time_bins, dist_bins=dist_bins, glider_a_name=g_a, glider_b_name=g_b, i=i, j=j, grid_size=grid_size, ) fig.tight_layout(rect=[0, 0, 1, 0.95]) if output_path: plt.savefig(output_path) print(f"[Diagnostics] Saved glider heatmap grid to: {output_path}") elif show: plt.show() else: plt.close()