Source code for src.toolbox.steps.custom.variables.chla

# This file is part of the NOC Autonomy Toolbox.
#
# Copyright 2025-2026 National Oceanography Centre and The Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#### Mandatory imports ####
from toolbox.steps.base_step import BaseStep, register_step
from toolbox.utils.qc_handling import QCHandlingMixin
import toolbox.utils.diagnostics as diag

#### Custom imports ####
import xarray as xr
import numpy as np
import pandas as pd
import pvlib
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm

[docs] def check_chl_variables(self, allowed_requests): user_request = self.apply_to if user_request not in self.data.data_vars: raise KeyError(f"The variable {user_request} does not exist in the data.") if user_request not in allowed_requests: raise KeyError( f"The variable {user_request} is not permitted for [{self.step_name}]" ) if f"{user_request}_ADJUSTED" in self.data.data_vars: self.log( f"User requested processing on {user_request} but {user_request}_ADJUSTED already exists. Using {user_request}_ADJUSTED..." ) user_request = f"{user_request}_ADJUSTED" output_as = user_request + ("_ADJUSTED" if "_ADJUSTED" not in user_request else "") self.log(f"Processing {user_request}...") return user_request, output_as
@register_step
[docs] class chla_deep_correction(BaseStep, QCHandlingMixin):
[docs] step_name = "Chla Deep Correction"
[docs] required_variables = ["TIME", "PROFILE_NUMBER", "DEPTH"]
[docs] provided_variables = []
[docs] def run(self): """ Example ------- - name: "Chla Deep Correction" parameters: apply_to: "CHLA" dark_value: null depth_threshold: 200 diagnostics: true """ self.filter_qc() # Save a copy of the pre-corrected data for the diagnostics plot self.data_copy = self.data.copy(deep=True) # Check this step is being applied to a valid variable self.apply_to, self.output_as = check_chl_variables( self, [ "CHLA", "CHLA_ADJUSTED", "CHLA_FLUORESCENCE", "CHLA_FLUORESCENCE_ADJUSTED", ], ) self.compute_dark_value() self.apply_dark_correction() self.reconstruct_data() self.update_qc() # Generate new QC if a non-adjusted variable was used in processing if self.apply_to != self.output_as: self.generate_qc({f"{self.output_as}_QC": [f"{self.apply_to}_QC"]}) if self.diagnostics: self.generate_diagnostics() self.context["data"] = self.data return self.context
[docs] def compute_dark_value(self): # Check config file for existing dark value if getattr(self, "dark_value", None) is not None: self.log(f"Using dark value from config: {self.dark_value}") return self.dark_value self.log( f"Computing dark value from profiles reaching >= {self.depth_threshold}m" ) missing_vars = {"TIME", "PROFILE_NUMBER", "DEPTH", self.apply_to} - set( self.data.data_vars ) if missing_vars: raise KeyError( f"[Chla Deep Correction] {missing_vars} could not be found in the data." ) interp_data = self.data[ ["TIME", "PROFILE_NUMBER", "DEPTH", self.apply_to] ].to_pandas() interp_data["DEPTH"] = ( interp_data.set_index("TIME")["DEPTH"].interpolate().reset_index(drop=True) ) interp_data = interp_data.dropna(subset=[self.apply_to, "PROFILE_NUMBER"]) # Subset the data to only deep measurements interp_data = interp_data[interp_data["DEPTH"] <= self.depth_threshold] deep_profiles = ( interp_data.groupby("PROFILE_NUMBER") .agg({self.apply_to: "count"}) .reset_index() ) deep_profiles = deep_profiles[deep_profiles[self.apply_to] > 0][ "PROFILE_NUMBER" ].to_numpy() if len(deep_profiles) == 0: raise ValueError( "[Chla Deep Correction] No deep profiles could be identified. " "Try adjusting the 'depth_threshold' parameter." ) interp_data = interp_data[interp_data["PROFILE_NUMBER"].isin(deep_profiles)] self.chla_deep_minima = interp_data.loc[ interp_data.groupby("PROFILE_NUMBER")[self.apply_to].idxmin(), ["TIME", "PROFILE_NUMBER", "DEPTH", self.apply_to], ] self.dark_value = np.nanmedian(self.chla_deep_minima[self.apply_to]) self.log( f"\nComputed dark value: {self.dark_value:.6f} " f"(median of {len(self.chla_deep_minima)} profile minimums)\n" f"Min values range: {np.min(self.chla_deep_minima[self.apply_to]):.6f} " f"to {np.max(self.chla_deep_minima[self.apply_to]):.6f}" )
[docs] def apply_dark_correction(self): self.data[self.output_as] = xr.DataArray( self.data[self.apply_to] - self.dark_value, dims=self.data[self.apply_to].dims, coords=self.data[self.apply_to].coords, ) if hasattr(self.data[self.apply_to], "attrs"): self.data[self.output_as].attrs = self.data[self.apply_to].attrs.copy() self.data[self.output_as].attrs[ "comment" ] = f"{self.apply_to} with dark value correction (dark_value={self.dark_value:.6f})" self.data[self.output_as].attrs["dark_value"] = self.dark_value
[docs] def generate_diagnostics(self): mpl.use("tkagg") fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), dpi=150) # Panel 1: Histogram of deep minima ax1.hist(self.chla_deep_minima[self.apply_to], bins=30, color="steelblue", edgecolor="black", alpha=0.8) ax1.axvline(self.dark_value, color="red", linestyle="dashed", linewidth=2, label=f"Dark Value (Median): {self.dark_value:.4f}") ax1.set_xlabel(f"Minimum {self.apply_to} below {self.depth_threshold}m") ax1.set_ylabel("Frequency (Number of Profiles)") ax1.set_title("Distribution of Deep Minima") ax1.legend() ax1.grid(True, alpha=0.3) # Panel 2: Before and After Depth Shift ax2.scatter(self.data_copy[self.apply_to], self.data_copy["DEPTH"], c="indianred", s=2, alpha=0.3, label="Uncorrected") ax2.scatter(self.data[self.output_as], self.data["DEPTH"], c="steelblue", s=2, alpha=0.3, label="Corrected") ax2.axhline(self.depth_threshold, color="black", linestyle="--", label=f"Threshold ({self.depth_threshold}m)") ax2.axvline(0, color="black", linestyle="-", linewidth=1, alpha=0.5) ax2.axvline(self.dark_value, color="red", linestyle=":", label="Dark Value Offset") ax2.set_xlabel(self.apply_to) ax2.set_ylabel("Depth (m)") ax2.set_title("Deployment Profiles: Dark Correction Shift") leg = ax2.legend(loc="lower right") handles = getattr(leg, "legend_handles", getattr(leg, "legendHandles", [])) for handle in handles: if hasattr(handle, "set_sizes"): handle.set_sizes([30]) handle.set_alpha(1) ax2.grid(True, alpha=0.3) fig.tight_layout() plt.show(block=True)
@register_step
[docs] class chla_quenching_correction(BaseStep, QCHandlingMixin):
[docs] step_name = "Chla Quenching Correction"
[docs] required_variables = ["PROFILE_NUMBER", "TIME", "DEPTH", "LATITUDE", "LONGITUDE"]
[docs] provided_variables = []
[docs] def run(self): """ Example ------- - name: "Chla Quenching Correction" parameters: method: "Argo" apply_to: "CHLA" mld_settings: { "threshold_on": "TEMP", "reference_depth": 10, "threshold": 0.2 } plot_profiles: [] diagnostics: true """ self.pre_qc_data = self.data.copy(deep=True) self.filter_qc() self.pre_correction_data = self.data.copy(deep=True) self.apply_to, self.output_as = check_chl_variables( self, [ "CHLA", "CHLA_ADJUSTED", "CHLA_FLUORESCENCE", "CHLA_FLUORESCENCE_ADJUSTED", ], ) if self.apply_to != self.output_as: self.data[self.output_as] = self.data[self.apply_to] methods = {"argo": self.apply_xing2012_quenching_correction} if self.method.lower() not in methods.keys(): raise KeyError(f"Method {self.method} is not supported") method_function = methods[self.method.lower()] if self.method.lower() in ["argo"]: sun_df = self.data[ ["PROFILE_NUMBER", "TIME", "DEPTH", "LATITUDE", "LONGITUDE"] ].to_pandas() self.sun_args = ( sun_df.sort_values(["PROFILE_NUMBER", "DEPTH"], ascending=[True, False]) .groupby("PROFILE_NUMBER") .head(50) .groupby("PROFILE_NUMBER") .agg({"TIME": "median", "LATITUDE": "median", "LONGITUDE": "median"}) ) method_variable_requirements = { "argo": { "PROFILE_NUMBER", "DEPTH", self.apply_to, self.mld_settings["threshold_on"], } } data_subset = self.data[list(method_variable_requirements[self.method.lower()])] profile_numbers = np.unique(data_subset["PROFILE_NUMBER"].dropna(dim="N_MEASUREMENTS")) self.diagnostic_meta = {} self.stats = { "total": len(profile_numbers), "corrected": 0, "skipped_night": 0, "skipped_no_mld": 0, "skipped_no_gps": 0, "skipped_no_data": 0 } for profile_number in tqdm( profile_numbers, colour="green", desc="\033[97mProgress\033[0m", unit="prof" ): profile = data_subset.where( data_subset["PROFILE_NUMBER"] == profile_number, drop=True ) corrected_chla, meta = method_function(profile) self.diagnostic_meta[profile_number] = meta self.stats[meta["status"]] += 1 profile_indices = np.where(self.data["PROFILE_NUMBER"] == profile_number) self.data[self.output_as][profile_indices] = corrected_chla self.reconstruct_data() self.update_qc() if self.apply_to != self.output_as: self.generate_qc({f"{self.output_as}_QC": [f"{self.apply_to}_QC"]}) self.log("\n--- Quenching Correction Summary ---") self.log(f"Total profiles evaluated: {self.stats['total']}") self.log(f"Successfully corrected: {self.stats['corrected']}") self.log(f"Skipped (Night time): {self.stats['skipped_night']}") self.log(f"Skipped (No valid MLD): {self.stats['skipped_no_mld']}") self.log(f"Skipped (No CHLA data): {self.stats['skipped_no_data']}") self.log(f"Skipped (Missing GPS): {self.stats['skipped_no_gps']}\n") if self.diagnostics: self.generate_diagnostics() self.context["data"] = self.data return self.context
[docs] def calculate_mld(self, profile): for k, v in self.mld_settings.items(): setattr(self, k, v) profile_subset = profile.where( profile["DEPTH"] <= self.reference_depth, drop=True ).dropna(dim="N_MEASUREMENTS", subset=["DEPTH", self.threshold_on]) if len(profile_subset["DEPTH"]) == 0: return np.nan reference_point = profile_subset.isel( {"N_MEASUREMENTS": np.nanargmax(profile_subset["DEPTH"])}, ) if reference_point["DEPTH"] < 2 * self.reference_depth: return np.nan reference_value = reference_point[self.threshold_on] profile_subset["delta"] = profile_subset[self.threshold_on] - reference_value profile_subset = profile_subset.where( np.abs(profile_subset["delta"]) >= np.abs(self.threshold), drop=True ) mld_value = np.nan if len(profile_subset["DEPTH"]) != 0: mld_value = float(profile_subset.isel({"N_MEASUREMENTS": 0})["DEPTH"]) return mld_value
[docs] def apply_xing2012_quenching_correction(self, profile): chlf = np.asarray(profile[self.apply_to].values, dtype=float) depth = np.asarray(profile["DEPTH"].values, dtype=float) N = len(chlf) meta = { "status": "skipped_no_data", "mld": np.nan, "sun_angle": np.nan, "z_qd": np.nan } if len(profile["PROFILE_NUMBER"]) == 0 or np.isnan(profile["PROFILE_NUMBER"].values[0]): return chlf, meta profile_number = int(profile["PROFILE_NUMBER"].values[0]) if profile_number not in self.sun_args.index: meta["status"] = "skipped_no_gps" return chlf, meta time, lat, long = self.sun_args.loc[profile_number].to_numpy() if pd.isna(time) or pd.isna(lat) or pd.isna(long): meta["status"] = "skipped_no_gps" return chlf, meta time_utc = pd.to_datetime(time) if time_utc.tzinfo is None: time_utc = time_utc.tz_localize("UTC") solar_position = pvlib.solarposition.get_solarposition(time_utc, lat, long) sun_angle = float(solar_position["elevation"].values[0]) meta["sun_angle"] = sun_angle if sun_angle <= 0: meta["status"] = "skipped_night" return chlf, meta mld = self.calculate_mld(profile) meta["mld"] = mld if N == 0 or len(depth) != N or np.all(np.isnan(chlf)): return chlf, meta # Ensure mld is valid and negative (below surface) if not np.isfinite(mld) or mld >= 0: meta["status"] = "skipped_no_mld" return chlf, meta within_mld = depth >= mld if not np.any(within_mld): return chlf, meta chlf_mld = np.where(within_mld, chlf, np.nan) if np.all(np.isnan(chlf_mld)): return chlf, meta idx_max, chlf_max = np.nanargmax(chlf_mld), np.nanmax(chlf_mld) chlf_max_depth = float(depth[idx_max]) meta["z_qd"] = chlf_max_depth meta["status"] = "corrected" chl_corr = np.copy(chlf) chl_corr[(depth >= chlf_max_depth) & (~np.isnan(chlf))] = chlf_max return chl_corr, meta
[docs] def generate_diagnostics(self): # --- Configurable Plot Variables --- MIN_POINTS_TO_PLOT = 100 COLOR_SUN = "orange" COLOR_RAW = "lightgrey" COLOR_UNCORRECTED = "indianred" COLOR_CORRECTED = "steelblue" PLOT_SIZE_OVERVIEW = (14, 8) PLOT_SIZE_PROFILES = (12, 5) mpl.use("tkagg") fig_overview, (ax_sun, ax_chla) = plt.subplots( 2, 1, figsize=PLOT_SIZE_OVERVIEW, sharex=True, gridspec_kw={'height_ratios': [1, 3]}, dpi=150 ) time_vals = self.data["TIME"].values # Sort values chronologically to prevent zigzag plotting artifacts sun_args_sorted = self.sun_args.sort_values(by="TIME") sun_times = pd.to_datetime(sun_args_sorted["TIME"].values) sun_angles = [] for t, lat, lon in zip(sun_times, sun_args_sorted["LATITUDE"], sun_args_sorted["LONGITUDE"]): t_utc = t.tz_localize("UTC") if t.tzinfo is None else t sun_angles.append(pvlib.solarposition.get_solarposition(t_utc, lat, lon)["elevation"].values[0]) ax_sun.plot(sun_times, sun_angles, color=COLOR_SUN, lw=1.5) ax_sun.axhline(0, color="black", ls="--", lw=1) ax_sun.set_ylabel("Sun Elevation (deg)") ax_sun.set_title("Deployment Overview: Sun Elevation and CHLA Adjustments") ax_sun.grid(True, alpha=0.3) ax_sun.fill_between(sun_times, sun_angles, 0, where=(np.array(sun_angles) > 0), color="yellow", alpha=0.2) ax_sun.fill_between(sun_times, sun_angles, 0, where=(np.array(sun_angles) <= 0), color="grey", alpha=0.2) ax_chla.scatter( self.pre_qc_data["TIME"].values, self.pre_qc_data[self.apply_to].values, c=COLOR_RAW, s=5, alpha=0.5, label="Raw (Failed QC)" ) ax_chla.scatter( time_vals, self.pre_correction_data[self.apply_to].values, c=COLOR_UNCORRECTED, s=5, alpha=0.7, label="Uncorrected" ) ax_chla.scatter( time_vals, self.data[self.output_as].values, c=COLOR_CORRECTED, s=5, alpha=0.7, label="Corrected" ) ax_chla.set_ylabel(self.apply_to) ax_chla.set_xlabel("Time") ax_chla.legend(loc="upper right") ax_chla.grid(True, alpha=0.3) fig_overview.tight_layout() fig_overview.show() if len(self.plot_profiles) == 0: return # Find profiles that meet the minimum data point requirement valid_df = self.data[["PROFILE_NUMBER", "DEPTH", self.apply_to]].to_pandas().dropna() counts = valid_df.groupby("PROFILE_NUMBER").size() dense_profiles = counts[counts >= MIN_POINTS_TO_PLOT].index.values profile_df = self.data[["PROFILE_NUMBER", "PROFILE_DIRECTION"]].to_pandas().drop_duplicates(subset=["PROFILE_NUMBER"]).dropna() # Filter for vertical profiles that are also dense vertical_profiles = profile_df[ (profile_df["PROFILE_DIRECTION"] != 0) & (profile_df["PROFILE_NUMBER"].isin(dense_profiles)) ]["PROFILE_NUMBER"].values mapped_plot_profiles = [] for requested_prof in self.plot_profiles: idx = requested_prof - 1 if 0 <= idx < len(vertical_profiles): mapped_plot_profiles.append(int(vertical_profiles[idx])) else: self.log(f"Requested profile {requested_prof} is out of bounds for the available dense vertical profiles.") if not mapped_plot_profiles: self.log("No valid profiles found to plot after mapping.") return nrows = int(np.ceil(len(mapped_plot_profiles) / 3)) fig_profs, axs = plt.subplots(nrows=nrows, ncols=3, figsize=(PLOT_SIZE_PROFILES[0], nrows * PLOT_SIZE_PROFILES[1]), dpi=150) if nrows == 1 and len(mapped_plot_profiles) == 1: axs = [axs] else: axs = axs.flatten() for profile_number, requested_prof, ax in zip(mapped_plot_profiles, self.plot_profiles, axs): meta = self.diagnostic_meta.get(profile_number, {}) prof_pre_corr = self.pre_correction_data.where(self.pre_correction_data["PROFILE_NUMBER"] == profile_number, drop=True) prof_post_corr = self.data.where(self.data["PROFILE_NUMBER"] == profile_number, drop=True) if len(prof_pre_corr["DEPTH"].dropna(dim="N_MEASUREMENTS")) == 0: ax.text(0.5, 0.5, f"No Data\nProf: {profile_number} (Old: {requested_prof})", ha="center", va="center") continue ax.plot( prof_pre_corr[self.apply_to], prof_pre_corr["DEPTH"], c=COLOR_UNCORRECTED, marker="o", ls="-", lw=1, markersize=3, label="Uncorrected", alpha=0.7 ) ax.plot( prof_post_corr[self.output_as], prof_post_corr["DEPTH"], c=COLOR_CORRECTED, marker="o", ls="-", lw=1, markersize=3, label="Corrected", alpha=0.8 ) if pd.notna(meta.get("mld")): ax.axhline(meta["mld"], color="purple", ls="--", lw=1.5, label=f"MLD ({meta['mld']:.1f}m)") if pd.notna(meta.get("z_qd")): ax.axhline(meta["z_qd"], color="green", ls=":", lw=1.5, label=f"Z_qd ({meta['z_qd']:.1f}m)") ax.invert_yaxis() ax.set_xlabel(self.apply_to) ax.set_ylabel("Depth (m)") status_text = f"Status: {meta.get('status', 'Unknown')}\nSun: {meta.get('sun_angle', np.nan):.1f} deg" ax.text(0.05, 0.05, status_text, transform=ax.transAxes, fontsize=8, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none')) ax.legend(title=f"Prof {profile_number} (Old {requested_prof})", loc="lower right", fontsize=8) fig_profs.suptitle("Quenching Correction: Profile Level Diagnostics") fig_profs.tight_layout() plt.show(block=True)