Skip to content

clinops.monitor

clinops.monitor.drift.DistributionDriftDetector

DistributionDriftDetector(
    n_bins=10,
    psi_threshold_medium=0.1,
    psi_threshold_high=0.2,
    run_ks_test=True,
    columns=None,
)

Detect distribution drift between a reference dataset and a current batch.

Fit on a reference dataset (typically the training set), then call detect() on each new batch to get per-column PSI and KS statistics.

Parameters:

Name Type Description Default
n_bins int

Number of equal-frequency bins used to compute PSI. Default 10. Use fewer bins for small datasets (< 500 rows).

10
psi_threshold_medium float

PSI threshold for MEDIUM severity. Default 0.1.

0.1
psi_threshold_high float

PSI threshold for HIGH severity. Default 0.2.

0.2
run_ks_test bool

If True, run a KS two-sample test in addition to PSI. Default True.

True
columns list[str] | None

Explicit list of columns to monitor. If None, all numeric columns in the reference DataFrame are monitored.

None

Examples:

>>> detector = DistributionDriftDetector()
>>> detector.fit(train_df)
>>> report = detector.detect(production_batch_df)
>>> print(report.summary())
>>> print(report.to_dataframe())
>>> # Only alert on high-severity drift
>>> drifted = report.drifted_columns(DriftSeverity.HIGH)
Source code in clinops/monitor/drift.py
def __init__(
    self,
    n_bins: int = 10,
    psi_threshold_medium: float = 0.1,
    psi_threshold_high: float = 0.2,
    run_ks_test: bool = True,
    columns: list[str] | None = None,
) -> None:
    if n_bins < 2:
        raise ValueError(f"n_bins must be >= 2, got {n_bins}")
    self.n_bins = n_bins
    self.psi_threshold_medium = psi_threshold_medium
    self.psi_threshold_high = psi_threshold_high
    self.run_ks_test = run_ks_test
    self.columns = columns
    self._reference_data: dict[str, np.ndarray] = {}
    self._bin_edges: dict[str, np.ndarray] = {}

fit

fit(df)

Compute reference statistics from a training/baseline DataFrame.

Parameters:

Name Type Description Default
df DataFrame

Reference DataFrame (typically the training set).

required

Returns:

Type Description
DistributionDriftDetector

Self, for method chaining.

Source code in clinops/monitor/drift.py
def fit(self, df: pd.DataFrame) -> DistributionDriftDetector:
    """
    Compute reference statistics from a training/baseline DataFrame.

    Parameters
    ----------
    df:
        Reference DataFrame (typically the training set).

    Returns
    -------
    DistributionDriftDetector
        Self, for method chaining.
    """
    cols = self.columns or list(df.select_dtypes(include=[np.number]).columns)
    self._reference_data = {}
    self._bin_edges = {}

    for col in cols:
        if col not in df.columns:
            logger.warning(f"DriftDetector.fit: column '{col}' not in DataFrame — skipping")
            continue
        values = df[col].dropna().to_numpy(dtype=float)
        if len(values) == 0:
            logger.warning(
                f"DriftDetector.fit: column '{col}' has no non-null values — skipping"
            )
            continue
        self._reference_data[col] = values
        # Build equal-frequency bin edges from the reference distribution
        quantiles = np.linspace(0, 100, self.n_bins + 1)
        edges = np.unique(np.percentile(values, quantiles))
        # Ensure open-ended outer bins to cover any current values outside reference range
        edges[0] = -np.inf
        edges[-1] = np.inf
        self._bin_edges[col] = edges

    logger.info(f"DistributionDriftDetector fitted on {len(self._reference_data)} columns")
    return self

detect

detect(df)

Compute drift metrics for each fitted column.

Parameters:

Name Type Description Default
df DataFrame

Current DataFrame to compare against the reference.

required

Returns:

Type Description
DriftReport

Raises:

Type Description
RuntimeError

If fit() has not been called.

Source code in clinops/monitor/drift.py
def detect(self, df: pd.DataFrame) -> DriftReport:
    """
    Compute drift metrics for each fitted column.

    Parameters
    ----------
    df:
        Current DataFrame to compare against the reference.

    Returns
    -------
    DriftReport

    Raises
    ------
    RuntimeError
        If ``fit()`` has not been called.
    """
    if not self._reference_data:
        raise RuntimeError("Call fit() before detect()")

    results: list[ColumnDriftResult] = []

    for col, ref_values in self._reference_data.items():
        if col not in df.columns:
            logger.warning(f"DriftDetector.detect: column '{col}' missing from current batch")
            continue

        cur_values = df[col].dropna().to_numpy(dtype=float)
        if len(cur_values) == 0:
            logger.warning(
                f"DriftDetector.detect: column '{col}' has no non-null values in current batch"
            )
            continue

        psi = self._compute_psi(ref_values, cur_values, self._bin_edges[col])
        severity = self._severity(psi)

        ks_stat: float | None = None
        ks_pval: float | None = None
        if self.run_ks_test:
            ks_result = stats.ks_2samp(ref_values, cur_values)
            ks_stat = float(ks_result.statistic)
            ks_pval = float(ks_result.pvalue)

        results.append(
            ColumnDriftResult(
                column=col,
                psi=psi,
                ks_statistic=ks_stat,
                ks_pvalue=ks_pval,
                severity=severity,
                reference_mean=float(np.mean(ref_values)),
                current_mean=float(np.mean(cur_values)),
                reference_std=float(np.std(ref_values)),
                current_std=float(np.std(cur_values)),
                n_reference=len(ref_values),
                n_current=len(cur_values),
            )
        )

    report = DriftReport(results=results, n_columns_checked=len(results))

    logger.info(
        f"DriftDetector: {report.n_high} HIGH, {report.n_medium} MEDIUM, "
        f"{report.n_low} LOW across {report.n_columns_checked} columns"
    )
    return report

clinops.monitor.drift.DriftReport dataclass

DriftReport(results, n_columns_checked)

Structured result from :class:DistributionDriftDetector.

Attributes:

Name Type Description
results list[ColumnDriftResult]

Per-column drift metrics.

n_columns_checked int

Total number of numeric columns evaluated.

n_low property

n_low

Number of columns with LOW severity drift.

n_medium property

n_medium

Number of columns with MEDIUM severity drift.

n_high property

n_high

Number of columns with HIGH severity drift.

drifted_columns

drifted_columns(min_severity=DriftSeverity.MEDIUM)

Return column names with drift at or above min_severity.

Parameters:

Name Type Description Default
min_severity DriftSeverity

Minimum severity level to include. Default: MEDIUM.

MEDIUM

Returns:

Type Description
list[str]
Source code in clinops/monitor/drift.py
def drifted_columns(self, min_severity: DriftSeverity = DriftSeverity.MEDIUM) -> list[str]:
    """
    Return column names with drift at or above ``min_severity``.

    Parameters
    ----------
    min_severity:
        Minimum severity level to include. Default: MEDIUM.

    Returns
    -------
    list[str]
    """
    order = {DriftSeverity.LOW: 0, DriftSeverity.MEDIUM: 1, DriftSeverity.HIGH: 2}
    threshold = order[min_severity]
    return [r.column for r in self.results if order[r.severity] >= threshold]

to_dataframe

to_dataframe()

Return per-column results as a DataFrame sorted by PSI descending.

Source code in clinops/monitor/drift.py
def to_dataframe(self) -> pd.DataFrame:
    """Return per-column results as a DataFrame sorted by PSI descending."""
    rows: list[dict[str, Any]] = [
        {
            "column": r.column,
            "psi": round(r.psi, 4),
            "severity": r.severity.value,
            "ks_statistic": round(r.ks_statistic, 4) if r.ks_statistic is not None else None,
            "ks_pvalue": round(r.ks_pvalue, 4) if r.ks_pvalue is not None else None,
            "reference_mean": round(r.reference_mean, 4),
            "current_mean": round(r.current_mean, 4),
            "mean_shift_pct": round(r.mean_shift_pct, 2),
            "n_reference": r.n_reference,
            "n_current": r.n_current,
        }
        for r in self.results
    ]
    return pd.DataFrame(rows).sort_values("psi", ascending=False).reset_index(drop=True)

summary

summary()

Human-readable drift summary.

Source code in clinops/monitor/drift.py
def summary(self) -> str:
    """Human-readable drift summary."""
    lines = [
        f"Columns checked : {self.n_columns_checked}",
        f"HIGH drift      : {self.n_high}",
        f"MEDIUM drift    : {self.n_medium}",
        f"LOW drift       : {self.n_low}",
    ]
    high_cols = self.drifted_columns(DriftSeverity.HIGH)
    if high_cols:
        lines.append(f"HIGH columns    : {', '.join(high_cols)}")
    med_cols = [c for c in self.drifted_columns(DriftSeverity.MEDIUM) if c not in high_cols]
    if med_cols:
        lines.append(f"MEDIUM columns  : {', '.join(med_cols)}")
    return "\n".join(lines)

clinops.monitor.drift.ColumnDriftResult dataclass

ColumnDriftResult(
    column,
    psi,
    ks_statistic,
    ks_pvalue,
    severity,
    reference_mean,
    current_mean,
    reference_std,
    current_std,
    n_reference,
    n_current,
)

Drift metrics for a single column.

Attributes:

Name Type Description
column str

Column name.

psi float

Population Stability Index (lower is more stable).

ks_statistic float | None

KS two-sample test statistic, or None if not computed.

ks_pvalue float | None

KS test p-value, or None if not computed. Values below 0.05 indicate a statistically significant distributional difference.

severity DriftSeverity

Drift severity based on PSI thresholds.

reference_mean float

Mean of the column in the reference (training) dataset.

current_mean float

Mean of the column in the current (production) dataset.

reference_std float

Standard deviation in the reference dataset.

current_std float

Standard deviation in the current dataset.

n_reference int

Number of non-null observations in the reference dataset.

n_current int

Number of non-null observations in the current dataset.

mean_shift property

mean_shift

Absolute shift in the column mean.

mean_shift_pct property

mean_shift_pct

Mean shift as a percentage of the reference mean (0 if reference mean is 0).

clinops.monitor.drift.DriftSeverity

Bases: StrEnum

Severity levels for distribution drift.

Based on standard PSI thresholds used in healthcare model validation:

LOW PSI < 0.1 — distribution is stable; no action required. MEDIUM 0.1 <= PSI < 0.2 — moderate shift; review the column and investigate whether the change is clinically meaningful. HIGH PSI >= 0.2 — significant drift; model retraining or pipeline investigation is strongly recommended.

clinops.monitor.quality.DataQualityChecker

DataQualityChecker(
    max_null_rate=0.5,
    required_columns=None,
    expected_dtypes=None,
    min_rows=None,
    max_rows=None,
)

Run data quality checks on a clinical DataFrame.

Can be used standalone (check(df) only) or fitted on a reference DataFrame to also detect schema drift between pipeline runs.

Parameters:

Name Type Description Default
max_null_rate float

Null rate above which a column is flagged as a warning. Default 0.5.

0.5
required_columns list[str] | None

Columns that must be present and non-null. Any missing column is an error; any all-null required column is also an error.

None
expected_dtypes dict[str, str] | None

Dict mapping column name to expected dtype string (e.g. {"subject_id": "int64", "charttime": "datetime64[ns]"}). Dtype mismatches are reported as warnings.

None
min_rows int | None

Minimum number of rows expected. Fewer rows triggers an error.

None
max_rows int | None

Maximum number of rows expected. More rows triggers a warning.

None

Examples:

>>> checker = DataQualityChecker(required_columns=["subject_id", "charttime"])
>>> checker.fit(train_df)          # learn reference schema and row count
>>> report = checker.check(df)
>>> print(report.summary())
>>> if not report.passed:
...     raise RuntimeError("Data quality check failed")
>>> # Standalone (no reference schema)
>>> report = DataQualityChecker(max_null_rate=0.3).check(df)
Source code in clinops/monitor/quality.py
def __init__(
    self,
    max_null_rate: float = 0.5,
    required_columns: list[str] | None = None,
    expected_dtypes: dict[str, str] | None = None,
    min_rows: int | None = None,
    max_rows: int | None = None,
) -> None:
    self.max_null_rate = max_null_rate
    self.required_columns: list[str] = required_columns or []
    self.expected_dtypes: dict[str, str] = expected_dtypes or {}
    self.min_rows = min_rows
    self.max_rows = max_rows
    self._reference_schema: dict[str, str] = {}
    self._reference_row_count: int | None = None

fit

fit(df)

Learn the reference schema and row count from a baseline DataFrame.

After fitting, check() will also report columns that were added or removed relative to this reference.

Parameters:

Name Type Description Default
df DataFrame

Reference DataFrame (typically the training set).

required

Returns:

Type Description
DataQualityChecker

Self, for method chaining.

Source code in clinops/monitor/quality.py
def fit(self, df: pd.DataFrame) -> DataQualityChecker:
    """
    Learn the reference schema and row count from a baseline DataFrame.

    After fitting, ``check()`` will also report columns that were added
    or removed relative to this reference.

    Parameters
    ----------
    df:
        Reference DataFrame (typically the training set).

    Returns
    -------
    DataQualityChecker
        Self, for method chaining.
    """
    self._reference_schema = {str(col): str(dtype) for col, dtype in df.dtypes.items()}
    self._reference_row_count = len(df)
    logger.info(
        f"DataQualityChecker fitted: {len(df):,} rows, {len(self._reference_schema)} columns"
    )
    return self

check

check(df)

Run all configured quality checks against df.

Parameters:

Name Type Description Default
df DataFrame

DataFrame to check.

required

Returns:

Type Description
QualityReport
Source code in clinops/monitor/quality.py
def check(self, df: pd.DataFrame) -> QualityReport:
    """
    Run all configured quality checks against ``df``.

    Parameters
    ----------
    df:
        DataFrame to check.

    Returns
    -------
    QualityReport
    """
    issues: list[QualityIssue] = []
    null_rates: dict[str, float] = {}

    # Row count checks
    issues.extend(self._check_row_counts(df))

    # Required column presence
    for col in self.required_columns:
        if col not in df.columns:
            issues.append(
                QualityIssue(
                    column=col,
                    issue_type="column_removed",
                    severity="error",
                    detail=f"Required column '{col}' is missing from DataFrame",
                )
            )

    # Schema drift vs reference
    if self._reference_schema:
        issues.extend(self._check_schema_drift(df))

    # Per-column checks
    for col in df.columns:
        null_rate = float(df[col].isna().mean())
        null_rates[col] = null_rate

        # All-null required column → error
        if col in self.required_columns and null_rate == 1.0:
            issues.append(
                QualityIssue(
                    column=col,
                    issue_type="all_null",
                    severity="error",
                    detail=f"Required column '{col}' is entirely null",
                )
            )
        # High null rate → warning
        elif null_rate > self.max_null_rate:
            issues.append(
                QualityIssue(
                    column=col,
                    issue_type="high_null_rate",
                    severity="warning",
                    detail=(
                        f"Column '{col}' null rate {null_rate:.1%} "
                        f"exceeds threshold {self.max_null_rate:.1%}"
                    ),
                )
            )

        # Expected dtype mismatch
        if col in self.expected_dtypes:
            actual = str(df[col].dtype)
            expected = self.expected_dtypes[col]
            if actual != expected:
                issues.append(
                    QualityIssue(
                        column=col,
                        issue_type="dtype_changed",
                        severity="warning",
                        detail=(f"Column '{col}' dtype is '{actual}', expected '{expected}'"),
                    )
                )

    n_errors = sum(1 for i in issues if i.severity == "error")
    n_warnings = sum(1 for i in issues if i.severity == "warning")
    status = "FAILED" if n_errors else "PASSED"
    logger.info(
        f"DataQualityChecker [{status}]: "
        f"{n_errors} errors, {n_warnings} warnings "
        f"on {len(df):,} rows × {len(df.columns)} columns"
    )

    return QualityReport(
        issues=issues,
        n_rows=len(df),
        n_columns=len(df.columns),
        null_rates=null_rates,
    )

clinops.monitor.quality.QualityReport dataclass

QualityReport(issues, n_rows, n_columns, null_rates)

Structured result from :class:DataQualityChecker.

Attributes:

Name Type Description
issues list[QualityIssue]

All detected quality issues.

n_rows int

Number of rows in the checked DataFrame.

n_columns int

Number of columns in the checked DataFrame.

null_rates dict[str, float]

Per-column null rate (fraction 0–1).

errors property

errors

Issues with severity "error".

warnings property

warnings

Issues with severity "warning".

passed property

passed

True if there are no error-severity issues.

to_dataframe

to_dataframe()

Return issues as a DataFrame.

Source code in clinops/monitor/quality.py
def to_dataframe(self) -> pd.DataFrame:
    """Return issues as a DataFrame."""
    if not self.issues:
        return pd.DataFrame(columns=["column", "issue_type", "severity", "detail"])
    return pd.DataFrame(
        [
            {
                "column": i.column,
                "issue_type": i.issue_type,
                "severity": i.severity,
                "detail": i.detail,
            }
            for i in self.issues
        ]
    )

summary

summary()

Human-readable quality summary.

Source code in clinops/monitor/quality.py
def summary(self) -> str:
    """Human-readable quality summary."""
    lines = [
        f"Rows     : {self.n_rows:,}",
        f"Columns  : {self.n_columns}",
        f"Errors   : {len(self.errors)}",
        f"Warnings : {len(self.warnings)}",
        f"Passed   : {self.passed}",
    ]
    for issue in self.issues:
        prefix = "  [ERROR]  " if issue.severity == "error" else "  [WARN]   "
        lines.append(f"{prefix}{issue.detail}")
    return "\n".join(lines)

clinops.monitor.quality.QualityIssue dataclass

QualityIssue(column, issue_type, severity, detail)

A single data quality issue detected by :class:DataQualityChecker.

Attributes:

Name Type Description
column str

Affected column name, or "__dataframe__" for row-level issues.

issue_type str

One of "high_null_rate", "all_null", "column_added", "column_removed", "dtype_changed", "row_count_anomaly".

severity str

"error" for issues that will break downstream steps; "warning" for issues worth investigating but not fatal.

detail str

Human-readable description of the issue.