# src/pipeline.py
import math
from typing import Dict, List
import pandas as pd
from pandera.typing import DataFrame
from ai_cdss.models import ScoringSchema
from ai_cdss.constants import USAGE_WEEK, DAYS
import logging
logger = logging.getLogger(__name__)
[docs]
class CDSS:
"""
Clinical Decision Support System (CDSS) Class.
This system provides personalized recommendations for patients based on scoring data.
It allows protocol recommendations, scheduling, and prescription adjustments.
Parameters
----------
scoring : DataFrame
A DataFrame containing patient protocol scores.
n : int, optional
Number of top protocols to recommend, by default 12.
days : int, optional
Number of days for scheduling, by default 7.
protocols_per_day : int, optional
Maximum number of protocols per day, by default 5.
"""
def __init__(self, scoring: DataFrame[ScoringSchema], n: int = 12, days: int = 7, protocols_per_day: int = 5):
"""
Initialize the Clinical Decision Support System.
"""
self.scoring = scoring
self.n = n
self.days = days
self.protocols_per_day = protocols_per_day
[docs]
def recommend(self, patient_id: int, protocol_similarity) -> DataFrame[ScoringSchema]:
"""
Recommend prescriptions for a patient.
Parameters
----------
patient_id : int
The ID of the patient.
protocol_similarity : DataFrame
A DataFrame containing protocol similarity scores.
Returns
-------
DataFrame
A DataFrame mapping recommended protocol IDs to their scheduling details.
"""
# Get scores for patient
patient_data = self.scoring[self.scoring["PATIENT_ID"] == patient_id]
if patient_data.empty:
return pd.DataFrame()
# Get current prescriptions (which already include scores)
prescriptions = self.get_prescriptions(patient_id)
# Track protocol rows to output
rows = []
if not prescriptions.empty:
# ALL_PRESCRIPTIONS_WEEK_USAGE = 0, Repeat prescriptions
week_skipped = not prescriptions.apply(lambda x: True if x[USAGE_WEEK] >= len(x[DAYS]) else False, axis=1).any()
# Check this condition
if week_skipped:
logger.info(f"Patient {patient_id}, skipped the whole week, cdss repeating prescriptions.")
# Convert to DataFrame
recommendations = prescriptions
recommendations.attrs = self.scoring.attrs
return recommendations
# Identify which protocols need substitution
protocols_to_swap = self.decide_prescription_swap(patient_id)
protocols_excluded = prescriptions["PROTOCOL_ID"].tolist()
# Directly add non-swapped prescriptions
rows.extend(prescriptions[~prescriptions["PROTOCOL_ID"].isin(protocols_to_swap)].to_dict("records"))
# Swap selected protocols
for protocol_id in protocols_to_swap:
substitute = self.get_substitute(
patient_id, protocol_id, protocol_similarity, protocol_excluded=protocols_excluded
)
if substitute:
protocols_excluded.append(substitute)
substitute_row = self.get_scores(patient_id, substitute)
substitute_row["DAYS"] = prescriptions.loc[prescriptions["PROTOCOL_ID"] == protocol_id, "DAYS"].values[0]
substitute_row["PROTOCOL_ID"] = substitute
substitute_row["PATIENT_ID"] = patient_id
rows.append(substitute_row)
else:
# No prescriptions → Generate new schedule
top_protocols = self.get_top_protocols(patient_id)
schedule = self.schedule_protocols(top_protocols) # {day: [protocol_id, ...]}
seen = {} # protocol_id: row
for day, protocol_ids in schedule.items():
for protocol_id in protocol_ids:
if protocol_id not in seen:
row = self.get_scores(patient_id, protocol_id)
row["DAYS"] = [day]
row["PROTOCOL_ID"] = protocol_id
row["PATIENT_ID"] = patient_id
seen[protocol_id] = row
else:
seen[protocol_id]["DAYS"].append(day)
rows.extend(seen.values())
# Convert to DataFrame
recommendations = pd.DataFrame(rows).sort_values(by="PROTOCOL_ID").reset_index(drop=True)
recommendations.attrs = self.scoring.attrs
return recommendations
[docs]
def schedule_protocols(self, protocols: List[int]):
"""
Distribute protocols across days while ensuring constraints.
Parameters
----------
protocols : list of int
List of protocol IDs to distribute.
Returns
-------
dict
A dictionary mapping days to scheduled protocols.
"""
schedule = {day: [] for day in range(0, self.days)} # Days are 1-indexed
total_slots = self.days * self.protocols_per_day
if protocols:
# Repeat protocols as needed to fill the total slots
repeated_protocols = (protocols * math.ceil(total_slots / len(protocols)))[:total_slots]
# Distribute protocols evenly across days
for i, protocol in enumerate(repeated_protocols):
day = (i % self.days) # Distribute protocols in a round-robin fashion
if protocol not in schedule[day]:
schedule[day].append(protocol)
return schedule
[docs]
def decide_prescription_swap(self, patient_id: int) -> List[int]:
"""
Determine which prescriptions to swap based on their score.
Parameters
----------
patient_id : int
The ID of the patient.
Returns
-------
list of int
List of protocol IDs to be swapped.
"""
prescriptions = self.get_prescriptions(patient_id)
# Below protocols mean
return prescriptions[prescriptions['SCORE'].transform(lambda x: x < x.mean())].PROTOCOL_ID.to_list()
[docs]
def get_substitute(self, patient_id: int, protocol_id: int, protocol_similarity, protocol_excluded: List[int] = None):
"""
Find a suitable substitute for a given protocol.
Parameters
----------
patient_id : int
The ID of the patient.
protocol_id : int
The protocol to be substituted.
protocol_similarity : DataFrame
A DataFrame containing protocol similarity scores.
protocol_excluded : list of int, optional
List of protocols to exclude from consideration, by default None.
Returns
-------
int
The ID of the substitute protocol, or None if no suitable substitute is found.
"""
# Get protocol usage for the given patient and protocol
usage = (
self.scoring[self.scoring["PATIENT_ID"] == patient_id]
.set_index("PROTOCOL_ID")["USAGE"]
)
# Get protocol similarities
similarities = protocol_similarity[
(protocol_similarity["PROTOCOL_A"] == protocol_id)
]
# Drop rows where PROTOCOL_B is the same as PROTOCOL_A (self-similarity)
similarities = similarities[similarities["PROTOCOL_A"] != similarities["PROTOCOL_B"]]
# Exclude protocols in the `protocol_excluded` list from similarities
if protocol_excluded:
similarities = similarities[~similarities["PROTOCOL_B"].isin(protocol_excluded)]
# Find the minimum usage value
min_usage = usage.min()
# Get candidates with the lowest usage
candidates = usage[usage == min_usage].index
# Among these candidates, select the one with the highest similarity
candidate_similarities = similarities[similarities["PROTOCOL_B"].isin(candidates)]
# Find the maximum similarity among candidates
if not candidate_similarities.empty:
max_sim = candidate_similarities["SIMILARITY"].max()
final_candidates = candidate_similarities[
candidate_similarities["SIMILARITY"] == max_sim
]["PROTOCOL_B"]
# Return the first candidate (or handle ties)
return final_candidates.iloc[0] if not final_candidates.empty else None
else:
return None
[docs]
def get_top_protocols(self, patient_id: int) -> List[int]:
"""
Select the top N protocols for a patient based on scores.
Parameters
----------
patient_id : int
The ID of the patient.
Returns
-------
list of int
A list of top protocol IDs.
"""
patient_data = self.scoring[self.scoring["PATIENT_ID"] == patient_id]
top_protocols = patient_data.nlargest(self.n, "SCORE")["PROTOCOL_ID"].tolist()
return top_protocols
[docs]
def get_prescriptions(self, patient_id: int):
"""
Retrieve the current prescriptions for a patient.
Parameters
----------
patient_id : int
The ID of the patient.
Returns
-------
DataFrame
A DataFrame containing prescription details.
"""
patient_data = self.scoring[self.scoring["PATIENT_ID"] == patient_id]
prescriptions = patient_data[patient_data["DAYS"].apply(lambda x: isinstance(x, list) and len(x) > 0)]
return prescriptions
[docs]
def get_scores(self, patient_id: int, protocol_id: int):
"""
Retrieve scores for a given patient and protocol.
Parameters
----------
patient_id : int
The ID of the patient.
protocol_id : int
The ID of the protocol.
Returns
-------
dict
A dictionary containing score details for the specified patient and protocol.
"""
# Filter scoring DataFrame for the given patient and protocol
return self.scoring[
(self.scoring["PATIENT_ID"] == patient_id) & (self.scoring["PROTOCOL_ID"] == protocol_id)
].iloc[0].to_dict()