# Local imports
import copy
import getpass
import multiprocessing
import os
import pickle
# External imports
import traceback
from datetime import datetime
from typing import Literal
import ehrapy as ep
import pandas as pd
from tableone import TableOne
from tqdm import tqdm
import corr_vars.core.extract as ex
import corr_vars.utils as utils
from corr_vars import logger
from corr_vars.static import GLOBAL_VARS, VARS
# pd.set_option('future.no_silent_downcasting', True)
class CohortTracker(ep.tl.CohortTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def plot_cohort_barplot(self, *args, **kwargs):
"""We are not implementing column tracking due to unhashable column types, so this function is disabled"""
raise NotImplementedError("This method is not yet implemented.")
[docs]
class Cohort:
"""Class to build a cohort in the CORR database.
Args:
conn_args:
Dictionary of [remote_hostname, username]
password_file:
Path to the password file or True if your file is in ~/password.txt.
database:
Default: "db_hypercapnia_prepared"
extraction_end_date:
End date for the extraction in the format "YYYY-MM-DD" (default: today).
obs_level:
Observation level (default: "icu_stay").
project_vars:
Dictionary with local variable definitions.
merge_consecutive:
Whether to merge consecutive ICU stays (default: True). Does not apply to any other obs_level.
load_default_vars:
Whether to load the default variables (default: True).
filters:
Initial filters (must be a valid SQL WHERE clause for the it_ishmed_fall table).
Attributes:
obs (pd.DataFrame):
Static data for each observation. Contains one row per observation (e.g., ICU stay)
with columns for static variables like demographics and outcomes.
Example:
>>> cohort.obs
patient_id case_id icu_stay_id icu_admission icu_discharge sex ... inhospital_death
0 P001 C001 C001_1 2023-01-01 08:30:00 2023-01-03 12:00:00 M ... False
1 P001 C001 C001_2 2023-01-03 14:20:00 2023-01-05 16:30:00 M ... False
2 P002 C002 C002_1 2023-01-02 09:15:00 2023-01-04 10:30:00 F ... False
3 P003 C003 C003_1 2023-01-04 11:45:00 2023-01-07 13:20:00 F ... True
...
obsm (dict of pd.DataFrame):
Dynamic data stored as dictionary of DataFrames. Each DataFrame contains time-series
data for a variable with columns:
- recordtime: Timestamp of the measurement
- value: Value of the measurement
- recordtime_end: End time (only for duration-based variables like therapies)
- description: Additional information (e.g., medication names)
Example:
>>> cohort.obsm["blood_sodium"]
icu_stay_id recordtime value
0 C001_1 2023-01-01 09:30:00 138
1 C001_1 2023-01-02 10:15:00 141
2 C001_2 2023-01-03 15:00:00 137
3 C002_1 2023-01-02 10:00:00 142
4 C003_1 2023-01-04 12:30:00 139
...
variables (dict of Variable):
Dictionary of all variable objects in the cohort. This is used to keep track of variable metadata.
Notes:
- For large cohorts, set ``load_default_vars=False`` to speed up the extraction. You can use
pre-extracted cohorts as starting points and load them using ``Cohort.load()``.
- Variables can be added using ``cohort.add_variable()``. Static variables will be added to ``obs``,
dynamic variables to ``obsm``.
Examples:
Create a new cohort:
>>> cohort = Cohort(obs_level="icu_stay",
... database="db_hypercapnia_prepared",
... load_default_vars=False,
... password_file=True)
Access static data:
>>> cohort.obs["age_on_admission"] # Get age for all patients
>>> cohort.obs.loc[cohort.obs["sex"] == "M"] # Filter for male patients
Access time-series data:
>>> cohort.obsm["blood_sodium"] # Get all blood sodium measurements
>>> # Get blood sodium measurements for a specific observation
>>> cohort.obsm["blood_sodium"].loc[
... cohort.obsm["blood_sodium"][cohort.primary_key] == "12345"
... ]
"""
def __init__(
self,
conn_args: dict = {},
password_file: str | bool = None,
database: Literal[
"db_hypercapnia_prepared", "db_corror_prepared"
] = "db_hypercapnia_prepared",
extraction_end_date: str = datetime.now().strftime("%Y-%m-%d"),
obs_level: Literal["icu_stay", "hospital_stay", "procedure"] = "icu_stay",
project_vars: dict = {},
merge_consecutive: bool = True,
load_default_vars: bool = True,
filters: str = "",
):
# Add flag to indicate this is a newly created cohort
self.from_file = False
self.conn_args = conn_args
self.database = database
self.extraction_end_date = extraction_end_date
self.obs_level = obs_level
self.merge_consecutive = merge_consecutive
self.project_vars = project_vars
if filters.startswith(
"_d"
): # Debug Shorthand: Extract number of months from filter string (e.g. "_d2" -> 2 months)
months = int(filters[2:])
start_date = "2023-01-01"
end_date = (
datetime.strptime(start_date, "%Y-%m-%d")
+ pd.Timedelta(days=30 * months)
).strftime("%Y-%m-%d")
self.filters = f"c_aufnahme BETWEEN '{start_date}' AND '{end_date}'"
print(f"DEBUG: Filters set to {self.filters}")
else:
self.filters = filters
self.password_file = password_file
if password_file:
self.conn = utils.ImpalaConnector.with_password_file(
password_file=password_file, **conn_args
)
else:
password = getpass.getpass(prompt="Enter your password: ")
self.conn = utils.ImpalaConnector(password=password, **conn_args)
self._set_primary_keys()
self._load_obs_level_data()
# Temporary storage for raw and intermediary data (mostly native_dynamic variables)
self._axiom_cache = {}
# All variable objects in the cohort (important for keeping metadata of static variables, e.g. tmin, tmax)
self.variables = {}
# Finalized time-series data for analysis
self.obsm = {}
# Inclusion criteria
self.inclusion_criteria = []
self.include_ct = None
# Exclusion criteria
self.exclusion_criteria = []
self.exclude_ct = None
# Constant variables (not dependent on tmin tmax)
self.constant_vars = [
col for col in self.obs.columns if col not in ["case_id", self.primary_key]
]
# Load default variables
if load_default_vars:
self.load_default_vars()
# Cache for adata compatibility
self._adata = None
self._adata_hash = None
print(
f"SUCCESS: Extracted data. Cohort has {len(self.obs)} observations ({self.obs_level})."
)
def _set_primary_keys(self):
match self.obs_level:
case "hospital_stay":
self.primary_key = "case_id"
self.t_min = "hospital_admission"
self.t_max = "hospital_discharge"
self.t_eligible = "hospital_admission"
self.t_outcome = "hospital_discharge"
case "icu_stay":
self.primary_key = "icu_stay_id"
self.t_min = "icu_admission"
self.t_max = "icu_discharge"
self.t_eligible = "icu_admission"
self.t_outcome = "icu_discharge"
case "procedure":
self.primary_key = "procedure_id"
self.t_min = "op_start_dtime_any"
self.t_max = "op_end_dtime_any"
self.t_eligible = "op_start_dtime_any"
self.t_outcome = "hospital_discharge"
case _:
raise ValueError(f"Observation level {self.obs_level} not supported.")
def _load_obs_level_data(self):
self._load_ishmed_fall()
match self.obs_level:
case "hospital_stay":
pass # No specific data to load / this is the most general level
case "icu_stay":
self._load_ishmed_bewegung()
self._load_icu_stays()
case "procedure":
self._load_procedures()
self._load_procedure_id()
# System methods
def __repr__(self):
return f"Cohort(obs_level={self.obs_level}, len(obs)={len(self.obs)}, variables={list(self.variables.keys())})"
def __str__(self):
out = "Cohort object\n"
out += f" obs_level: {self.obs_level}\n"
out += f" len(obs): {len(self.obs)}\n"
out += f" filters: {self.filters}\n"
out += f" variables: {self.variables}\n"
out += "\n"
out += "obs (first 5 rows):\n"
out += self.obs.head().to_string()
out += "\n\n"
return out
[docs]
def debug_print(self):
"""Print debug information about the cohort. Please use this if you are creating a GitHub issue.
Returns:
None
"""
print(
f"""Cohort object\n
- obs_level: {self.obs_level}
- len(obs): {len(self.obs)}
- filters: {self.filters}
- obs: {self.obs.columns.to_list()}
- obsm: {list(self.obsm.keys())}
- axiom_cache: {list(self._axiom_cache.keys())}
- t_eligible: {self.t_eligible}
- t_outcome: {self.t_outcome}
- inclusion_criteria: {self.inclusion_criteria}
- exclusion_criteria: {self.exclusion_criteria}
"""
)
utils.print_debug_info()
# Save and load methods
[docs]
def save(self, filename: str):
"""
Save the cohort to a pickle file.
Args:
filename: Path to the pickle file.
Returns:
None
"""
# Avoid pickling unpickable objects
conn = self.conn
self.conn = None
# Add .corr suffix
filename = filename + ".corr" if not filename.endswith(".corr") else filename
# Pickle the rest
with open(filename, "wb") as f:
pickle.dump(self, f)
self.conn = conn
print(f"SUCCESS: Saved cohort to {filename}")
[docs]
@classmethod
def load(cls, filename: str, conn_args: dict = {}):
"""
Load a cohort from a pickle file. If this file was saved by a different user, you need to pass your database credentials to the function.
Args:
filename: Path to the pickle file.
conn_args: Database credentials [remote_hostname, username, password_file].
Returns:
Cohort: A new Cohort object.
"""
with open(filename, "rb") as f:
cohort = pickle.load(f)
if hasattr(cohort, "password_file"):
password_file = cohort.password_file
else:
password_file = None
if conn_args == {}:
conn_args = cohort.conn_args
# Re-init the connection
try:
if password_file:
cohort.conn = utils.ImpalaConnector.with_password_file(
password_file=password_file, **conn_args
)
else:
password = getpass.getpass(prompt="Enter your password: ")
cohort.conn = utils.ImpalaConnector(password=password, **conn_args)
except Exception as e:
raise ConnectionError(
f"Failed to re-initialize connection: {e} \n Please veryify your database credentials."
)
cohort.from_file = True
print(f"SUCCESS: Loaded cohort from {filename}")
return cohort
def __getstate__(self):
"""
Called when pickling the object. Removes the unpickleable connection.
"""
state = self.__dict__.copy()
del state["conn"] # Avoid pickling the connection
return state
def __setstate__(self, state):
"""
Called when unpickling the object. Re-initializes the connection.
"""
self.__dict__.update(state)
self.conn = None # Will be re-established when load() is called
def _extr_end_date(self, table):
return f"CAST({table}.`_hdl_loadstamp` AS DATE) <= '{self.extraction_end_date}'"
def _get_case_id_query(self):
query = f"SELECT DISTINCT c_falnr FROM {self.database}.it_ishmed_fall WHERE {self._extr_end_date('it_ishmed_fall')}"
if self.filters:
query += f" AND {self.filters}"
return query
def _convert_unhashable_cols(self, df):
"""
Convert unhashable column types to hashable ones.
"""
df = df.copy()
for col in df.columns:
if isinstance(df[col].iloc[0], list):
df[col] = df[col].apply(lambda x: tuple(x))
elif isinstance(df[col].iloc[0], dict):
df[col] = df[col].apply(lambda x: frozenset(x.items()))
return df
def _get_data_hash(self):
"""
Get a hash of the cohort data.
"""
obs = self._convert_unhashable_cols(self.obs)
obs_hash = hash(pd.util.hash_pandas_object(obs).sum())
obsm_hash = sum(
hash(pd.util.hash_pandas_object(df).sum()) for df in self.obsm.values()
)
return obs_hash + obsm_hash
@property
def adata(self):
"""AnnotatedData representation of the cohort (NEW - Cached version)
Returns an `AnnData <https://anndata.readthedocs.io/en/stable/generated/anndata.AnnData.html>`_ object, which stores a data matrix together with annotations
of observations, variables, and unstructured annotations.
Warning:
This returns a copy of the data. Modifications to the returned object will not be reflected in the cohort.
To modify the cohort through the Adata object, use Cohort._overwrite_from_adata()
"""
current_hash = self._get_data_hash()
if self._adata is not None and self._adata_hash == current_hash:
return self._adata
else:
self._adata = self.to_adata()
self._adata_hash = current_hash
return self._adata
[docs]
def to_adata(self):
"""
Convert the cohort to an AnnData object.
Returns:
AnnData: An AnnData object.
"""
obs = self._convert_unhashable_cols(self.obs)
static_cols = [col for col in obs.columns if col not in [self.primary_key]]
adata = ep.ad.df_to_anndata(
obs, index_column=self.primary_key, columns_obs_only=static_cols
)
# Add dynamic variables
for var_name, df_var in self.obsm.items():
cols_agg = {
col: lambda x: list(x)
for col in df_var.columns
if col not in [self.primary_key]
}
grouped = df_var.groupby(self.primary_key).agg(cols_agg).reset_index()
merged_df = adata.obs.merge(grouped, on=self.primary_key, how="left")
merged_df.set_index(self.primary_key, inplace=True)
adata.obsm[var_name] = merged_df[cols_agg.keys()]
return adata
[docs]
def to_csv(self, folder: str):
"""
Save the cohort to CSV files.
Args:
folder: Path to the folder.
"""
os.makedirs(folder, exist_ok=True)
self.obs.to_csv(os.path.join(folder, "_obs.csv"), index=False)
for var_name, var_data in self.obsm.items():
var_data.to_csv(os.path.join(folder, f"{var_name}.csv"), index=False)
def _overwrite_from_adata(self, adata):
self.obs = adata.obs.copy().reset_index()
self.obsm = {}
for var_name in adata.obsm.keys():
self.obsm[var_name] = adata.obsm[var_name]
[docs]
def tableone(
self,
ignore_cols: list = [],
groupby: str = None,
filter: str = None,
pval: bool = False,
**kwargs,
) -> TableOne:
"""
Create a `TableOne <https://tableone.readthedocs.io/en/latest/index.html>`_ object for the cohort.
Args:
ignore_cols: Columns to ignore.
groupby: Column to group by.
filter: Filter to apply to the data.
pval: Whether to calculate p-values.
**kwargs: Additional arguments to pass to TableOne.
Returns:
TableOne: A TableOne object.
Examples:
>>> tableone = cohort.tableone()
>>> print(tableone)
>>> tableone.to_csv("tableone.csv")
>>> tableone = cohort.tableone(groupby="sex", pval=False)
>>> print(tableone)
>>> tableone.to_csv("tableone_sex.csv")
"""
# Ignore datetime columns
ignore_cols += [
col for col in self.obs.select_dtypes(include=["datetime"]).columns
]
# Ignore more columns
ignore_cols += [
"case_id",
"icu_stay_id",
"patient_id",
"procedure_id",
"c_patnr",
]
if groupby is not None:
ignore_cols += [groupby]
columns = [col for col in self.obs.columns if col not in ignore_cols]
# Reorder columns to put icu_id last (because it is so long...)
if "icu_id" in columns:
columns = [col for col in columns if col != "icu_id"] + ["icu_id"]
# Set non-normal distribution for all numeric columns
nonnormal_cols = [
col
for col in self.obs.select_dtypes(include=["number"]).columns
if col in columns and col not in ["inhospital_death"]
]
# Explicitly set booleans and objects to categorical
categorical_cols = [
col
for col in self.obs.select_dtypes(
include=["bool", "object", "category"]
).columns
if col in columns
]
# Filter data
if filter is not None:
obs_filtered = self.obs.query(filter)
else:
obs_filtered = self.obs
logger.info(f"Columns: {columns}")
logger.info(f"Nonnormal: {nonnormal_cols}")
logger.info(f"Categorical: {categorical_cols}")
logger.info(f"Groupby: {groupby}")
return TableOne(
obs_filtered,
columns=columns,
nonnormal=nonnormal_cols,
categorical=categorical_cols,
groupby=groupby,
pval=pval,
**kwargs,
)
def _load_ishmed_fall(self):
"""
Get the static data from the it_ishmed_fall table.
"""
query = f"""
SELECT
it_ishmed_fall.c_patnr AS patient_id,
it_ishmed_fall.c_falnr AS case_id,
it_ishmed_fall.c_aufnahme AS hospital_admission,
it_ishmed_fall.c_entlassung AS hospital_discharge,
it_ishmed_patient.c_gender AS sex,
it_ishmed_patient.c_birthdate AS birthdate,
it_ishmed_patient.c_datetimeofdeath AS death_timestamp
FROM {self.database}.it_ishmed_fall
LEFT JOIN {self.database}.it_ishmed_patient USING (c_falnr)
WHERE {self._extr_end_date('it_ishmed_fall')}
"""
if self.filters:
query += f" AND {self.filters}"
obs = self.conn.df_read_sql(query=query)
for col in [
"hospital_admission",
"hospital_discharge",
"birthdate",
"death_timestamp",
]:
obs[col] = pd.to_datetime(obs[col], errors="coerce")
self.obs = obs
print(f"SUCCESS: Extracted ISHMED_FALL data for {len(self.obs)} cases")
def _load_ishmed_bewegung(self):
"""
Get the ICU stays from the it_ishmed_bewegung table.
"""
mv_adm_transfer = (
GLOBAL_VARS["hospital_data"]["mv_types"]["Admission"]
+ GLOBAL_VARS["hospital_data"]["mv_types"]["Transfer"]
)
mv_adm_transfer = utils.json_list_sql(mv_adm_transfer)
query = f"""
SELECT
c_falnr AS case_id,
c_begin AS icu_admission,
c_ende AS icu_discharge,
c_pflege_oe_id AS icu_id
FROM
{self.database}.it_ishmed_bewegung
WHERE
c_bewegungsart IN ({mv_adm_transfer})
AND c_pflege_oe_id IN ({utils.json_list_sql(GLOBAL_VARS['hospital_data']['icu_ids'])})
AND {self._extr_end_date('it_ishmed_bewegung')}
AND c_falnr IN ({self._get_case_id_query()})
"""
df_bew = self.conn.df_read_sql(query=query)
for col in ["icu_admission", "icu_discharge"]:
df_bew[col] = pd.to_datetime(df_bew[col], errors="coerce")
if self.merge_consecutive:
print("Merging consecutive ICU stays...")
grouped = df_bew.groupby("case_id")
groups_to_merge = [group for name, group in grouped if len(group) > 1]
# Apply the merging function only to groups with multiple stays
# merged_groups = [utils.merge_consecutive_stays(group) for group in tqdm(groups_to_merge, desc="Merging consecutive stays", unit="case")]
with multiprocessing.Pool() as pool:
merged_groups = list(
tqdm(
pool.imap(utils.merge_consecutive_stays, groups_to_merge),
total=len(groups_to_merge),
desc="Merging consecutive stays",
unit="cases",
)
)
# Combine the merged groups with the single-stay groups
single_stay_groups = grouped.filter(lambda x: len(x) == 1)
df_bew_merged = pd.concat([single_stay_groups] + merged_groups)
df_bew_merged = df_bew_merged.sort_values(
by=["case_id", "icu_admission"]
).reset_index(drop=True)
print(f"Original number of stays: {len(df_bew)}")
print(f"Number of stays after merging: {len(df_bew_merged)}")
print(
f"Percentage of stays after merging: {len(df_bew_merged)/len(df_bew):.2%}"
)
self.df_bew = df_bew_merged
else:
self.df_bew = df_bew
print("SUCCESS: Extracted ICU stays")
def _load_icu_stays(self):
"""
Merge the ICU stays with the static data to get a dataframe with icu stay as primary key.
"""
df_icu_stays = self.df_bew.merge(self.obs, on="case_id", how="left")
print(f"Number of stays: {len(df_icu_stays)}")
df_icu_stays = utils.filter_by_condition(
df_icu_stays,
lambda df: df["icu_admission"] < df["hospital_admission"],
description="icu_admission < hospital_admission",
mode="drop",
)
df_icu_stays = utils.filter_by_condition(
df_icu_stays,
lambda df: df["icu_discharge"] > df["hospital_discharge"],
description="icu_discharge > hospital_discharge",
mode="drop",
)
df_icu_stays["icu_stay_id"] = (
df_icu_stays.groupby("case_id").cumcount().add(1).astype(str)
)
df_icu_stays["icu_stay_id"] = (
df_icu_stays["case_id"] + "_" + df_icu_stays["icu_stay_id"]
)
# Re-order columns
df_icu_stays = df_icu_stays[
["case_id", "icu_stay_id", "icu_admission", "icu_discharge"]
+ [
col
for col in df_icu_stays.columns
if col
not in ["case_id", "icu_stay_id", "icu_admission", "icu_discharge"]
]
]
print("SUCCESS: Merged ICU stays with static data")
self.obs = df_icu_stays
def _load_procedures(self):
"""
Get the procedures from the it_ishmed_procedure table and assign OP times from COPRA.
"""
# Step 1: Get SAP surgeries (OPS code starting with 5)
df_sap = self.conn.df_read_sql(
f"""
SELECT
c_falnr AS case_id,
c_prozedur_code AS ops_code,
c_prozedur_begin
FROM
{self.database}.it_ishmed_procedure
WHERE
c_prozedur_code LIKE '5%'
AND {self._extr_end_date('it_ishmed_procedure')}
AND c_falnr IN ({self._get_case_id_query()})
""",
datetime_cols=["c_prozedur_begin"],
)
# Step 2: Get COPRA surgeries with start and end times
# Extract Behandlung_OP_Zeiten (var_id 101322) with parent = Behandlung
# Then extract BEGAN (Anesthesia start) and ENDAN (Anesthesia end) from Behandlung_OP_Zeiten_Ereignisname (var_id 101323[event_timestamp] and 101324[event_name] with parent = Behandlung_OP_Zeiten)
# behandlung: First level - Behandlung (Treatment) entries
# Get 1263 (Behandlung_Beginn) and 1264 (Behandlung_Ende)
# h2: Second level - OP_Zeiten (Operation Times) entries
# n: Event Names (BEGAN/ENDAN)
# t: Event Timestamps
extract_cols = ["began", "endan", "schni", "naht", "freig"]
query = f"""
WITH behandlung AS (
SELECT
b.c_falnr AS case_id,
b.c_id AS beh_id,
MAX(CASE WHEN h.c_var_id = 1263 THEN h.c_value END) as beh_beginn,
MAX(CASE WHEN h.c_var_id = 1264 THEN h.c_value END) as beh_ende
FROM {self.database}.it_copra6_hierarchy_v2 b
LEFT JOIN {self.database}.it_copra6_hierarchy_v2 h
ON b.c_id = h.c_parent_id
AND h.c_var_id IN (1263, 1264)
WHERE
b.c_var_id = 30
AND {self._extr_end_date('b')}
AND b.c_falnr IN ({self._get_case_id_query()})
GROUP BY
b.c_falnr,
b.c_id
),
op_zeiten AS (
SELECT h2.c_id as zeit_id, b.case_id, b.beh_id, b.beh_beginn, b.beh_ende
FROM behandlung b
JOIN {self.database}.it_copra6_hierarchy_v2 h2
ON b.beh_id = h2.c_parent_id
WHERE h2.c_var_id = 101322
)
SELECT
oz.case_id,
oz.beh_id,
oz.beh_beginn,
oz.beh_ende,
{", ".join([f"MAX(CASE WHEN n.c_var_id = 101324 AND n.c_value = '{col.upper()}' THEN t.c_value END) as {col}" for col in extract_cols])}
FROM op_zeiten oz
JOIN {self.database}.it_copra6_hierarchy_v2 n
ON oz.zeit_id = n.c_parent_id
JOIN {self.database}.it_copra6_hierarchy_v2 t
ON n.c_parent_id = t.c_parent_id
AND t.c_var_id = 101323
WHERE
n.c_var_id = 101324
GROUP BY
oz.case_id,
oz.beh_id,
oz.beh_beginn,
oz.beh_ende
"""
df_op_times = self.conn.df_read_sql(
query, datetime_cols=extract_cols + ["beh_beginn", "beh_ende"]
)
# Step 3: Merge with SAP surgeries (by closest start time)
# Step 3.1: Drop missings
initial_len_sap = len(df_sap)
initial_len_op_times = len(df_op_times)
df_sap = df_sap.dropna(subset=["c_prozedur_begin"])
df_op_times = df_op_times.dropna(subset=["beh_beginn"])
drop_sap = initial_len_sap - len(df_sap)
drop_op_times = initial_len_op_times - len(df_op_times)
print(
f"DROP: {drop_sap} rows ({drop_sap / initial_len_sap:.2%}) due to c_prozedur_begin in df_sap is NaT"
)
print(
f"DROP: {drop_op_times} rows ({drop_op_times / initial_len_op_times:.2%}) due to beh_beginn in df_op_times is NaT"
)
# Step 3.2: Combine multiple surgeries with same start time
initial_len_sap = len(df_sap)
df_sap = (
df_sap.groupby(["case_id", "c_prozedur_begin"])
.agg({"ops_code": lambda x: list(x)})
.reset_index()
)
combine_sap = initial_len_sap - len(df_sap)
print(
f"COMBINE: {combine_sap} rows ({combine_sap / initial_len_sap:.2%}) due to multiple surgeries at the same time"
)
df_sap = df_sap.sort_values(by="c_prozedur_begin")
df_op_times = df_op_times.sort_values(by="beh_beginn")
df_proc = pd.merge_asof(
df_sap,
df_op_times,
left_on="c_prozedur_begin",
right_on="beh_beginn",
by="case_id",
direction="nearest",
)
# Step 4: Clean columns
df_proc.rename(
columns={
"beh_beginn": "op_start_dtime_any",
"beh_ende": "op_end_dtime_any",
"c_prozedur_begin": "op_dtime_ops",
"began": "op_start_anaesthesia",
"endan": "op_ende_anaesthesia",
"schni": "op_schnitt_anaesthesia",
"naht": "op_naht_anaesthesia",
"freig": "op_freigabe_anaesthesia",
},
inplace=True,
)
df_proc = df_proc.drop(columns=["beh_id"])
# Step 5: Save procedures
self.df_proc = df_proc
print("SUCCESS: Extracted procedures")
def _load_procedure_id(self):
"""
Merge the procedures with the static data to get a dataframe with procedure as primary key.
"""
# Step 1: Merge procedures with static data
df_individual_proc = self.df_proc.merge(self.obs, on="case_id", how="left")
print(f"Number of procedures: {len(df_individual_proc)}")
# Step 2: Filter for viable times (within hospital admission and discharge)
df_individual_proc = utils.filter_by_condition(
df_individual_proc,
lambda df: df["op_start_dtime_any"] < df["hospital_admission"],
description="op_start_dtime_any < hospital_admission",
mode="drop",
)
df_individual_proc = utils.filter_by_condition(
df_individual_proc,
lambda df: df["op_end_dtime_any"] > df["hospital_discharge"],
description="op_end_dtime_any > hospital_discharge",
mode="drop",
)
# Step 3: Sort procedures (by start time within each case)
df_individual_proc = df_individual_proc.sort_values(
by=["case_id", "op_start_dtime_any"]
)
# Step 4: Create procedure_id (as sequential IDs starting at 1 for each case concatenated with case_id)
df_individual_proc["procedure_id"] = (
df_individual_proc.groupby("case_id").cumcount().add(1).astype(str)
)
df_individual_proc["procedure_id"] = (
df_individual_proc["case_id"] + "_" + df_individual_proc["procedure_id"]
)
# Step 5: Re-order columns (for consistency)
df_individual_proc = df_individual_proc[
[
"patient_id",
"case_id",
"procedure_id",
"sex",
"birthdate",
"death_timestamp",
"hospital_admission",
"hospital_discharge",
"ops_code",
"op_dtime_ops",
"op_start_dtime_any",
"op_end_dtime_any",
"op_start_anaesthesia",
"op_freigabe_anaesthesia",
"op_ende_anaesthesia",
"op_schnitt_anaesthesia",
"op_naht_anaesthesia",
]
]
# Step 6: Update self.obs
self.obs = df_individual_proc
print("SUCCESS: Merged procedures with static data")
[docs]
def load_default_vars(self):
"""
Load the default variables defined in ``vars.json``. It is recommended to use this after filtering your cohort for eligibility to speed up the process.
Returns:
None: Variables are loaded into the cohort.
"""
default_vars = VARS["corr_defaults"]["default_vars"]
apply_defaults = default_vars["global"] + default_vars.get(self.obs_level, [])
for var_name in apply_defaults:
var = ex.Variable.from_corr_vars(
var_name, cohort=self, tmin=self.t_min, tmax=self.t_max
)
if isinstance(
var, ex.NativeDynamic
): # For dynamic variables, convert to static to get value on admission
var = var.on_admission(
select="!closest(hospital_admission, 0, 24h) value"
)
self.add_variable(var)
@property
def axiom_cache(self) -> dict:
"""
This is a deepcopy of cached native dynamic variables that are always stored at a hospital stay level (i.e. tmin=hospital_admission and tmax=hospital_discharge).
If you do not know what to do with this, you probably do not need it.
"""
return copy.deepcopy(
self._axiom_cache
) # Deepcopy to avoid reference issues to Variable objects in the cache
[docs]
def clear_axiom_cache(self) -> None:
"""Delete the axiom cache. Can be useful to free up memory, or for debugging purposes."""
self._axiom_cache = {}
def _load_axiom(
self, variable: str | ex.Variable, overwrite: bool = False
) -> ex.NativeDynamic:
"""Load a variable from the Axiom cache, add it if not present.
Args:
variable: Variable to load.
overwrite (bool): Whether to overwrite the variable if it is already present in the cache.
Returns:
Variable (NativeDynamic): The loaded variable.
"""
if isinstance(variable, str):
var_name = variable
else:
var_name = variable.var_name
if var_name in self.axiom_cache.keys() and not overwrite:
return self.axiom_cache[var_name]
else:
if isinstance(variable, str):
var = ex.NativeDynamic.from_corr_vars(
var_name,
cohort=self,
tmin="hospital_admission",
tmax="hospital_discharge",
)
else:
var = variable
var.tmin = "hospital_admission"
var.tmax = "hospital_discharge"
var.extract(self)
self._axiom_cache[var_name] = var
return self.axiom_cache[var_name]
[docs]
def get_obsm_filtered(self, var_name: str, tmin: str, tmax: str) -> pd.DataFrame:
"""Filter a variable stored in obsm by tmin and tmax.
You may specify tmin and tmax as a tuple (e.g. ("hospital_admission", "+1d")), in which case it will be relative to the hospital admission time of the patient.
Args:
var_name: Name of the variable to filter.
tmin: Name of the column to use as tmin or tuple (see description).
tmax: Name of the column to use as tmax or tuple (see description).
Returns:
Filtered variable.
Examples:
>>> var_data = cohort.get_obsm_filtered(
... var_name="blood_sodium",
... tmin=("hospital_admission", "+1d"),
... tmax="hospital_discharge"
... )
"""
var = self._load_axiom(var_name)
return self._filter_by_tmin_tmax(var.data, tmin, tmax)
[docs]
def add_variable(
self, variable: str | ex.Variable, save_as=None, tmin=None, tmax=None
) -> None:
"""Add a variable to the cohort.
You may specify tmin and tmax as a tuple (e.g. ("hospital_admission", "+1d")), in which case it will be relative to the hospital admission time of the patient.
Args:
variable: Variable to add. Either a string with the variable name (from `vars.json`) or a Variable object.
save_as: Name of the column to save the variable as. Defaults to variable name.
tmin: Name of the column to use as tmin or tuple (see description).
tmax: Name of the column to use as tmax or tuple (see description).
Returns:
None: Variable is added to the cohort.
Examples:
>>> cohort.add_variable("blood_sodium")
>>> cohort.add_variable(
... variable="anx_dx_covid_19",
... tmin=("hospital_admission", "-1d"),
... tmax=cohort.t_eligible
... )
>>> cohort.add_variable(
... NativeStatic(
... var_name="highest_hct_before_eligible",
... select="!max value",
... base_var='blood_hematokrit',
... tmax=cohort.t_eligible
... )
... )
>>> cohort.add_variable(
... variable='any_med_glu',
... save_as="glucose_prior_eligible",
... tmin=(cohort.t_eligible, "-48h"),
... tmax=cohort.t_eligible
... )
"""
if isinstance(variable, str):
tmin = tmin if tmin is not None else self.t_min
tmax = tmax if tmax is not None else self.t_max
var = ex.Variable.from_corr_vars(
variable, cohort=self, tmin=tmin, tmax=tmax
)
else:
assert (
tmin is None and tmax is None
), "Please specify tmin and tmax directly in the Variable object"
var = variable
print(f"Extracting {var.var_name} with tmin: {var.tmin} and tmax: {var.tmax}")
var.extract(self)
self._save_variable(var, save_as=save_as)
def _save_variable(self, var: "ex.Variable", save_as=None) -> None:
"""Save a Variable object to the cohort.
Will either add a column to obs or add an entry to obsm (for dynamic variables).
Variables will also be added to Cohort.variables as variable objects.
Args:
var: Variable to save. (Must be already extracted)
save_as: Name of the column to save the variable as. Defaults to variable name.
Returns:
None: Variable is saved to the cohort.
"""
self.variables[var.var_name] = var
data = var.data
if save_as is None:
save_as = var.var_name
if var.dynamic:
self.obsm[save_as] = data
else:
if any(col.startswith(save_as) for col in self.obs.columns):
for col in self.obs.columns:
if col.startswith(save_as):
print(f"DROP existing variable: {col} ({len(self.obs)})")
self.obs.drop(columns=[col], inplace=True)
if save_as != var.var_name:
if isinstance(
var, ex.NativeStatic
): # This is the only var type that can have multiple columns
if len(var.select_cols) > 1:
data.columns = [
f"{save_as}_{col.split('_', 1)[1]}" for col in data.columns
]
else:
data.rename(columns={var.var_name: save_as}, inplace=True)
self.obs = pd.merge(self.obs, data, on=self.primary_key, how="left")
print(f"SUCCESS: Saved variable: {save_as}")
print(var)
print(var.data.describe())
def _filter_by_tmin_tmax(self, df_var, tmin, tmax):
"""
Filter a dataframe by tmin and tmax.
Currently not in use.
"""
if tmin is None:
tmin = self.get_time_series(self.t_min)
if tmax is None:
tmax = self.get_time_series(self.t_max)
time_bounds = self.obs.set_index(self.primary_key)[[tmin, tmax]].reset_index()
df_merged = df_var.merge(time_bounds, on=self.primary_key, how="left")
df_merged[tmin] = pd.to_datetime(df_merged[tmin], errors="coerce")
df_merged[tmax] = pd.to_datetime(df_merged[tmax], errors="coerce")
df_filtered = df_merged[
(df_merged["recordtime"] >= df_merged[tmin])
& (df_merged["recordtime"] <= df_merged[tmax])
]
df_filtered = df_filtered.drop([tmin, tmax], axis=1).reset_index(drop=True)
return df_filtered
# Time anchors, inclusion and exclusion
[docs]
def set_t_eligible(self, t_eligible: str, drop_ineligible: bool = True) -> None:
"""
Set the time anchor for eligibility. This can be referenced as cohort.t_eligible throughout the process and is required to add inclusion or exclusion criteria.
Args:
t_eligible: Name of the column to use as t_eligible.
drop_ineligible: Whether to drop ineligible patients. Defaults to True.
Returns:
None: t_eligible is set.
Examples:
>>> # Add a suitable time-anchor variable
>>> cohort.add_variable(NativeStatic(
... var_name="spo2_lt_90",
... base_var="spo2",
... select="!first recordtime",
... where="value < 90",
... ))
>>> # Set the time anchor for eligibility
>>> cohort.set_t_eligible("spo2_lt_90")
"""
assert t_eligible in self.obs.columns, f"Column {t_eligible} not found in obs."
assert pd.api.types.is_datetime64_any_dtype(
self.obs[t_eligible]
), f"Column {t_eligible} is not a datetime column."
if self.t_eligible is not None:
print(
f"WARNING: t_eligible already set to {self.t_eligible}. Will overwrite and set to {t_eligible}. \nCAVE! Previously ineligible patients will not be restored."
)
self.t_eligible = t_eligible
if drop_ineligible:
# Drop where t_eligible is NaT
self.obs = utils.filter_by_condition(
self.obs,
lambda df: df[t_eligible].isna(),
description=f"{t_eligible} is NaT",
mode="drop",
)
# Rest axiom cache to account for significant memory reduction
self.clear_axiom_cache()
[docs]
def set_t_outcome(self, t_outcome):
"""
Set the time anchor for outcome. This can be referenced as cohort.t_outcome throughout the process and is recommended to specify for your study.
Args:
t_outcome (str): Name of the column to use as t_outcome.
Returns:
None: t_outcome is set.
Examples:
>>> cohort.set_t_outcome("hospital_discharge")
"""
assert t_outcome in self.obs.columns, f"Column {t_outcome} not found in obs."
assert pd.api.types.is_datetime64_any_dtype(
self.obs[t_outcome]
), f"Column {t_outcome} is not a datetime column."
if self.t_outcome is not None:
print(
f"WARNING: t_outcome already set to {self.t_outcome}. Will overwrite and set to {t_outcome}."
)
self.t_outcome = t_outcome
[docs]
def include(self, *args, **kwargs):
"""
Add an inclusion criterion to the cohort. It is recommended to use ``Cohort.add_inclusion()`` and add all of your inclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method.
Warning:
You must call ``Cohort.add_inclusion()`` before calling ``Cohort.include()`` to ensure that the inclusion criteria are properly tracked.
Args:
variable (str | Variable),
operation (str),
label (str),
operations_done (str)
[Optional: tmin, tmax]
Returns:
None: Criterion is added to the cohort.
Examples:
>>> cohort.include(
... variable="age_on_admission",
... operation=">= 18",
... label="Adult",
... operations_done="Include only adult patients"
... )
"""
return self._include_exclude("include", *args, **kwargs)
[docs]
def exclude(self, *args, **kwargs):
"""
Add an exclusion criterion to the cohort. It is recommended to use ``Cohort.add_exclusion()`` and add all of your exclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method.
Warning:
You must call ``Cohort.add_exclusion()`` before calling ``Cohort.exclude()`` to ensure that the exclusion criteria are properly tracked.
Args:
variable (str | Variable),
operation (str),
label (str),
operations_done (str)
[Optional: tmin, tmax]
Returns:
None: Criterion is added to the cohort.
Examples:
>>> cohort.exclude(
... variable="elix_total",
... operation="> 20",
... operations_done="Exclude patients with high Elixhauser score"
... )
"""
return self._include_exclude("exclude", *args, **kwargs)
def _include_exclude(
self,
mode: Literal["include", "exclude"],
variable: str | ex.Variable,
operation: str,
label: str = "",
operations_done: str = "",
allow_obs=False,
tmin: str | None = None,
tmax: str | None = None,
):
"""
Add an inclusion or exclusion criterion to the cohort.
Args:
mode (Literal["include", "exclude"]): Whether to add an inclusion or exclusion criterion.
variable (str | Variable),
operation (str),
label (str),
operations_done (str)
allow_obs (bool): Allow using a stored variable in obs insteead of re-extracting. CAVE: This can be dangerous when trying to set custom time bounds.
[Optional: tmin, tmax]
Returns:
None: Criterion is added to the cohort.
"""
adata = self.to_adata()
try:
if operation.lower() in ["true", "false"]:
operation = "== True" if operation.lower() == "true" else "== False"
if tmax is None:
assert (
self.t_eligible is not None
), "t_eligible is not set. Please set t_eligible before adding inclusion criteria or set tmax manually (not recommended)."
tmax = self.t_eligible
if tmin is None:
tmin = "hospital_admission"
# Check if this is a defined variable or a custom variable
if variable in self.constant_vars or (
allow_obs and variable in self.obs.columns
):
data = self.obs[[variable, self.primary_key]]
var_name = variable
else:
if isinstance(variable, str):
var = ex.Variable.from_corr_vars(
variable, cohort=self, tmin=tmin, tmax=tmax
)
var_name = variable
else:
var = variable
var_name = variable.var_name
var.tmin = tmin
var.tmax = tmax
data = var.extract(self)
data = data.reset_index()
op_ids = data.query(f"{var_name}{operation}")[self.primary_key]
criterion = {
"variable": var_name,
"operation": operation,
"label": label,
"operations_done": operations_done,
"tmin": tmin,
"tmax": tmax,
}
match mode:
case "include":
adata = adata[adata.obs_names.isin(op_ids)]
self.include_ct(adata, label=label, operations_done=operations_done)
self.inclusion_criteria.append(criterion)
case "exclude":
adata = adata[~adata.obs_names.isin(op_ids)]
self.exclude_ct(adata, label=label, operations_done=operations_done)
self.exclusion_criteria.append(criterion)
except Exception as e:
print(
f"ERROR: Failed to add {mode} criterion ({var_name}{operation}) with data: \n{data.head()} \n{e} \n{traceback.format_exc()}"
)
self._overwrite_from_adata(adata)
[docs]
def add_inclusion(self, inclusion_list: list[dict] = []):
"""
Add an inclusion criteria to the cohort.
Args:
inclusion_list (list): List of inclusion criteria. Must include a dictionary with keys:
* ``variable`` (str | Variable): Variable to use for exclusion
* ``operation`` (str): Operation to apply (e.g., "> 5", "== True")
* ``label`` (str): Short label for the exclusion step
* ``operations_done`` (str): Detailed description of what this exclusion does
* ``tmin`` (str, optional): Start time for variable extraction
* ``tmax`` (str, optional): End time for variable extraction
Returns:
ct (CohortTracker): CohortTracker object, can be used to plot inclusion chart
Note:
Per default, all inclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds.
Examples:
>>> ct = cohort.add_inclusion([
... {
... "variable": "age_on_admission",
... "operation": ">= 18",
... "label": "Adult patients",
... "operations_done": "Excluded patients under 18 years old"
... }
... ])
>>> ct.plot_flowchart()
"""
# Backwards compatibility
if hasattr(self, "include_ct") and self.include_ct is not None:
raise ValueError(
"Inclusion criteria already set. Please use Cohort.include() to add individual inclusion criteria."
)
if not hasattr(self, "inclusion_criteria") or self.inclusion_criteria is None:
self.inclusion_criteria = []
if not hasattr(self, "include_ct"):
self.include_ct = None
adata = self.to_adata()
self.include_ct = CohortTracker(adata, columns=["patient_id"])
self.include_ct(adata, label="Initial cohort")
for inclusion in inclusion_list:
self.include(**inclusion)
return self.include_ct
[docs]
def add_exclusion(self, exclusion_list: list[dict] = []):
"""
Add an exclusion criteria to the cohort.
Args:
exclusion_list (list): List of exclusion criteria. Each criterion is a dictionary containing:
* ``variable`` (str | Variable): Variable to use for exclusion
* ``operation`` (str): Operation to apply (e.g., "> 5", "== True")
* ``label`` (str): Short label for the exclusion step
* ``operations_done`` (str): Detailed description of what this exclusion does
* ``tmin`` (str, optional): Start time for variable extraction
* ``tmax`` (str, optional): End time for variable extraction
Returns:
ct (CohortTracker): CohortTracker object, can be used to plot exclusion chart
Note:
Per default, all exclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds.
Examples:
>>> ct = cohort.add_exclusion([
... {
... "variable": "any_rrt_icu",
... "operation": "true",
... "label": "No RRT",
... "operations_done": "Excluded RRT before hypernatremia"
... },
... {
... "variable": "any_dx_tbi",
... "operation": "true",
... "label": "No TBI",
... "operations_done": "Excluded TBI before hypernatremia"
... },
... {
... "variable": NativeStatic(
... var_name="sodium_count",
... select="!count value",
... base_var="blood_sodium"),
... "operation": "< 1",
... "label": "Final cohort",
... "operations_done": "Excluded cases with less than 1 sodium measurement after hypernatremia",
... "tmin": cohort.t_eligible,
... "tmax": "hospital_discharge"
... }
... ])
>>> ct.plot_flowchart() # Plot the exclusion flowchart
"""
# Backwards compatibility
if hasattr(self, "exclude_ct") and self.exclude_ct is not None:
raise ValueError(
"Exclusion criteria already set. Please use Cohort.exclude() to add individual exclusion criteria."
)
if not hasattr(self, "exclusion_criteria") or self.exclusion_criteria is None:
self.exclusion_criteria = []
if not hasattr(self, "exclude_ct"):
self.exclude_ct = None
adata = self.to_adata()
self.exclude_ct = CohortTracker(adata, columns=["patient_id"])
self.exclude_ct(adata, label="Meets inclusion criteria")
for exclusion in exclusion_list:
self.exclude(**exclusion)
return self.exclude_ct
[docs]
def add_variable_definition(self, var_name: str, var_dict: dict) -> None:
"""Add or update a local variable definition.
Args:
var_name: Name of the variable
var_dict: Dictionary containing variable definition. Can be partial -
missing fields will be inherited from global definition.
Examples:
>>> # Add completely new variable
>>> cohort.add_variable_definition("my_new_var", {
... "type": "native_dynamic",
... "table": "it_ishmed_labor",
... "where": "c_katalog_leistungtext LIKE '%new%'",
... "value_dtype": "DOUBLE",
... "cleaning": {"value": {"low": 100, "high": 150}}
... })
>>> # Partially override existing variable
>>> cohort.add_variable_definition("blood_sodium", {
... "where": "c_katalog_leistungtext LIKE '%custom_sodium%'"
... })
"""
if var_name not in self.project_vars:
self.project_vars[var_name] = {}
# Update existing definition
self.project_vars[var_name].update(var_dict)