Source code for src.toolbox.utils.validation

# This file is part of the NOC Autonomy Toolbox.
#
# Copyright 2025-2026 National Oceanography Centre and The Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# validation.py

import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

from toolbox.utils.diagnostics import summarising_profiles

from toolbox.utils.alignment import (
    interpolate_DEPTH,
    aggregate_vars,
    find_profile_pair_metadata,
    merge_pairs_from_filtered_aggregates,
    compute_r2_for_merged_profiles_xr,
    plot_r2_heatmaps_per_pair,
    filter_xarray_by_profile_ids,
)

from testing.sandbox import target_ds_raw


[docs] def load_device_folder_to_xarray( path_or_glob, alias_map=None, depth_candidates=("DEPTH", "depth"), time_candidates=("TIME", "time", "DateTime", "datetime"), lat_candidates=("LATITUDE", "latitude", "lat"), lon_candidates=("LONGITUDE", "longitude", "lon"), profile_start_index=1, ): """ Read many device NetCDF files and combine into one xarray.Dataset with: - dim: N_MEASUREMENTS - vars/coords: PROFILE_NUMBER (int), DEPTH, TIME, LATITUDE, LONGITUDE, + data cols """ if os.path.isdir(path_or_glob): files = [ os.path.join(r, f) for r, _, fs in os.walk(path_or_glob) for f in fs if f.endswith((".nc", ".nc4", ".cdf", ".netcdf")) ] else: files = [ f for f in glob.glob(path_or_glob) if f.endswith((".nc", ".nc4", ".cdf", ".netcdf")) ] files = sorted(files) if not files: raise FileNotFoundError(f"No NetCDF files found for '{path_or_glob}'") def _pick(ds, cands): # choose first available var name for c in cands: if c in ds: return c return None rows = [] pid = int(profile_start_index) for fp in files: ds = xr.open_dataset(fp, decode_times=True) if alias_map: to_rename = {k: v for k, v in alias_map.items() if k in ds} if to_rename: ds = ds.rename(to_rename) depth_col = _pick(ds, depth_candidates) or "DEPTH" time_col = _pick(ds, time_candidates) lat_col = _pick(ds, lat_candidates) lon_col = _pick(ds, lon_candidates) df = ds.to_dataframe().reset_index() n = len(df) if time_col is None and "TIME" not in df.columns: # try attributes or 0-D variables t = ds.attrs.get("time") or ds.attrs.get("TIME") if t is not None: df["TIME"] = np.repeat(pd.to_datetime(t), n) elif time_col is not None: df["TIME"] = pd.to_datetime(df[time_col], errors="coerce") if lat_col is not None and "LATITUDE" not in df.columns: df["LATITUDE"] = df[lat_col] if lon_col is not None and "LONGITUDE" not in df.columns: df["LONGITUDE"] = df[lon_col] if depth_col != "DEPTH" and "DEPTH" not in df.columns: df["DEPTH"] = pd.to_numeric(df[depth_col], errors="coerce") df["PROFILE_NUMBER"] = pid rows.append(df) pid += 1 ds.close() big = pd.concat(rows, ignore_index=True) ds_out = xr.Dataset( {c: (("N_MEASUREMENTS",), big[c].to_numpy()) for c in big.columns}, coords={"N_MEASUREMENTS": np.arange(len(big))}, ) ds_out["PROFILE_NUMBER"] = ds_out["PROFILE_NUMBER"].astype("int32") if "PROFILE_NUMBER" not in ds_out.coords: ds_out = ds_out.set_coords("PROFILE_NUMBER") if "TIME" in ds_out: ds_out["TIME"] = xr.DataArray( pd.to_datetime(big["TIME"]).to_numpy(), dims=("N_MEASUREMENTS",) ) for k in ("LATITUDE", "LONGITUDE", "DEPTH"): if k in ds_out: ds_out[k] = ds_out[k].astype("float64") print( f"[Device] Loaded {len(files)} files → {ds_out.sizes['N_MEASUREMENTS']} rows, " f"{len(np.unique(ds_out['PROFILE_NUMBER'].values))} profiles." ) # output variables print( f"Variables: {', '.join([v for v in ds_out.data_vars if v != 'PROFILE_NUMBER'])}" ) return ds_out
[docs] def validate(pmanager, target="None"): """ End-to-end validation using settings.validation: - load device NetCDFs - summarise & pair profiles - (re)use cached target medians if available; otherwise interpolate+aggregate once - interpolate, bin, aggregate device (2-D medians) - merge per-pair on depth bins - compute per-pair R² - plot heatmaps per variable using plot_r2_heatmaps_per_pair """ # --- config --- vcfg = pmanager.settings.get("validation", {}) or {} device_name = vcfg.get("device_name", "DEVICE") variables = vcfg.get("variable_names", list(pmanager.alignment_map.keys())) folder_path = vcfg.get("folder_path", "") max_time_hr = vcfg.get("max_time_threshold", 12) max_dist_km = vcfg.get("max_distance_threshold", 10) var_r2_criteria = vcfg.get("variable_r2_criteria", {}) save_plots = bool(vcfg.get("save_plots", False)) show_plots = bool(vcfg.get("show_plots", True)) plot_output_path = vcfg.get("plot_output_path", "") apply_and_save = bool(vcfg.get("apply_and_save", False)) out_path = vcfg.get("output_path", "") # ---- Target: prefer cached aggregated medians from preview_alignment() ---- # Check the target(s) exist if type(target) is not list: target_name = target target = [target] else: target_name = "_".join(target) for platform in target: if platform not in pmanager.pipelines or platform not in pmanager._contexts: raise ValueError(f"Target '{platform}' not available.") # Make merged dataset from all of the targets to_merge = [] for platform in target: # Append the target name to the profile number raw_ds = pmanager._contexts[platform]["data"][ ["PROFILE_NUMBER", "DEPTH", "TIME", "LATITUDE", "LONGITUDE"] + variables ] raw_ds["PROFILE_NUMBER"] = (("N_MEASUREMENTS",), raw_ds["PROFILE_NUMBER"].values.astype("str") + f"_{platform}") # Remap the variable names if specified rename_map = { alias: std for std, alias_map in pmanager.alignment_map.items() if (alias := alias_map.get(platform)) and alias in raw_ds } raw_ds.rename(rename_map) to_merge.append(raw_ds) target_ds_raw = to_merge[0] target_ds_raw = target_ds_raw.assign_coords({"N_MEASUREMENTS": target_ds_raw["N_MEASUREMENTS"]}) if len(to_merge) > 1: for ds in to_merge[1:]: offset = len(target_ds_raw["N_MEASUREMENTS"]) ds = ds.assign_coords( N_MEASUREMENTS=ds["N_MEASUREMENTS"] + offset ) target_ds_raw = xr.concat([target_ds_raw, ds], dim="N_MEASUREMENTS") # Create caches if not present if not hasattr(pmanager, "processed_per_glider"): pmanager.processed_per_glider = {} if not hasattr(pmanager, "_exportables"): pmanager._exportables = {"raw": {}, "processed": {}, "lite": {}} # Use cached medians if available, else compute once and cache if ( target_name in pmanager.processed_per_glider and "agg" in pmanager.processed_per_glider[target_name] ): t_med = pmanager.processed_per_glider[target_name]["agg"] else: # standardize names → interpolate → aggregate → cache + export handle t_interp = interpolate_DEPTH(target_ds_raw) t_med = aggregate_vars(t_interp, variables) # dims: PROFILE_NUMBER, DEPTH_bin pmanager.processed_per_glider[target_name] = { "interp": t_interp, "agg": t_med, } pmanager._exportables["raw"][target_name] = target_ds_raw pmanager._exportables["processed"][target_name] = t_med # ---- Device: load & aggregate (external; not part of pipelines) ---- device_alias = vcfg.get("aliases", None) # {STD: device_col} # Convert to {device_col: STD} for loader renaming dev_to_std = ( {dev: std for std, dev in device_alias.items() if dev} if device_alias else None ) device_ds_raw = load_device_folder_to_xarray(folder_path, alias_map=dev_to_std) # Summaries (for pairing) target_summary = summarising_profiles(target_ds_raw, target_name).reset_index(drop=True) device_summary = summarising_profiles(device_ds_raw, device_name).reset_index( drop=True ) # Pairs paired_df = find_profile_pair_metadata( df_target=target_summary, df_ancillary=device_summary, target_name=target_name, ancillary_name=device_name, time_thresh_hr=max_time_hr, dist_thresh_km=max_dist_km, ) if paired_df.empty: print("[Validation] No matched target/device profiles found.") return {"paired_df": paired_df, "r2_ds": None, "merged": None} print(f"[Validation] Matched {len(paired_df)} pairs with {device_name}.") # Device medians (computed here) d_interp = interpolate_DEPTH(device_ds_raw) d_med = aggregate_vars(d_interp, variables) # dims: PROFILE_NUMBER, DEPTH_bin # IDs from the pairs t_ids = paired_df[f"{target_name}_PROFILE_NUMBER"].values d_ids = paired_df[f"{device_name}_PROFILE_NUMBER"].values # filter to just those profiles (works on aggregated 2-D medians) t_med = filter_xarray_by_profile_ids(t_med, "PROFILE_NUMBER", t_ids) d_med = filter_xarray_by_profile_ids(d_med, "PROFILE_NUMBER", d_ids) # Trim pairs to those actually present in the aggregated sets t_present = set(t_med["PROFILE_NUMBER"].values.tolist()) d_present = set(d_med["PROFILE_NUMBER"].values.tolist()) mask_pairs = paired_df[f"{target_name}_PROFILE_NUMBER"].isin(t_present) & paired_df[ f"{device_name}_PROFILE_NUMBER" ].isin(d_present) paired_df = paired_df.loc[mask_pairs].reset_index(drop=True) # Log what was dropped dropped_t = set(t_ids) - t_present dropped_d = set(d_ids) - d_present if dropped_t: print( f"[Validation] Dropped {len(dropped_t)} target profiles with no aggregated data." ) if dropped_d: print( f"[Validation] Dropped {len(dropped_d)} device profiles with no aggregated data." ) # Merge pairs → dims: PAIR_INDEX, DEPTH_bin merged = merge_pairs_from_filtered_aggregates( paired_df=paired_df, agg_target=t_med, agg_anc=d_med, target_name=target_name, ancillary_name=device_name, variables=variables, bin_dim="DEPTH_bin", pair_dim="PAIR_INDEX", ) print(f"[Validation] Merged data has {merged.sizes['PAIR_INDEX']} pairs.") # R² per pair r2_ds = compute_r2_for_merged_profiles_xr( merged, variables=variables, target_name=target_name, ancillary_name=device_name ) # Plot heatmaps with the shared helper align_cfg = pmanager.settings.get("alignment", {}) or {} r2_thresholds = align_cfg.get( "r2_thresholds", [0.99, 0.95, 0.90, 0.85, 0.80, 0.75, 0.70] ) r2_datasets_for_plot = {device_name: r2_ds} plot_r2_heatmaps_per_pair( r2_datasets=r2_datasets_for_plot, variables=variables, target_name=target_name, r2_thresholds=r2_thresholds, time_thresh_hr=max_time_hr, dist_thresh_km=max_dist_km, figsize=(9, 6), save_plots=save_plots, output_path=plot_output_path or None, show_plots=show_plots, ) return {"paired_df": paired_df, "merged": merged, "r2_ds": r2_ds}