Source code for src.toolbox.steps.custom.variables.salinity

# This file is part of the NOC Autonomy Toolbox.
#
# Copyright 2025-2026 National Oceanography Centre and The Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#### Mandatory imports ####
from toolbox.steps.base_step import BaseStep, register_step
from toolbox.utils.qc_handling import QCHandlingMixin
import toolbox.utils.diagnostics as diag

#### Custom imports ####
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy import interpolate
from tqdm import tqdm
import xarray as xr
import numpy as np
import gsw


[docs] def running_average_nan(arr: np.ndarray, window_size: int) -> np.ndarray: """ Estimate running average mean """ if window_size % 2 == 0: raise ValueError("Window size must be odd for symmetry.") pad_size = window_size // 2 # Symmetric padding padded = np.pad(arr, pad_size, mode="reflect") # Edge handling kernel = np.ones(window_size) # Compute weighted sums while ignoring NaNs sum_vals = np.convolve(np.nan_to_num(padded), kernel, mode="valid") count_vals = np.convolve(~np.isnan(padded), kernel, mode="valid") # Compute the moving average, handling NaNs properly. # Providing an 'out' array prevents uninitialised memory warnings. avg = np.divide( sum_vals, count_vals, out=np.full_like(sum_vals, np.nan, dtype=float), where=(count_vals != 0) ) return avg
[docs] def compute_optimal_lag(profile_data, filter_window_size, time_col): """ Calculate the optimal conductivity time lag relative to temperature to reduce salinity spikes for each glider profile. """ profile_data = profile_data[ [time_col, "CNDC", "PRES", "TEMP"] ].dropna(dim="N_MEASUREMENTS", subset=["CNDC"]) if len(profile_data[time_col]) == 0: return np.nan t0 = profile_data[time_col].values[0] elapsed_time = (profile_data[time_col] - t0).dt.total_seconds().values temp_vals = profile_data["TEMP"].values pres_vals = profile_data["PRES"].values conductivity_from_time = interpolate.interp1d( elapsed_time, profile_data["CNDC"].values, bounds_error=False ) time_lags = np.array( [np.linspace(-2, 2, 41), np.full(41, np.nan)] ).T for i, lag in enumerate(time_lags[:, 0].copy()): time_shifted_conductivity = conductivity_from_time(elapsed_time + lag) cndc_scaled = time_shifted_conductivity * 10 if np.nanmax(time_shifted_conductivity) < 10 else time_shifted_conductivity PSAL = gsw.conversions.SP_from_C( cndc_scaled, temp_vals, pres_vals ) PSAL_Smooth = running_average_nan(PSAL, filter_window_size) PSAL_Diff = PSAL - PSAL_Smooth time_lags[i, 1] = np.nanstd(PSAL_Diff) best_score_index = np.argmin(time_lags[:, 1]) return time_lags[best_score_index, 0]
@register_step
[docs] class AdjustSalinity(BaseStep, QCHandlingMixin):
[docs] step_name = "Salinity Adjustment"
[docs] required_variables = ["TIME", "PROFILE_NUMBER", "CNDC", "TEMP", "PRES"]
[docs] provided_variables = []
[docs] def run(self): self.log(f"Running adjustment...") self.data_copy = self.data.copy(deep=True) self.time_col = "TIME_CTD" if self.time_col not in self.data: self.log("TIME_CTD cound not be found. Defaulting to TIME instead.") self.time_col = "TIME" self.filter_qc() self.correct_ct_lag() self.correct_thermal_lag() self.reconstruct_data() self.update_qc() if self.diagnostics: self.generate_diagnostics() self.context["data"] = self.data return self.context
[docs] def correct_ct_lag(self): profile_numbers = np.unique(self.data["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS").values) self.per_profile_optimal_lag = np.full((len(profile_numbers), 3), np.nan) # Extract numpy arrays for fast filtering to avoid Xarray subsetting overhead prof_arr = self.data["PROFILE_NUMBER"].values dir_arr = self.data["PROFILE_DIRECTION"].values time_arr = self.data[self.time_col].values cndc_arr = self.data["CNDC"].values # Shuffle indices to randomly sample profiles across the dataset indices = np.random.permutation(len(profile_numbers)) processed_counts = {-1: 0, 1: 0, 0: 0} max_profiles = 50 for i in tqdm(indices, colour="green", desc='\033[97mCT Lag Progress\033[0m', unit="prof"): # Break the loop entirely once we have enough profiles to save computation if all(count >= max_profiles for count in processed_counts.values()): break profile_number = profile_numbers[i] # Fast numpy subsetting prof_indices = np.where(prof_arr == profile_number)[0] if len(prof_indices) == 0: continue dir_subset = dir_arr[prof_indices] valid_dirs = dir_subset[~np.isnan(dir_subset)] prof_direction = valid_dirs[0] if len(valid_dirs) > 0 else np.nan if np.isnan(prof_direction) or prof_direction not in processed_counts: continue if processed_counts[prof_direction] >= max_profiles: continue prof_times = time_arr[prof_indices] valid_times = prof_times[~np.isnat(prof_times)] if len(valid_times) > 0: duration = valid_times[-1] - valid_times[0] # Check if profile lasted >= 30 minutes and has enough points for the filter if duration >= np.timedelta64(30, 'm') and len(valid_times) > 3 * self.filter_window_size: # Only subset Xarray when we know we are calculating the lag profile = self.data.isel(N_MEASUREMENTS=prof_indices) optimal_lag = compute_optimal_lag(profile, self.filter_window_size, self.time_col) self.per_profile_optimal_lag[i, :] = [profile_number, optimal_lag, prof_direction] processed_counts[prof_direction] += 1 # Use numpy mask to extract valid data for interpolation valid_data_mask = (~np.isnan(cndc_arr)) & (~np.isnat(time_arr)) if not np.any(valid_data_mask): self.log("No valid CNDC data found. Skipping CT lag correction.") return # Calculate medians for Downcast (-1), Upcast (1), and Transect (0) self.ct_lag_medians = {} for d in [-1, 1, 0]: mask = (self.per_profile_optimal_lag[:, 2] == d) & (~np.isnan(self.per_profile_optimal_lag[:, 1])) dir_lags = self.per_profile_optimal_lag[mask, 1] if len(dir_lags) > 0: self.ct_lag_medians[d] = np.median(dir_lags) else: self.ct_lag_medians[d] = 0.0 # Default to 0 if no valid profiles survive the 30min check valid_times = time_arr[valid_data_mask] valid_cndc = cndc_arr[valid_data_mask] valid_dirs = dir_arr[valid_data_mask] t0 = valid_times[0] elapsed_time = (valid_times - t0) / np.timedelta64(1, 's') CNDC_from_TIME = interpolate.interp1d( elapsed_time, valid_cndc, bounds_error=False ) corrected_cndc = valid_cndc.copy() for d in [-1, 1, 0]: dir_mask = valid_dirs == d if np.any(dir_mask): shifted_time = elapsed_time[dir_mask] + self.ct_lag_medians[d] corrected_cndc[dir_mask] = CNDC_from_TIME(shifted_time) # Apply corrections back to the main array directly final_cndc = cndc_arr.copy() final_cndc[valid_data_mask] = corrected_cndc self.data["CNDC"].values = final_cndc
[docs] def correct_thermal_lag(self): """ Applies thermal mass correction independently to the downcast, transect, and upcast. """ import scipy.signal corrected_temp_array = np.full(len(self.data["TEMP"]), np.nan) profile_numbers = np.unique(self.data["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS").values) self.filter_params = {} prof_arr = self.data["PROFILE_NUMBER"].values dir_arr = self.data["PROFILE_DIRECTION"].values temp_arr = self.data["TEMP"].values time_arr = self.data[self.time_col].values valid_mask = (~np.isnan(temp_arr)) & (~np.isnan(time_arr)) for prof in tqdm(profile_numbers, colour="blue", desc='\033[97mThermal Lag Progress\033[0m', unit="prof"): for direction in [-1, 1, 0]: mask = (prof_arr == prof) & (dir_arr == direction) & valid_mask indices = np.where(mask)[0] if len(indices) < 5: continue cast_times = time_arr[indices] cast_temps = temp_arr[indices] t0 = cast_times[0] elapsed_time = (cast_times - t0) / np.timedelta64(1, 's') TEMP_from_TIME = interpolate.interp1d(elapsed_time, cast_temps, bounds_error=False, fill_value="extrapolate") TIME_1Hz_sampling = np.arange(0, elapsed_time[-1], 1) if len(TIME_1Hz_sampling) < 2: continue TEMP_1Hz_sampling = TEMP_from_TIME(TIME_1Hz_sampling) alpha_offset = 0.0135 alpha_slope = 0.0264 tau_offset = 7.1499 tau_slope = 2.7858 flow_rate = 0.4867 tau = tau_offset + tau_slope / np.sqrt(flow_rate) alpha = alpha_offset + alpha_slope / flow_rate self.filter_params = {"alpha": alpha, "tau": tau} nyquist_frequency = 1/2 a = 4 * nyquist_frequency * alpha * tau / (1 + 4 * nyquist_frequency * tau) b = 1 - (2 * a / alpha) delta_TEMP = np.zeros_like(TEMP_1Hz_sampling) delta_TEMP[1:] = np.diff(TEMP_1Hz_sampling) TEMP_correction = scipy.signal.lfilter([a], [1, b], delta_TEMP) corrected_TEMP_1Hz_sampling = TEMP_1Hz_sampling - TEMP_correction corrected_TEMP_from_TIME = interpolate.interp1d( TIME_1Hz_sampling, corrected_TEMP_1Hz_sampling, bounds_error=False, fill_value="extrapolate" ) corrected_temp_array[indices] = corrected_TEMP_from_TIME(elapsed_time) final_temp = np.where(np.isnan(corrected_temp_array), self.data["TEMP"].values, corrected_temp_array) self.data["TEMP"][:] = final_temp
[docs] def generate_diagnostics(self): COLOUR_RAW = "indianred" COLOUR_CORRECTED = "steelblue" mpl.use("tkagg") fig = plt.figure(figsize=(15, 5), dpi=150) gs = fig.add_gridspec(1, 6, wspace=0.4) ax_lag = fig.add_subplot(gs[0, 0:2]) ax_therm_down = fig.add_subplot(gs[0, 2]) ax_therm_trans = fig.add_subplot(gs[0, 3], sharey=ax_therm_down) ax_therm_up = fig.add_subplot(gs[0, 4], sharey=ax_therm_down) ax_prof = fig.add_subplot(gs[0, 5]) plt.setp(ax_therm_trans.get_yticklabels(), visible=False) plt.setp(ax_therm_up.get_yticklabels(), visible=False) # Handle purely empty profiles gently if np.all(np.isnan(self.per_profile_optimal_lag[:, 0])): prof_min, prof_max = 0, 1 else: prof_min = np.nanmin(self.per_profile_optimal_lag[:, 0]) prof_max = np.nanmax(self.per_profile_optimal_lag[:, 0]) colours_dir = {-1: "forestgreen", 1: "royalblue", 0: "mediumorchid"} labels_dir = {-1: "Down", 1: "Up", 0: "Trans"} for d in [-1, 1, 0]: if d in self.ct_lag_medians: med_val = self.ct_lag_medians[d] ax_lag.plot([prof_min, prof_max], [med_val, med_val], c=colours_dir[d], ls="--", lw=2, label=f"{labels_dir[d]}: {med_val:.2f}s") mask = self.per_profile_optimal_lag[:, 2] == d ax_lag.scatter(self.per_profile_optimal_lag[mask, 0], self.per_profile_optimal_lag[mask, 1], c=colours_dir[d], s=10, alpha=0.6) ax_lag.plot([prof_min, prof_max], [0, 0], "k-", lw=1) ax_lag.set_xlabel("Profile Index") ax_lag.set_ylabel("CTlag (s)") ax_lag.set_title("Optimal C-T Lag per Profile") ax_lag.legend(loc="upper right", fontsize=8) ax_lag.grid(True, alpha=0.3) p_start, p_end = self.plot_profiles_in_range mask_range = (self.data_copy["PROFILE_NUMBER"] >= p_start) & (self.data_copy["PROFILE_NUMBER"] <= p_end) uncorrected = self.data_copy.where(mask_range, drop=True) corrected = self.data.where(mask_range, drop=True) if len(uncorrected["DEPTH"].dropna(dim="N_MEASUREMENTS")) == 0: ax_therm_down.text(0.5, 0.5, "No Data", ha="center", va="center") ax_prof.text(0.5, 0.5, "No Data", ha="center", va="center") else: delta_temp = uncorrected["TEMP"] - corrected["TEMP"] alpha_val = self.filter_params.get('alpha', 0.0) tau_val = self.filter_params.get('tau', 0.0) ax_therm_trans.set_title(f"Thermal Error (°C)\nalpha={alpha_val:.3f}, tau={tau_val:.1f}s\n\nTransect", fontsize=9) ax_therm_down.set_title("\n\nDowncast", fontsize=9) ax_therm_up.set_title("\n\nUpcast", fontsize=9) for direction, ax_therm, marker in zip([-1, 0, 1], [ax_therm_down, ax_therm_trans, ax_therm_up], ["o", "s", "x"]): mask_dir = uncorrected["PROFILE_DIRECTION"] == direction if np.any(mask_dir): ax_therm.plot(delta_temp[mask_dir], uncorrected["DEPTH"][mask_dir], ls="", marker=marker, markersize=3, alpha=0.5, c=colours_dir[direction]) ax_therm.axvline(0, color="black", ls="--", alpha=0.5) ax_therm.grid(True, alpha=0.3) # Invert the shared y-axis so 0 is at the top ax_therm_down.invert_yaxis() ax_therm_down.set_ylabel("Depth (m)") cndc_raw_max = uncorrected["CNDC"].max(skipna=True).values cndc_raw = uncorrected["CNDC"] * 10 if not np.isnan(cndc_raw_max) and cndc_raw_max < 10 else uncorrected["CNDC"] temp_raw = uncorrected["TEMP"] cndc_lagged_max = corrected["CNDC"].max(skipna=True).values cndc_lagged = corrected["CNDC"] * 10 if not np.isnan(cndc_lagged_max) and cndc_lagged_max < 10 else corrected["CNDC"] temp_therm = corrected["TEMP"] pres = uncorrected["PRES"] psal_raw = gsw.conversions.SP_from_C(cndc_raw, temp_raw, pres) psal_full = gsw.conversions.SP_from_C(cndc_lagged, temp_therm, pres) for direction, marker in zip([-1, 1, 0], ["o", "x", "s"]): mask_dir = uncorrected["PROFILE_DIRECTION"] == direction if np.any(mask_dir): label_raw = "Raw" if direction == -1 else None label_corr = "Corrected" if direction == -1 else None ax_prof.plot(psal_raw[mask_dir], uncorrected["DEPTH"][mask_dir], c=COLOUR_RAW, ls="", marker=marker, markersize=2, alpha=0.4, label=label_raw) ax_prof.plot(psal_full[mask_dir], uncorrected["DEPTH"][mask_dir], c=COLOUR_CORRECTED, ls="", marker=marker, markersize=2, alpha=0.6, label=label_corr) leg = ax_prof.legend(loc="lower left", fontsize=8) handles = getattr(leg, "legend_handles", getattr(leg, "legendHandles", [])) for handle in handles: if hasattr(handle, "set_sizes"): handle.set_sizes([20]) handle.set_alpha(1) # Invert the profile plot y-axis so 0 is at the top ax_prof.invert_yaxis() ax_prof.set_xlabel("Practical Salinity") ax_prof.set_title(f"Final Result ({p_start}-{p_end})", fontsize=10) ax_prof.grid(True, alpha=0.3) fig.subplots_adjust(top=0.85) plt.show(block=True)