# This file is part of the NOC Autonomy Toolbox.
#
# Copyright 2025-2026 National Oceanography Centre and The Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from toolbox.utils.config_mirror import ConfigMirrorMixin
import os
import yaml
import pandas as pd
import numpy as np
import xarray as xr
import datetime as _dt
from toolbox.pipeline import Pipeline
from toolbox.utils.diagnostics import (
summarising_profiles,
plot_distance_time_grid,
plot_glider_pair_heatmap_grid,
)
from toolbox.utils.alignment import (
interpolate_DEPTH,
aggregate_vars,
merge_pairs_from_filtered_aggregates,
filter_xarray_by_profile_ids,
find_profile_pair_metadata,
compute_r2_for_merged_profiles_xr,
plot_r2_heatmaps_per_pair,
plot_pair_scatter_grid,
collect_xy_from_r2_ds,
fit_linear_map,
)
from toolbox.utils.validation import validate
[docs]
class PipelineManager(ConfigMirrorMixin):
"""A class enabling the execution of multiple pipelines in sequence."""
def __init__(self):
# init regular state
[docs]
self.pipelines = {} # {pipeline_name: Pipeline instance}
[docs]
self.alignment_map = {} # {standard_name: {pipeline_name: alias}}
self._contexts = {}
self._summary_ran = False
# NEW: private config
self._init_config_mirror()
[docs]
def load_mission_control(self, config_path, mirror_keys=None):
"""
Load MissionControl YAML into private self._parameters.
- Builds pipelines
- Builds alignment_map
- Mirrors selected keys as attributes (e.g., 'settings')
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f) or {}
# 1) Store full mission config in private _parameters
self.load_config(config, mirror_keys=mirror_keys or ["settings"])
# 2) Build pipelines (also load each pipeline's config file into its own private store)
for entry in self._parameters.get("pipelines", []) or []:
name = entry["name"]
cfg_path = entry["config"]
self.add_pipeline(name, cfg_path)
# 3) Alignment aliases → alignment_map
alignment_vars = (
self._parameters.get("alignment", {}).get("variables", {}) or {}
)
self.alignment_map = {
std: (details or {}).get("aliases", {}) or {}
for std, details in alignment_vars.items()
}
# 4) Mirror settings (and any other mirrored keys) into attributes
self._reset_parameter_bridge(mirror_keys=self._param_attr_keys or {"settings"})
[docs]
def add_pipeline(self, name, config_path):
"""Add a single pipeline with a unique name."""
if name in self.pipelines:
raise ValueError(f"Pipeline '{name}' already added.")
pl = Pipeline(config_path) # assumes Pipeline accepts path (see section C)
self.pipelines[name] = pl
print(f"[Pipeline Manager] Pipeline '{name}' added from {config_path}.")
[docs]
def save_manager_config(self, path: str):
"""Save MissionControl/Manager config from self._parameters."""
self.save_config(path)
[docs]
def save_pipeline_configs(self, out_dir: str, filename="{name}.yaml"):
"""
Ask each Pipeline to write its private config to YAML.
The pipeline file content comes from pipeline._parameters (including its steps).
"""
os.makedirs(out_dir, exist_ok=True)
for name, pl in self.pipelines.items():
out = os.path.join(out_dir, filename.format(name=name))
pl.save_config(out)
print(f"[Config] Saved pipeline configs → {out_dir}")
[docs]
def save_all_configs(
self, manager_path: str, pipelines_dir: str, pipeline_filename="{name}.yaml"
):
"""Convenience: save manager config and all pipeline configs."""
self.save_manager_config(manager_path)
self.save_pipeline_configs(pipelines_dir, filename=pipeline_filename)
[docs]
def run_all(self):
"""Run all registered pipelines and cache the resulting contexts."""
for name, pipeline in self.pipelines.items():
print("#" * 20)
print(f"Running pipeline: {name}")
pipeline.run()
self._contexts = self.get_contexts()
print("#" * 20)
print("All pipelines executed successfully.")
print(f"Contexts cached: {self._contexts.keys()}")
print("#" * 20)
[docs]
def get_contexts(self):
"""Retrieve the context dictionary from each pipeline."""
return {name: p._context for name, p in self.pipelines.items()}
[docs]
def load_data(self, filepath, platform_name):
context = {
"data": xr.load_dataset(filepath)
}
self._contexts[platform_name] = context
print(f"[Pipeline Manager] {platform_name} sucessfully added added")
[docs]
def summarise_all_profiles(self) -> pd.DataFrame:
"""
For all pipelines, summarise profiles and plot glider-to-glider distance time series.
This includes:
- Computing median TIME, LATITUDE, LONGITUDE per profile
- Matching each profile to its closest in time from another source
- Plotting a distance grid comparing all gliders
Returns
-------
pd.DataFrame
Concatenated summary of all glider profiles, with closest match info appended.
"""
self._summary_ran = True
if self._contexts is None:
raise RuntimeError("Pipelines must be run before generating summaries.")
print("[Pipeline Manager] Generating glider distance summaries...")
# Step 1: Generate per-glider summaries
self.summary_per_glider = {}
for pipeline_name, context in self._contexts.items():
ds = context["data"]
if not isinstance(ds, xr.Dataset):
raise TypeError(f"Pipeline '{pipeline_name}' has invalid dataset.")
else:
print(f"[Pipeline Manager] Processing dataset for {pipeline_name}...")
summary_df = summarising_profiles(ds, pipeline_name)
print("Summary Columns:", summary_df.columns.tolist())
self.summary_per_glider[pipeline_name] = summary_df
# Step 2: Find closest profiles across gliders
# Extract diagnostic flags from settings
show_plots = self.settings.get("diagnostics", {}).get("show_plots", True)
save_plots = self.settings.get("diagnostics", {}).get("save_plots", False)
distance_over_time_matrix = self.settings.get("diagnostics", {}).get(
"distance_over_time_matrix", False
)
self.matchup_thresholds = self.settings.get("diagnostics", {}).get(
"matchup_thresholds", {}
)
max_time_threshold = (
self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("max_time_threshold", 12)
)
max_distance_threshold = (
self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("max_distance_threshold", 20)
)
bin_size = (
self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("bin_size", 2)
)
if not distance_over_time_matrix:
print("[Pipeline Manager] Distance over time matrix is disabled.")
else:
print("[Pipeline Manager] Plotting distance time grid...")
# After generating all summaries...
combined_summaries = plot_distance_time_grid(
summaries=self.summary_per_glider,
output_path=self.settings.get("diagnostics", {}).get(
"distance_plot_output", None
),
show=self.settings.get("diagnostics", {}).get("show_plots", True),
)
if not self.matchup_thresholds:
print(
"[Pipeline Manager] Matchup thresholds are not set. Skipping heatmap grid."
)
else:
print("[Pipeline Manager] Finding closest profiles across gliders...")
# compute time taken for caluclations
start_time = pd.Timestamp.now()
plot_glider_pair_heatmap_grid(
summaries=self.summary_per_glider,
time_bins=np.arange(0, max_time_threshold + 1, bin_size),
dist_bins=np.arange(0, max_distance_threshold + 1, bin_size),
output_path=self.settings.get("diagnostics", {}).get(
"heatmap_output", None
),
show=self.settings.get("diagnostics", {}).get("show_plots", True),
)
end_time = pd.Timestamp.now()
print(f"[Pipeline Manager] Heatmap grid plotted in {end_time - start_time}")
return
[docs]
def preview_alignment(self, target="None"):
"""
Align all datasets to a target dataset and compute R² against ancillary sources.
This version:
- Renames each pipeline's variables to the standard names (from alignment_map)
- Runs interpolate + aggregate ONCE per pipeline and caches the results
- Uses the cached medians for pairing/merging/R²
- Populates exportable handles for raw/processed/lite data
"""
# === PRECONDITIONS ===
if not self._summary_ran:
raise RuntimeError("Run summarise_all_profiles() before alignment.")
if target not in self.pipelines:
raise ValueError(f"Target pipeline '{target}' not found.")
# === CONFIG ===
alignment_vars = list(self.alignment_map.keys())
self.r2_datasets = {} # Reset R² result container
# ---- Helper: alias -> std renamer for a given pipeline name ----
def _rename_to_standard(name: str, ds):
rename_map = {
alias: std
for std, alias_map in self.alignment_map.items()
if (alias := alias_map.get(name)) and alias in ds
}
return ds.rename(rename_map) if rename_map else ds, rename_map
if not hasattr(self, "processed_per_glider"):
self.processed_per_glider = {}
# export registry the rest of your workflow can use later to write files
if not hasattr(self, "_exportables"):
self._exportables = {"raw": {}, "processed": {}, "lite": {}}
# === COLLECT: standardised & processed datasets for ALL pipelines (target + ancillaries) ===
for name, ctx in self._contexts.items():
raw_ds = ctx["data"]
# Keep a pointer to raw data for export (no copy)
self._exportables["raw"][name] = raw_ds
# If we already processed this pipeline, skip recomputation
if (
name in self.processed_per_glider
and "agg" in self.processed_per_glider[name]
):
continue
# 1) rename variables to standard names
ds_std, used_map = _rename_to_standard(name, raw_ds)
# 2) interpolate depth
print(f"[Pipeline Manager] Interpolating DEPTH for '{name}'...")
ds_interp = interpolate_DEPTH(ds_std)
# 3) aggregate medians (2-D by PROFILE_NUMBER × DEPTH_bin)
print(f"[Pipeline Manager] Aggregating medians for '{name}'...")
ds_agg = aggregate_vars(ds_interp, alignment_vars)
# store in cache
self.processed_per_glider[name] = {
"renamed": ds_std,
"interp": ds_interp,
"agg": ds_agg,
}
# make processed export handle available
self._exportables["processed"][name] = ds_agg
if "lite" not in self._exportables:
self._exportables["lite"] = {}
# Prepare target objects
target_summary = self.summary_per_glider[target].reset_index()
target_agg = self.processed_per_glider[target]["agg"]
# === LOOP: align each ancillary to target using the cached medians ===
for ancillary_name, ctx in self._contexts.items():
if ancillary_name == target:
continue
print(
f"\n[Pipeline Manager] Aligning '{ancillary_name}' to target '{target}'..."
)
ancillary_summary = self.summary_per_glider[ancillary_name]
if ancillary_summary.index.names[0] is not None:
ancillary_summary = ancillary_summary.reset_index()
# === STEP 1: Find Matched Profile Pairs ===
paired_df = find_profile_pair_metadata(
df_target=target_summary,
df_ancillary=ancillary_summary,
target_name=target,
ancillary_name=ancillary_name,
time_thresh_hr=self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("max_time_threshold", 12),
dist_thresh_km=self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("max_distance_threshold", 20),
)
if paired_df.empty:
print(
f"[Pipeline Manager] No matched profiles between {target} and {ancillary_name}."
)
continue
print(f"[Pipeline Manager] Found {len(paired_df)} matched profile pairs.")
# === STEP 2: Use CACHED aggregated medians ===
binned_ds = {
target: target_agg,
ancillary_name: self.processed_per_glider[ancillary_name]["agg"],
}
# === STEP 3: Filter Datasets by Matched Profile IDs ===
filtered_ds = {}
for glider_name, agg_ds in [
(target, binned_ds[target]),
(ancillary_name, binned_ds[ancillary_name]),
]:
profile_ids = paired_df[f"{glider_name}_PROFILE_NUMBER"].values
filtered_ds[glider_name] = filter_xarray_by_profile_ids(
ds=agg_ds,
profile_id_var="PROFILE_NUMBER",
valid_ids=profile_ids,
)
# === STEP 4: Build pairwise merged dataset ===
merged = merge_pairs_from_filtered_aggregates(
paired_df=paired_df,
agg_target=filtered_ds[target],
agg_anc=filtered_ds[ancillary_name],
target_name=target,
ancillary_name=ancillary_name,
variables=alignment_vars, # the raw names; helper will use median_{var}
)
print("[Align] Merged dims:", merged.dims)
print("[Align] Vars:", list(merged.data_vars))
# === STEP 5: Compute R² ===
print(f"[Pipeline Manager] Computing R² for '{ancillary_name}'...")
r2_ds = compute_r2_for_merged_profiles_xr(
ds=merged,
variables=alignment_vars,
target_name=target,
ancillary_name=ancillary_name,
)
self.r2_datasets[ancillary_name] = r2_ds
# add to cache
self._exportables["processed"][f"{target}_vs_{ancillary_name}"] = r2_ds
print(
f"[Pipeline Manager] R² dataset stored for '{target}' vs '{ancillary_name}'."
)
print("\n[Pipeline Manager] Alignment complete for all datasets.")
# Set R² thresholds
r2_thresholds = self.settings.get("alignment", {}).get(
"r2_thresholds", [0.99, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7]
)
# Call the plotting function
# r2_datasets produced by align_to_target
plot_r2_heatmaps_per_pair(
r2_datasets=self.r2_datasets,
variables=list(self.alignment_map.keys()),
target_name=target, # e.g. "Doombar"
r2_thresholds=r2_thresholds,
time_thresh_hr=self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("max_time_threshold", 12),
dist_thresh_km=self.settings.get("diagnostics", {})
.get("matchup_thresholds", {})
.get("max_distance_threshold", 20),
figsize=(9, 6),
save_plots=self.settings.get("alignment", {}).get("save_plots", False),
output_path=self.settings.get("alignment", {}).get(
"plot_output_path", "r2_heatmap_grid.png"
),
show_plots=self.settings.get("alignment", {}).get("show_plots", True),
)
[docs]
def fit_and_save_to_target(
self,
target,
out_dir=None,
variable_r2_criteria=None,
max_time_hr=None,
max_dist_km=None,
ancillaries=None,
overwrite=False,
show_plots=True,
):
"""
Fit ancillary variables to target datasets using profile-pair medians and per-variable R² criteria.
"""
if out_dir is None:
# get directory from settings or create timestamped dir
datetime_str = _dt.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
out_dir = self.settings.get("alignment", {}).get(
"output_path", datetime_str
)
if not getattr(self, "r2_datasets", None):
raise RuntimeError("Run preview_alignment() before fitting.")
if target not in self.pipelines or target not in self._contexts:
raise ValueError(f"Target '{target}' not available in manager contexts.")
alignment_vars = list(self.alignment_map.keys())
if variable_r2_criteria is None:
variable_r2_criteria = self.settings.get("alignment", {}).get(
"variable_r2_criteria", {}
)
missing = [v for v in alignment_vars if v not in variable_r2_criteria]
if missing:
raise ValueError(f"Missing R² thresholds for variables: {missing}")
os.makedirs(out_dir, exist_ok=True)
all_sources = [n for n in self.pipelines.keys() if n != target]
anc_list = list(ancillaries) if ancillaries else all_sources
if show_plots and hasattr(self, "plot_pair_scatter_grid"):
try:
plot_pair_scatter_grid(
r2_datasets=self.r2_datasets,
variables=alignment_vars,
target_name=target,
variable_r2_criteria=variable_r2_criteria,
max_time_hr=max_time_hr,
max_dist_km=max_dist_km,
)
except Exception as e:
print(f"[Fit] (Plot) Skipped grid due to: {e}")
def _alias_map_for(source_name):
return {
alias: std
for std, mapping in self.alignment_map.items()
if (alias := mapping.get(source_name))
}
saved_paths = {}
fits_summary = {}
for anc in anc_list:
if anc not in self._contexts:
print(f"[Fit] Skipping '{anc}' (no context).")
continue
print(f"\n[Fit] === {anc} → align to {target} ===")
r2_ds = self.r2_datasets.get(anc)
if r2_ds is None or not isinstance(r2_ds, xr.Dataset):
print(f"[Fit] No R² dataset for '{anc}'. Skipping.")
continue
anc_ds = self._contexts[anc]["data"]
rename_map = _alias_map_for(anc)
anc_ds_std = anc_ds.rename(rename_map) if rename_map else anc_ds
# Compute per-variable fits
anc_fits = {}
for var in alignment_vars:
x, y = collect_xy_from_r2_ds(
r2_ds,
var=var,
target_name=target,
ancillary_name=anc,
r2_min=variable_r2_criteria.get(var),
time_max=max_time_hr,
dist_max=max_dist_km,
)
fit = fit_linear_map(x, y)
anc_fits[var] = fit
print(
f"[Fit] {anc}:{var} slope={fit['slope']:.4g} intercept={fit['intercept']:.4g} "
f"R²={fit['r2']:.3f} N={fit['n']}"
)
# Apply to full ancillary dataset (creates {VAR}_ALIGNED_TO_{target})
ds_out = anc_ds_std.copy()
created_vars = []
for var in alignment_vars:
if var not in ds_out:
print(f"[Fit] [{anc}] missing '{var}' — skip.")
continue
slope = anc_fits[var]["slope"]
intercept = anc_fits[var]["intercept"]
npts = anc_fits[var]["n"]
out_name = f"{var}_ALIGNED_TO_{target}"
aligned = slope * ds_out[var] + intercept
aligned = aligned.astype(ds_out[var].dtype, copy=False)
aligned.name = out_name
aligned.attrs.update(
{
"long_name": f"{var} aligned to {target}",
"alignment_target": target,
"alignment_source": anc,
"alignment_slope": float(slope),
"alignment_intercept": float(intercept),
"alignment_fit_points": int(npts),
"alignment_generated": _dt.datetime.utcnow().isoformat() + "Z",
}
)
ds_out[out_name] = aligned
created_vars.append(out_name)
if not created_vars:
print(f"[Fit] No aligned variables for '{anc}'. Skipping save.")
continue
# Keep aligned vars in memory
try:
self._contexts[anc]["data"] = ds_out
except Exception:
pass
# immediately build & cache aggregated aligned medians for this ancillary
try:
_ = self._aggregate_aligned_vars_for_ancillary(
anc_name=anc, target=target, vars_to_aggregate=alignment_vars
) # populates self.processed_per_glider[anc][f'agg_aligned_to_{target}']
except Exception as e:
print(
f"[Fit] Warning: failed to cache aggregated aligned medians for '{anc}': {e}"
)
# Save per-ancillary file
out_path = os.path.join(out_dir, f"{anc}_aligned_to_{target}.nc")
if (not overwrite) and os.path.exists(out_path):
print(f"[Fit] File exists, not overwriting: {out_path}")
else:
encoding = {
name: {"zlib": True, "complevel": 2} for name in created_vars
}
try:
ds_out.to_netcdf(out_path, encoding=encoding)
print(f"[Fit] Saved: {out_path}")
saved_paths[anc] = out_path
fits_summary[anc] = anc_fits
except Exception as e:
print(f"[Fit] Failed to save '{anc}': {e}")
# Cache fits for metadata
if anc in self.processed_per_glider:
self.processed_per_glider[anc][
f"last_fit_to_target_{target}"
] = anc_fits
return {"paths": saved_paths, "fits": fits_summary}
[docs]
def validate_with_device(self, target="None", **overrides):
"""
Run the validation workflow using settings['validation'].
Optionally pass keyword overrides (e.g., show_plots=False) for this call only.
Examples:
mngr.validate_with_device("Doombar")
mngr.validate_with_device("Doombar", show_plots=False, apply_and_save=True)
"""
# Fast path: no overrides → just call through
if not overrides:
validate(self, target=target)
return
# One-shot overrides: temporarily update settings['validation']
vcfg_orig = dict(self.settings.get("validation", {})) # shallow copy
try:
vcfg = self.settings.setdefault("validation", {})
vcfg.update(overrides)
validate(self, target=target)
finally:
# restore original validation config
self.settings["validation"] = vcfg_orig
[docs]
def fit_to_device(self, target="None"):
"""
Fit TARGET variables to a validation device using profile-pair medians and per-variable R² criteria.
The mapping is fit as: device = slope * target + intercept, then applied to the FULL target dataset
to create new variables `{VAR}_ALIGNED_TO_{DEVICE}`.
Reads options from self.settings['validation']:
validation:
device_name: "<device label>"
variable_names: ["CNDC","TEMP", ...] # optional; defaults to alignment_map keys
variable_r2_criteria: {CNDC: 0.95, TEMP: 0.9, ...}
max_time_threshold: <float>
max_distance_threshold: <float>
save_plots: <bool>
show_plots: <bool>
plot_output_path: "<file or dir>"
apply_and_save: <bool>
output_path: "<dir or empty for timestamped dir>"
Returns
-------
dict with:
- "path": output NetCDF (if saved)
- "fits": {var: {slope, intercept, r2, n}, ...}
- "device_name": device label used
"""
# --- Preconditions ---
if type(target) == str:
if target not in self.pipelines:
raise ValueError(f"Target pipeline '{target}' not found.")
if target not in self._contexts:
raise ValueError(f"Target pipeline '{target}' has no context data.")
target_name = target
elif type(target) == list:
for platform in target:
if platform not in self.pipelines or platform not in self._contexts:
raise ValueError(f"Target '{platform}' not available.")
target_name = "_".join(target)
# --- Validation config ---
vcfg = self.settings.get("validation", {}) or {}
device_name = vcfg.get("device_name", "DEVICE")
variables = vcfg.get("variable_names", list(self.alignment_map.keys()))
var_r2_criteria = vcfg.get("variable_r2_criteria", {}) or {}
max_time_hr = vcfg.get("max_time_threshold", None)
max_dist_km = vcfg.get("max_distance_threshold", None)
show_plots = bool(vcfg.get("show_plots", True))
save_plots = bool(vcfg.get("save_plots", False))
plot_output_path = vcfg.get("plot_output_path", "device_fit_scatter_grid.png")
apply_and_save = bool(vcfg.get("apply_and_save", False))
out_dir = vcfg.get("output_path", "") or ""
# Validate thresholds exist for all requested variables
missing = [v for v in variables if v not in var_r2_criteria]
if missing:
raise ValueError(
f"[Fit→Device] R² threshold missing for variables: {missing}"
)
print(f"[Fit→Device] Using device='{device_name}', variables={variables}")
print(f"[Fit→Device] R² thresholds: {var_r2_criteria}")
# --- Ensure we have the R² dataset for TARGET vs DEVICE ---
# This will run the whole validation pipeline (load device, pair, aggregate, merge, compute R²)
from .utils.validation import validate # adjust import if your layout differs
val_res = validate(self, target=target)
r2_ds = val_res.get("r2_ds", None)
if r2_ds is None or not isinstance(r2_ds, xr.Dataset):
raise RuntimeError(
"[Fit→Device] No R² dataset available from validation()."
)
# --- QA scatter grid (X=device, Y=target) before fitting ---
if show_plots or save_plots:
try:
# plot_pair_scatter_grid expects a dict of {ancillary_name: ds}
ds_map = {device_name: r2_ds}
fig, _ = plot_pair_scatter_grid(
r2_datasets=ds_map,
variables=variables,
target_name=target_name,
variable_r2_criteria=var_r2_criteria,
max_time_hr=max_time_hr,
max_dist_km=max_dist_km,
ancillaries_order=[device_name],
)
if save_plots:
# If path looks like a directory, drop a default filename into it
out_is_dir = (plot_output_path.endswith(os.sep)) or (
os.path.isdir(plot_output_path)
)
if out_is_dir:
os.makedirs(plot_output_path, exist_ok=True)
fout = os.path.join(
plot_output_path, f"{target_name}_vs_{device_name}_fit_grid.png"
)
else:
os.makedirs(
os.path.dirname(plot_output_path) or ".", exist_ok=True
)
fout = plot_output_path
fig.savefig(fout, dpi=300)
print(f"[Fit→Device] Saved scatter grid to: {fout}")
if not show_plots:
import matplotlib.pyplot as plt
plt.close(fig)
except Exception as e:
print(f"[Fit→Device] (Plot) Skipped grid due to: {e}")
# --- Compute fits to map TARGET → DEVICE for each variable ---
# collect_xy_from_r2_ds returns (X=device, Y=target). For TARGET→DEVICE we invert to (x=target, y=device).
fits = {}
for var in variables:
X_dev, Y_tgt = collect_xy_from_r2_ds(
r2_ds,
var=var,
target_name=target_name,
ancillary_name=device_name,
r2_min=var_r2_criteria.get(var),
time_max=max_time_hr,
dist_max=max_dist_km,
)
# invert orientation for target->device mapping
x = Y_tgt # target
y = X_dev # device
info = fit_linear_map(x, y) # fits y_device = a * x_target + b
fits[var] = info
print(
f"[Fit→Device] {var}: device ≈ {info['slope']:.4g}·target + {info['intercept']:.4g} "
f"(R²={info['r2']:.3f}, N={info['n']})"
)
return
[docs]
def apply_adjustment(self, target, fit_params):
if target not in self._contexts:
raise ValueError(f"Target pipeline '{target}' has no context data.")
# --- config ---
vcfg = self.settings.get("validation", {}) or {}
device_name = vcfg.get("device_name", "DEVICE")
# --- Apply mapping to the FULL target dataset (create {var}_ALIGNED_TO_{device}) ---
# Rename target variables to standard names based on alignment_map aliases
target_ds_raw = self._contexts[target]["data"]
rename_map = {
alias: std
for std, alias_map in self.alignment_map.items()
if (alias := alias_map.get(target)) and alias in target_ds_raw
}
target_ds_std = (
target_ds_raw.rename(rename_map) if rename_map else target_ds_raw
)
ds_out = target_ds_std.copy()
for var, info in fit_params.items():
if var not in ds_out:
print(f"[Fit→Device] Target missing variable '{var}' — skipping.")
continue
slope, intercept, npts = info["slope"], info["intercept"], info["n"]
out_name = f"{var}_ALIGNED_TO_{device_name}"
aligned = (slope * ds_out[var] + intercept).astype(
ds_out[var].dtype, copy=False
)
aligned.name = out_name
aligned.attrs.update(
{
"long_name": f"{var} aligned to {device_name}",
"alignment_target": target,
"alignment_reference_device": device_name,
"alignment_direction": "target_to_device",
"alignment_slope": float(slope),
"alignment_intercept": float(intercept),
"alignment_fit_points": int(npts),
}
)
self._contexts[target]["data"][out_name] = aligned
# update processed_per_glider for potential later use
if not hasattr(self, "processed_per_glider"):
self.processed_per_glider = {}
if target not in self.processed_per_glider:
self.processed_per_glider[target] = {}
self.processed_per_glider[target][f"last_fit_to_device_{device_name}"] = fit_params
return {"fits": fit_params, "device_name": device_name}
[docs]
def save(self, dir, raw=True, processed=True):
saving = {"raw": raw, "processed": processed}
for data_output, to_save in saving.items():
print(f"Saving {data_output} outputs.")
if to_save:
if len(self._exportables) == 0:
print(f"There is no {data_output} data to save.")
continue
for platform_name, data in self._exportables[data_output].items():
data.to_netcdf(os.path.join(dir, f"{platform_name}_{data_output}.nc"))