# ai_cdss/models.py
import pandera as pa
import pandas as pd
from typing import List, Callable, Type
from functools import partial, wraps
import logging
NullableField = partial(pa.Field, nullable=True)
# Set up logging
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# RGS Data Input
[docs]
class SessionSchema(pa.DataFrameModel):
"""
Schema for RGS session-level data, including patient profile, prescription and session details.
"""
# Patient profile
patient_id: int = NullableField(alias='PATIENT_ID')
# Identifiers
prescription_id: int = NullableField(alias='PRESCRIPTION_ID')
session_id: int = NullableField(alias='SESSION_ID')
protocol_id: int = NullableField(alias='PROTOCOL_ID')
# Prescription
prescription_starting_date: pa.DateTime = NullableField(alias='PRESCRIPTION_STARTING_DATE')
prescription_ending_date: pa.DateTime = NullableField(alias='PRESCRIPTION_ENDING_DATE')
# Session
session_date: pa.DateTime = NullableField(alias='SESSION_DATE')
weekday: int = NullableField(alias='WEEKDAY_INDEX', ge=0, le=6, description="Weekday Index (0=Monday, 6=Sunday)")
status: str = NullableField(alias='STATUS', isin=["CLOSED", "ABORTED", "ONGOING"])
# Metrics
real_session_duration: int = NullableField(alias='REAL_SESSION_DURATION', ge=0)
prescribed_session_duration: int = NullableField(alias='PRESCRIBED_SESSION_DURATION', ge=0)
session_duration: int = NullableField(alias='SESSION_DURATION', ge=0)
adherence: float = NullableField(alias='ADHERENCE', ge=0, le=1)
total_success: int = NullableField(alias='TOTAL_SUCCESS', ge=0)
total_errors: int = NullableField(alias='TOTAL_ERRORS', ge=0)
game_score: int = NullableField(alias='GAME_SCORE', ge=0)
[docs]
class TimeseriesSchema(pa.DataFrameModel):
"""
Schema for timeseries session data. Includes measurements per-second of difficulty modulators (DM) and performance estimates (PE).
"""
# Identifiers
patient_id: int = NullableField(alias="PATIENT_ID", gt=0)
session_id: int = NullableField(alias="SESSION_ID", gt=0)
protocol_id: int = NullableField(alias="PROTOCOL_ID", gt=0)
# Protocol
game_mode: str = NullableField(alias="GAME_MODE")
# Time
timepoint: int = NullableField(alias="SECONDS_FROM_START")
# Metrics
dm_key: str = NullableField(alias="DM_KEY")
dm_value: float = NullableField(alias="DM_VALUE")
pe_key: str = NullableField(alias="PE_KEY")
pe_value: float = NullableField(alias="PE_VALUE")
[docs]
class PPFSchema(pa.DataFrameModel):
"""
Schema for Patient-Protocol Fit (PPF) data. Represents how well a protocol fits a patient, including a PPF score and feature contributions.
"""
patient_id: int = pa.Field(alias="PATIENT_ID")
protocol_id: int = pa.Field(alias="PROTOCOL_ID")
ppf: float = pa.Field(alias="PPF")
contrib: object = pa.Field(alias="CONTRIB")
[docs]
class PCMSchema(pa.DataFrameModel):
"""
Schema for protocol similarity matrix. Include pairwise similarity scores between protocols based on clinical domain overlap.
"""
protocol_a: int = pa.Field(alias="PROTOCOL_A")
protocol_b: int = pa.Field(alias="PROTOCOL_B")
similarity: float = pa.Field(alias="SIMILARITY")
# ---------------------------------------------------------------------
# Recommender Output
[docs]
class ScoringSchema(pa.DataFrameModel):
"""
Schema for prescription scoring output. Represents the result of a recommendation.
"""
class Config:
coerce = True
patient_id: int = pa.Field(alias="PATIENT_ID", gt=0, description="Must be a positive integer.")
protocol_id: int = pa.Field(alias="PROTOCOL_ID", gt=0, description="Must be a positive integer.")
adherence: float = pa.Field(alias="ADHERENCE_RECENT", ge=0, le=1, description="Must be a probability (0-1).")
dm: float = pa.Field(alias="DELTA_DM") # , ge=-1, le=1, description="Must be between (-1, 1).")
ppf: float = pa.Field(alias="PPF", ge=0, le=1, description="Must be a probability (0-1).")
contrib: List[float] = pa.Field(alias="CONTRIB", nullable=False, coerce=True)
score: float = pa.Field(alias="SCORE", ge=0, description="Score must be a positive float.")
usage: int = pa.Field(alias="USAGE", ge=0, description="Usage count must be a non-negative integer.")
days: List[int] = pa.Field(alias="DAYS", description="Days of the week the protocol is prescribed.")
# ---------------------------------------------------------------------
# Validation Decorator
def safe_check_types(schema_model: Type[pa.DataFrameModel]):
"""
Custom decorator: skips dtype checks for nullable columns with all null values.
schema_model: A pandera DataFrameModel class.
"""
schema = schema_model.to_schema()
schema_name = schema_model.__name__ # Get the name of the schema model class
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
df: pd.DataFrame = func(*args, **kwargs)
if df.empty:
logger.warning(
f"Returned DataFrame from `{func.__name__}` is empty. "
f"Kwargs: {kwargs}"
)
return df
modified_columns = {}
skipped_columns = []
for col_name, col_schema in schema.columns.items():
if col_schema.nullable and df[col_name].isna().all():
skipped_columns.append(col_name) # Skip dtype validation for this nullable column with all nulls
modified_columns[col_name] = pa.Column(
dtype=None,
checks=col_schema.checks,
nullable=col_schema.nullable,
required=col_schema.required,
unique=col_schema.unique,
coerce=col_schema.coerce,
regex=col_schema.regex,
description=col_schema.description,
title=col_schema.title,
)
else:
# Keep original schema if dtype validation is needed
modified_columns[col_name] = col_schema
# Log all skipped columns once
if skipped_columns:
logger.debug(
f"Skipped dtype check for empty columns in `{schema_name}`: {', '.join(skipped_columns)}"
)
# Reconstruct modified schema
temp_schema = pa.DataFrameSchema(
columns=modified_columns,
checks=schema.checks,
index=schema.index,
dtype=schema.dtype,
coerce=schema.coerce,
strict=schema.strict,
)
# Perform validation
validated_df = temp_schema.validate(df)
return validated_df
return wrapper
return decorator