Source code for ai_cdss.cdss

# src/pipeline.py
import logging
import math
from typing import Dict, List, Optional

import pandas as pd
from ai_cdss.constants import (
    DAYS,
    PATIENT_ID,
    PROTOCOL_A,
    PROTOCOL_B,
    PROTOCOL_ID,
    SCORE,
    SIMILARITY,
    USAGE,
    USAGE_WEEK,
)
from ai_cdss.models import ScoringSchema
from pandera.typing import DataFrame

logger = logging.getLogger(__name__)


[docs] class CDSS: """ Clinical Decision Support System (CDSS) Class. This system provides personalized recommendations for patients based on scoring data. It allows protocol recommendations, scheduling, and prescription adjustments. Parameters ---------- scoring : DataFrame A DataFrame containing patient protocol scores. n : int, optional Number of top protocols to recommend, by default 12. days : int, optional Number of days for scheduling, by default 7. protocols_per_day : int, optional Maximum number of protocols per day, by default 5. """ def __init__( self, scoring: pd.DataFrame, n: int = 12, days: int = 7, protocols_per_day: int = 5, ): """ Initialize the Clinical Decision Support System. """ self.scoring = scoring self.n = n self.days = days self.protocols_per_day = protocols_per_day ########################################################################### # Recommendation method ###########################################################################
[docs] def recommend(self, patient_id: int, protocol_similarity) -> pd.DataFrame: """ Recommend prescriptions for a patient. """ if not self._has_patient_data(patient_id): raise ValueError(f"Patient {patient_id} has no data.") prescriptions = self._get_prescriptions(patient_id) if prescriptions.empty: return self._generate_new_recommendations(patient_id) if self._is_week_skipped(prescriptions): return self._repeat_prescriptions(prescriptions) return self._update_existing_recommendations( patient_id, prescriptions, protocol_similarity )
########################################################################### # Patient Bootstrap ########################################################################### def _generate_new_recommendations(self, patient_id: int) -> pd.DataFrame: # Generate a new schedule of protocols for a patient with no prescriptions top_protocols = self._get_top_protocols(patient_id) schedule = self._schedule_protocols(top_protocols) # Build the recommendations DataFrame rows: list[dict] = [] seen = {} for day, protocol_ids in schedule.items(): for protocol_id in protocol_ids: if protocol_id not in seen: row = self._get_scores(patient_id, protocol_id) row["DAYS"] = [day] row["PROTOCOL_ID"] = protocol_id row["PATIENT_ID"] = patient_id seen[protocol_id] = row else: seen[protocol_id]["DAYS"].append(day) rows.extend(seen.values()) recommendations = ( pd.DataFrame(rows).sort_values(by="PROTOCOL_ID").reset_index(drop=True) ) recommendations.attrs = self.scoring.attrs return recommendations def _get_top_protocols(self, patient_id: int) -> List[int]: """ Select the top N protocols for a patient based on scores. """ patient_data = self.scoring[self.scoring[PATIENT_ID] == patient_id] top_protocols = patient_data.nlargest(self.n, SCORE)[PROTOCOL_ID].tolist() return top_protocols def _schedule_protocols(self, protocols: List[int]) -> Dict[int, List[int]]: """ Distribute protocols across days while ensuring constraints. """ schedule: Dict[int, List[int]] = { day: [] for day in range(0, self.days) } # Days are 1-indexed total_slots = self.days * self.protocols_per_day if protocols: # Repeat protocols as needed to fill the total slots repeated_protocols = (protocols * math.ceil(total_slots / len(protocols)))[ :total_slots ] # Distribute protocols evenly across days for i, protocol in enumerate(repeated_protocols): day = i % self.days # Distribute protocols in a round-robin fashion if protocol not in schedule[day]: schedule[day].append(protocol) return schedule # protocol: [day, ...] ########################################################################### # Prescription Updates (Substitution Logic) ########################################################################### def _update_existing_recommendations( self, patient_id: int, prescriptions: pd.DataFrame, protocol_similarity ) -> pd.DataFrame: """ Update recommendations by swapping out underperforming protocols for better alternatives. """ # Identify protocols to swap and those to exclude from substitution protocols_to_swap: list[int] = self._decide_prescription_swap(patient_id) protocols_excluded: list[int] = prescriptions[PROTOCOL_ID].tolist() # Start with prescriptions that are not being swapped updated_rows: list[dict] = prescriptions[ ~prescriptions[PROTOCOL_ID].isin(protocols_to_swap) ].to_dict("records") # Swap out underperforming protocols for protocol_id in protocols_to_swap: substitute_row = self._swap_protocol( patient_id, protocol_id, prescriptions, protocol_similarity, protocols_excluded=protocols_excluded, ) logger.debug( "Swapping %s for %s for patient %s", protocol_id, substitute_row[PROTOCOL_ID], patient_id, ) updated_rows.append(substitute_row) protocols_excluded.append(substitute_row[PROTOCOL_ID]) # Create the recommendations DataFrame recommendations = ( pd.DataFrame(updated_rows) .sort_values(by=PROTOCOL_ID) .reset_index(drop=True) ) recommendations.attrs = self.scoring.attrs return recommendations ########################################################################### # Marginal Value Theorem (Swapping Criteria) def _decide_prescription_swap(self, patient_id: int) -> List[int]: """ Determine which prescriptions to swap based on their score. """ prescriptions = self._get_prescriptions(patient_id) # Below protocols mean return prescriptions[ prescriptions[SCORE].transform(lambda x: x < x.mean()) ].PROTOCOL_ID.to_list() ########################################################################### # Substitution Logic def _swap_protocol( self, patient_id: int, protocol_id: int, prescriptions: pd.DataFrame, protocol_similarity, protocols_excluded: list[int], ) -> dict: """ Find and return a substitute protocol row for a given protocol_id, or the same protocol if not found (all protocols are prescribed). """ substitute = self._get_substitute( patient_id, protocol_id, protocol_similarity, protocols_excluded=protocols_excluded, ) if substitute: substitute_row = self._get_scores(patient_id, substitute) substitute_row[DAYS] = prescriptions.loc[ prescriptions[PROTOCOL_ID] == protocol_id, DAYS ].values[0] substitute_row[PROTOCOL_ID] = substitute substitute_row[PATIENT_ID] = patient_id return substitute_row # Else return same protocol return self._get_scores(patient_id, protocol_id) def _get_substitute( self, patient_id: int, protocol_id: int, protocol_similarity: pd.DataFrame, protocols_excluded: Optional[List[int]] = None, ) -> Optional[int]: """ Find a suitable substitute for a given protocol. Returns the protocol ID of the substitute, or None if not found. """ usage = self._get_patient_protocol_usage(patient_id) similarities = self._get_protocol_similarities( protocol_id, protocol_similarity, protocols_excluded ) # Try to find unused protocols first unused_candidates = self._get_unused_candidates(usage) if unused_candidates: logger.info( "No usage for %s, selecting most similar from %s", protocol_id, unused_candidates, ) return self._select_most_similar(unused_candidates, similarities) # Otherwise, pick from top 5 similar protocols the least used top_similar_protocols = self._get_top_similar_protocols(similarities, top_n=5) least_used_candidates = self._get_least_used_candidates( usage, top_similar_protocols ) if least_used_candidates: logger.info( "No unused protocols for %s, selecting least used from %s", protocol_id, least_used_candidates, ) return self._select_most_similar(least_used_candidates, similarities) # If no candidates found return None ########################################################################### # USAGE def _get_patient_protocol_usage(self, patient_id: int) -> pd.Series: """Return protocol usage for the given patient.""" return self.scoring[self.scoring[PATIENT_ID] == patient_id].set_index( PROTOCOL_ID )[USAGE] def _get_unused_candidates(self, usage: pd.Series) -> List[int]: """Return protocol IDs with zero usage.""" unused = usage[usage == 0].index.tolist() return unused def _get_least_used_candidates( self, usage: pd.Series, candidate_protocols: List[int] ) -> List[int]: """Return protocol IDs among candidates with the least usage.""" candidate_usage = usage[usage.index.isin(candidate_protocols)] if candidate_usage.empty: return [] min_usage = candidate_usage.min() return candidate_usage[candidate_usage == min_usage].index.tolist() ########################################################################### # SIMILARITY def _get_protocol_similarities( self, protocol_id: int, protocol_similarity: pd.DataFrame, protocol_excluded: Optional[List[int]], ) -> pd.DataFrame: """Return similarities for a protocol, excluding self and any excluded protocols.""" similarities = protocol_similarity[ protocol_similarity[PROTOCOL_A] == protocol_id ] similarities = similarities[ similarities[PROTOCOL_A] != similarities[PROTOCOL_B] ] if protocol_excluded: similarities = similarities[ ~similarities[PROTOCOL_B].isin(protocol_excluded) ] return similarities def _get_top_similar_protocols( self, similarities: pd.DataFrame, top_n: int = 5 ) -> List[int]: """Return the protocol IDs of the top N most similar protocols.""" return similarities.nlargest(top_n, SIMILARITY)[PROTOCOL_B].tolist() def _select_most_similar( self, candidates: List[int], similarities: pd.DataFrame ) -> Optional[int]: """Return the candidate protocol with the highest similarity.""" candidate_similarities = similarities[similarities[PROTOCOL_B].isin(candidates)] if candidate_similarities.empty: return None max_sim = candidate_similarities[SIMILARITY].max() final_candidates = candidate_similarities[ candidate_similarities[SIMILARITY] == max_sim ][PROTOCOL_B] return final_candidates.iloc[0] if not final_candidates.empty else None ########################################################################### # Validation Utilities ########################################################################### def _is_week_skipped(self, prescriptions: pd.DataFrame) -> bool: """ Return True if and only if all prescriptions scheduled this week were skipped (i.e., no session was performed at all). We consider only rows that have at least one scheduled day this week (len(DAYS) > 0). For the week to be 'skipped', each of those rows must have USAGE_WEEK == 0. """ if prescriptions.empty: return False # can't say 'skipped' if nothing is scheduled # Keep only prescriptions that actually have days scheduled this week scheduled_mask = prescriptions[DAYS].apply(lambda d: len(d) > 0) scheduled = prescriptions[scheduled_mask] if scheduled.empty: return False # nothing scheduled -> not 'skipped' (adjust if you prefer True) # Week is skipped if and only if none of the scheduled prescriptions recorded any usage return (scheduled[USAGE_WEEK] == 0).all() def _is_partially_week_skipped(self, prescriptions: pd.DataFrame) -> bool: """ Check if all prescriptions are partially skipped for the whole week. True only if all prescriptions were used less than the number of scheduled days """ # Apply a lambda function to each row: check if USAGE_WEEK >= number of DAYS for that prescription # If True for any row, .any() will return True (week is skipped for at least one prescription) return not prescriptions.apply( lambda x: x[USAGE_WEEK] >= len(x[DAYS]), axis=1 ).any() def _has_patient_data(self, patient_id: int) -> bool: """Check if patient has scoring data.""" patient_data = self.scoring[self.scoring[PATIENT_ID] == patient_id] return not patient_data.empty def _repeat_prescriptions(self, prescriptions) -> pd.DataFrame: """Repeat existing prescriptions when week was skipped.""" logger.info( "Patient %s, skipped the whole week, cdss repeating prescriptions.", prescriptions[PATIENT_ID].iloc[0] if not prescriptions.empty else "unknown", ) df = prescriptions.copy() df.attrs = getattr(self.scoring, "attrs", {}) return df # type: ignore ########################################################################### # General Utilities ########################################################################### def _get_scores(self, patient_id: int, protocol_id: int): """ Retrieve scores for a given patient and protocol. """ # Filter scoring DataFrame for the given patient and protocol return ( self.scoring[ (self.scoring[PATIENT_ID] == patient_id) & (self.scoring[PROTOCOL_ID] == protocol_id) ] .iloc[0] .to_dict() ) def _get_prescriptions(self, patient_id: int): """ Retrieve the current prescriptions for a patient. """ patient_data = self.scoring[self.scoring[PATIENT_ID] == patient_id] prescriptions = patient_data[ patient_data[DAYS].apply(lambda x: isinstance(x, list) and len(x) > 0) ] return prescriptions