# ai_cdss/data_loader.py
from abc import ABC, abstractmethod
from typing import List
from pathlib import Path
import pandas as pd
import pandera as pa
from pandera.errors import SchemaError
from pandera.typing import DataFrame
from ai_cdss.models import SessionSchema, TimeseriesSchema, PPFSchema, PCMSchema, safe_check_types
from ai_cdss.evaluation.synthetic import (
generate_synthetic_session_data,
generate_synthetic_protocol_similarity,
generate_synthetic_timeseries_data,
generate_synthetic_ppf_data,
generate_synthetic_ids,
generate_synthetic_protocol_metric
)
from ai_cdss.constants import DEFAULT_DATA_DIR, PPF_PARQUET_FILEPATH
from rgs_interface.data.interface import DatabaseInterface
import shutil
import logging
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Base Data Loader
class DataLoaderBase(ABC):
@abstractmethod
def load_session_data(self, patient_list: List[int]) -> DataFrame[SessionSchema]:
pass
@abstractmethod
def load_timeseries_data(self, patient_list: List[int]) -> DataFrame[TimeseriesSchema]:
pass
@abstractmethod
def load_ppf_data(self, patient_list: List[int]) -> DataFrame[PPFSchema]:
pass
@abstractmethod
def load_protocol_similarity(self) -> DataFrame[PCMSchema]:
pass
@abstractmethod
def load_patient_clinical_data(self, patient_list: List[int]) -> pd.DataFrame:
pass
@abstractmethod
def load_patient_subscales(self, patient_list: str = None) -> pd.DataFrame:
pass
@abstractmethod
def load_protocol_attributes(self, file_path: str = None) -> pd.DataFrame:
pass
@abstractmethod
def load_protocol_init(self) -> pd.DataFrame:
pass
# ---------------------------------------------------------------------
# RGS Data Loader
[docs]
class DataLoader(DataLoaderBase):
"""
Loads data from database and CSV files.
Parameters
----------
rgs_mode : str, optional
Mode for fetching RGS data. Default is "plus".
"""
def __init__(self, rgs_mode: str = "plus"):
"""
Initialize the DataLoader with a list of patient IDs and RGS mode.
"""
self.interface: DatabaseInterface = DatabaseInterface()
self.rgs_mode = rgs_mode
# @safe_check_types(SessionSchema)
[docs]
def load_session_data(self, patient_list: List[int]) -> DataFrame[SessionSchema]:
"""
Load session data from the RGS interface.
New patients without prescriptions are not included in this table
Returns
-------
DataFrame[SessionSchema]
Session data for the specified patients.
"""
try:
session = self.interface.fetch_rgs_data(patient_list, rgs_mode=self.rgs_mode)
logger.debug("Session data loaded successfully.")
return session
except SchemaError as e:
logger.error(f"Data validation failed: {e}")
return pd.DataFrame(columns=SessionSchema.to_schema().columns.keys())
except Exception as e:
logger.error(f"Failed to load session data: {e}")
raise
# @safe_check_types(TimeseriesSchema)
[docs]
def load_timeseries_data(self, patient_list: List[int]) -> DataFrame[TimeseriesSchema]:
"""
Load timeseries data from the RGS interface.
New patients without prescriptions are not included in this table
Returns
-------
DataFrame[TimeseriesSchema]
Timeseries data for the specified patients.
"""
try:
timeseries = self.interface.fetch_dm_data(patient_list, rgs_mode=self.rgs_mode)
logger.debug(f"Timeseries data loaded successfully.")
return timeseries
except SchemaError as e:
logger.error(f"Data validation failed: {e}")
return pd.DataFrame(columns=TimeseriesSchema.to_schema().columns.keys())
except Exception as e:
logger.error(f"Failed to load timeseries data: {e}")
raise
# @safe_check_types(PPFSchema)
[docs]
def load_ppf_data(self, patient_list: List[int]) -> DataFrame[PPFSchema]:
"""
Load PPF (precomputed patient-protocol fit) from internal data.
Returns
-------
DataFrame[PPFSchema]
PPF data indexed by PROTOCOL_ID.
"""
try:
# Define PPF file path
ppf_path = PPF_PARQUET_FILEPATH
# Check if file exists
if ppf_path.exists():
ppf_data = pd.read_parquet(path = ppf_path)
else:
raise FileNotFoundError("No PPF file found in ~/.ai_cdss/output.")
# Filter PPF by patient list
ppf_data = ppf_data[ppf_data["PATIENT_ID"].isin(patient_list)]
# Check for missing patients
missing_patients = set(patient_list) - set(ppf_data["PATIENT_ID"].unique())
# If no PPF data for a patient
if missing_patients:
logger.warning(f"PPF missing for {len(missing_patients)} patients: {missing_patients}")
protocols = set(ppf_data["PROTOCOL_ID"].unique())
# Generate new rows where each missing patient is assigned every protocol
missing_combinations = pd.DataFrame([
{"PATIENT_ID": pid, "PROTOCOL_ID": protocol_id, "PPF": None, "CONTRIB": None} # Initialize to None
for pid in missing_patients
for protocol_id in protocols
])
# Concatenate missing patient data into the existing PPF dataset
ppf_data = pd.concat([ppf_data, missing_combinations], ignore_index=True)
return ppf_data
logger.debug("PPF data loaded successfully.")
return ppf_data
except Exception as e:
logger.error(f"Failed to load PPF data: {e}")
raise
[docs]
def load_patient_clinical_data(self, patient_list: List[int]) -> pd.DataFrame:
"""
Load patient clinical data from the RGS interface.
Parameters
----------
patient_list : List[int]
List of patient IDs to fetch clinical data for.
Returns
-------
pd.DataFrame
DataFrame containing clinical data for the specified patients.
"""
try:
clinical_data = self.interface.fetch_clinical_data(patient_list)
if clinical_data.empty:
logger.warning("No clinical data found for the specified patients.")
else:
logger.info(f"Clinical data loaded for {len(clinical_data)} patients.")
return clinical_data
except Exception as e:
logger.error(f"Failed to load patient clinical data: {e}")
raise
def load_patient_subscales(self, patient_list = None):
return _load_patient_subscales()
# @safe_check_types(PCMSchema)
[docs]
def load_protocol_similarity(self) -> DataFrame[PCMSchema]:
"""
Load protocol similarity data from internal storage.
Returns
-------
DataFrame[ProtocolSimilaritySchema]
Protocol similarity data with columns: PROTOCOL_ID_1, PROTOCOL_ID_2, SIMILARITY_SCORE.
"""
try:
# Define similarity file paths
output_dir = Path.home() / ".ai_cdss" / "output"
parquet_file = output_dir / "protocol_similarity.parquet"
csv_file = output_dir / "protocol_similarity.csv"
# Check if Parquet file exists
if parquet_file.exists():
similarity_data = pd.read_parquet(path = parquet_file).reset_index()
# Fall back to CSV if Parquet file is not found
elif csv_file.exists():
similarity_data = pd.read_csv(csv_file, index_col=0)
else:
raise FileNotFoundError(
"No protocol similarity file found in ~/.ai_cdss/output. "
"Expected either protocol_similarity.parquet or protocol_similarity.csv."
)
logger.debug("Protocol similarity data loaded successfully.")
return similarity_data
except Exception as e:
logger.error(f"Failed to load protocol similarity data: {e}")
raise
def load_protocol_attributes(self, file_path = None):
return _load_protocol_attributes()
def load_protocol_init(self) -> pd.DataFrame:
try:
output_dir = Path.home() / ".ai_cdss" / "output"
csv_file = output_dir / "protocol_metrics.csv"
if csv_file.exists():
protocol_metrics = pd.read_csv(csv_file, index_col=0)
else:
raise FileNotFoundError(
"No protocol metrics file found in ~/.ai_cdss/output. "
"Expected protocol_metrics.csv."
)
logger.info("Protocol initialization data loaded successfully.")
return protocol_metrics
except Exception as e:
logger.error(f"Failed to load protocol metrics data: {e}")
raise
# ---------------------------------------------------------------------
# Local Data Loader
class DataLoaderLocal(DataLoaderBase):
def load_session_data(self, patient_list: List[int]) -> DataFrame[SessionSchema]:
pass
def load_timeseries_data(self, patient_list: List[int]) -> DataFrame[TimeseriesSchema]:
pass
def load_ppf_data(self, patient_list: List[int]) -> DataFrame[PPFSchema]:
pass
def load_protocol_similarity(self) -> DataFrame[PCMSchema]:
pass
def load_patient_clinical_data(self, patient_list: List[int]) -> pd.DataFrame:
pass
def load_patient_subscales(self, patient_list = None):
return _load_patient_subscales(file_path=None)
def load_protocol_attributes(self, file_path: str = None) -> pd.DataFrame:
return _load_protocol_attributes(file_path=file_path)
def load_protocol_init(self) -> pd.DataFrame:
pass
# ---------------------------------------------------------------------
# Synthetic Data Loader
class DataLoaderMock(DataLoaderBase):
def __init__(self, num_patients: int = 5, num_protocols: int = 3, num_sessions: int = 10):
# Generate and store shared IDs
self.ids = generate_synthetic_ids(
num_patients=num_patients,
num_protocols=num_protocols,
num_sessions=num_sessions,
)
self.num_protocols = num_protocols
@safe_check_types(SessionSchema)
def load_session_data(self, patient_list: List[int] = []) -> DataFrame[SessionSchema]:
return generate_synthetic_session_data(shared_ids=self.ids)
@safe_check_types(TimeseriesSchema)
def load_timeseries_data(self, patient_list: List[int] = []) -> DataFrame[TimeseriesSchema]:
return generate_synthetic_timeseries_data(shared_ids=self.ids)
@pa.check_types
def load_ppf_data(self, patient_list: List[int] = []) -> DataFrame[PPFSchema]:
return generate_synthetic_ppf_data(shared_ids=self.ids)
@pa.check_types
def load_protocol_similarity(self) -> DataFrame[PCMSchema]:
return generate_synthetic_protocol_similarity(num_protocols=self.num_protocols)
def load_protocol_init(self) -> pd.DataFrame:
return generate_synthetic_protocol_metric(num_protocols=self.num_protocols)
def load_patient_clinical_data(self, patient_list):
return super().load_patient_clinical_data(patient_list)
def load_patient_subscales(self, patient_list = None):
return super().load_patient_subscales(patient_list)
def load_protocol_attributes(self, file_path = None):
return super().load_protocol_attributes(file_path)
# ---------------------------------------------------------------------
# - Utility Functions
# ---------------------------------------------------------------------
def safe_load_csv(file_path: str = None, default_filename: str = None) -> pd.DataFrame:
"""
Safely loads a CSV file, either from a given file path or from the default data directory.
Parameters:
file_path (str, optional): Full path to the CSV file. If not provided, `default_filename` is used.
default_filename (str, optional): Name of the file in the default directory.
Returns:
pd.DataFrame: Loaded data.
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file cannot be read as a valid CSV.
"""
file_path = Path(file_path) if file_path else DEFAULT_DATA_DIR / default_filename
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}. Ensure the correct path is specified.")
try:
df = pd.read_csv(file_path, index_col=0)
# If the file was loaded from outside the default directory, save a copy
default_file_path = DEFAULT_DATA_DIR / file_path.name
if file_path.parent != DEFAULT_DATA_DIR:
shutil.copy(file_path, default_file_path)
print(f"File copied to default directory: {default_file_path}")
return df
except Exception as e:
raise ValueError(f"Error reading {file_path}: {e}")
def _load_patient_subscales(file_path: str = None) -> pd.DataFrame:
"""Load patient clinical subscale scores from a given file or the default directory."""
return safe_load_csv(file_path, "clinical_scores.csv")
def _load_protocol_attributes(file_path: str = None) -> pd.DataFrame:
"""Load protocol attributes from a given file or the default directory."""
return safe_load_csv(file_path, "protocol_attributes.csv")