Source code for corr_vars.core.cohort

# Local imports
import copy
import getpass
import multiprocessing
import os
import pickle

# External imports
import traceback
from datetime import datetime
from typing import Literal

import ehrapy as ep
import pandas as pd
from tableone import TableOne
from tqdm import tqdm

import corr_vars.core.extract as ex
import corr_vars.utils as utils
from corr_vars import logger
from corr_vars.static import GLOBAL_VARS, VARS

# pd.set_option('future.no_silent_downcasting', True)


class CohortTracker(ep.tl.CohortTracker):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def plot_cohort_barplot(self, *args, **kwargs):
        """We are not implementing column tracking due to unhashable column types, so this function is disabled"""
        raise NotImplementedError("This method is not yet implemented.")


[docs] class Cohort: """Class to build a cohort in the CORR database. Args: conn_args: Dictionary of [remote_hostname, username] password_file: Path to the password file or True if your file is in ~/password.txt. database: Default: "db_hypercapnia_prepared" extraction_end_date: End date for the extraction in the format "YYYY-MM-DD" (default: today). obs_level: Observation level (default: "icu_stay"). project_vars: Dictionary with local variable definitions. merge_consecutive: Whether to merge consecutive ICU stays (default: True). Does not apply to any other obs_level. load_default_vars: Whether to load the default variables (default: True). filters: Initial filters (must be a valid SQL WHERE clause for the it_ishmed_fall table). Attributes: obs (pd.DataFrame): Static data for each observation. Contains one row per observation (e.g., ICU stay) with columns for static variables like demographics and outcomes. Example: >>> cohort.obs patient_id case_id icu_stay_id icu_admission icu_discharge sex ... inhospital_death 0 P001 C001 C001_1 2023-01-01 08:30:00 2023-01-03 12:00:00 M ... False 1 P001 C001 C001_2 2023-01-03 14:20:00 2023-01-05 16:30:00 M ... False 2 P002 C002 C002_1 2023-01-02 09:15:00 2023-01-04 10:30:00 F ... False 3 P003 C003 C003_1 2023-01-04 11:45:00 2023-01-07 13:20:00 F ... True ... obsm (dict of pd.DataFrame): Dynamic data stored as dictionary of DataFrames. Each DataFrame contains time-series data for a variable with columns: - recordtime: Timestamp of the measurement - value: Value of the measurement - recordtime_end: End time (only for duration-based variables like therapies) - description: Additional information (e.g., medication names) Example: >>> cohort.obsm["blood_sodium"] icu_stay_id recordtime value 0 C001_1 2023-01-01 09:30:00 138 1 C001_1 2023-01-02 10:15:00 141 2 C001_2 2023-01-03 15:00:00 137 3 C002_1 2023-01-02 10:00:00 142 4 C003_1 2023-01-04 12:30:00 139 ... variables (dict of Variable): Dictionary of all variable objects in the cohort. This is used to keep track of variable metadata. Notes: - For large cohorts, set ``load_default_vars=False`` to speed up the extraction. You can use pre-extracted cohorts as starting points and load them using ``Cohort.load()``. - Variables can be added using ``cohort.add_variable()``. Static variables will be added to ``obs``, dynamic variables to ``obsm``. Examples: Create a new cohort: >>> cohort = Cohort(obs_level="icu_stay", ... database="db_hypercapnia_prepared", ... load_default_vars=False, ... password_file=True) Access static data: >>> cohort.obs["age_on_admission"] # Get age for all patients >>> cohort.obs.loc[cohort.obs["sex"] == "M"] # Filter for male patients Access time-series data: >>> cohort.obsm["blood_sodium"] # Get all blood sodium measurements >>> # Get blood sodium measurements for a specific observation >>> cohort.obsm["blood_sodium"].loc[ ... cohort.obsm["blood_sodium"][cohort.primary_key] == "12345" ... ] """ def __init__( self, conn_args: dict = {}, password_file: str | bool = None, database: Literal[ "db_hypercapnia_prepared", "db_corror_prepared" ] = "db_hypercapnia_prepared", extraction_end_date: str = datetime.now().strftime("%Y-%m-%d"), obs_level: Literal["icu_stay", "hospital_stay", "procedure"] = "icu_stay", project_vars: dict = {}, merge_consecutive: bool = True, load_default_vars: bool = True, filters: str = "", ): # Add flag to indicate this is a newly created cohort self.from_file = False self.conn_args = conn_args self.database = database self.extraction_end_date = extraction_end_date self.obs_level = obs_level self.merge_consecutive = merge_consecutive self.project_vars = project_vars if filters.startswith( "_d" ): # Debug Shorthand: Extract number of months from filter string (e.g. "_d2" -> 2 months) months = int(filters[2:]) start_date = "2023-01-01" end_date = ( datetime.strptime(start_date, "%Y-%m-%d") + pd.Timedelta(days=30 * months) ).strftime("%Y-%m-%d") self.filters = f"c_aufnahme BETWEEN '{start_date}' AND '{end_date}'" print(f"DEBUG: Filters set to {self.filters}") else: self.filters = filters self.password_file = password_file if password_file: self.conn = utils.ImpalaConnector.with_password_file( password_file=password_file, **conn_args ) else: password = getpass.getpass(prompt="Enter your password: ") self.conn = utils.ImpalaConnector(password=password, **conn_args) self._set_primary_keys() self._load_obs_level_data() # Temporary storage for raw and intermediary data (mostly native_dynamic variables) self._axiom_cache = {} # All variable objects in the cohort (important for keeping metadata of static variables, e.g. tmin, tmax) self.variables = {} # Finalized time-series data for analysis self.obsm = {} # Inclusion criteria self.inclusion_criteria = [] self.include_ct = None # Exclusion criteria self.exclusion_criteria = [] self.exclude_ct = None # Constant variables (not dependent on tmin tmax) self.constant_vars = [ col for col in self.obs.columns if col not in ["case_id", self.primary_key] ] # Load default variables if load_default_vars: self.load_default_vars() # Cache for adata compatibility self._adata = None self._adata_hash = None print( f"SUCCESS: Extracted data. Cohort has {len(self.obs)} observations ({self.obs_level})." ) def _set_primary_keys(self): match self.obs_level: case "hospital_stay": self.primary_key = "case_id" self.t_min = "hospital_admission" self.t_max = "hospital_discharge" self.t_eligible = "hospital_admission" self.t_outcome = "hospital_discharge" case "icu_stay": self.primary_key = "icu_stay_id" self.t_min = "icu_admission" self.t_max = "icu_discharge" self.t_eligible = "icu_admission" self.t_outcome = "icu_discharge" case "procedure": self.primary_key = "procedure_id" self.t_min = "op_start_dtime_any" self.t_max = "op_end_dtime_any" self.t_eligible = "op_start_dtime_any" self.t_outcome = "hospital_discharge" case _: raise ValueError(f"Observation level {self.obs_level} not supported.") def _load_obs_level_data(self): self._load_ishmed_fall() match self.obs_level: case "hospital_stay": pass # No specific data to load / this is the most general level case "icu_stay": self._load_ishmed_bewegung() self._load_icu_stays() case "procedure": self._load_procedures() self._load_procedure_id() # System methods def __repr__(self): return f"Cohort(obs_level={self.obs_level}, len(obs)={len(self.obs)}, variables={list(self.variables.keys())})" def __str__(self): out = "Cohort object\n" out += f" obs_level: {self.obs_level}\n" out += f" len(obs): {len(self.obs)}\n" out += f" filters: {self.filters}\n" out += f" variables: {self.variables}\n" out += "\n" out += "obs (first 5 rows):\n" out += self.obs.head().to_string() out += "\n\n" return out
[docs] def debug_print(self): """Print debug information about the cohort. Please use this if you are creating a GitHub issue. Returns: None """ print( f"""Cohort object\n - obs_level: {self.obs_level} - len(obs): {len(self.obs)} - filters: {self.filters} - obs: {self.obs.columns.to_list()} - obsm: {list(self.obsm.keys())} - axiom_cache: {list(self._axiom_cache.keys())} - t_eligible: {self.t_eligible} - t_outcome: {self.t_outcome} - inclusion_criteria: {self.inclusion_criteria} - exclusion_criteria: {self.exclusion_criteria} """ ) utils.print_debug_info()
# Save and load methods
[docs] def save(self, filename: str): """ Save the cohort to a pickle file. Args: filename: Path to the pickle file. Returns: None """ # Avoid pickling unpickable objects conn = self.conn self.conn = None # Add .corr suffix filename = filename + ".corr" if not filename.endswith(".corr") else filename # Pickle the rest with open(filename, "wb") as f: pickle.dump(self, f) self.conn = conn print(f"SUCCESS: Saved cohort to {filename}")
[docs] @classmethod def load(cls, filename: str, conn_args: dict = {}): """ Load a cohort from a pickle file. If this file was saved by a different user, you need to pass your database credentials to the function. Args: filename: Path to the pickle file. conn_args: Database credentials [remote_hostname, username, password_file]. Returns: Cohort: A new Cohort object. """ with open(filename, "rb") as f: cohort = pickle.load(f) if hasattr(cohort, "password_file"): password_file = cohort.password_file else: password_file = None if conn_args == {}: conn_args = cohort.conn_args # Re-init the connection try: if password_file: cohort.conn = utils.ImpalaConnector.with_password_file( password_file=password_file, **conn_args ) else: password = getpass.getpass(prompt="Enter your password: ") cohort.conn = utils.ImpalaConnector(password=password, **conn_args) except Exception as e: raise ConnectionError( f"Failed to re-initialize connection: {e} \n Please veryify your database credentials." ) cohort.from_file = True print(f"SUCCESS: Loaded cohort from {filename}") return cohort
def __getstate__(self): """ Called when pickling the object. Removes the unpickleable connection. """ state = self.__dict__.copy() del state["conn"] # Avoid pickling the connection return state def __setstate__(self, state): """ Called when unpickling the object. Re-initializes the connection. """ self.__dict__.update(state) self.conn = None # Will be re-established when load() is called def _extr_end_date(self, table): return f"CAST({table}.`_hdl_loadstamp` AS DATE) <= '{self.extraction_end_date}'" def _get_case_id_query(self): query = f"SELECT DISTINCT c_falnr FROM {self.database}.it_ishmed_fall WHERE {self._extr_end_date('it_ishmed_fall')}" if self.filters: query += f" AND {self.filters}" return query def _convert_unhashable_cols(self, df): """ Convert unhashable column types to hashable ones. """ df = df.copy() for col in df.columns: if isinstance(df[col].iloc[0], list): df[col] = df[col].apply(lambda x: tuple(x)) elif isinstance(df[col].iloc[0], dict): df[col] = df[col].apply(lambda x: frozenset(x.items())) return df def _get_data_hash(self): """ Get a hash of the cohort data. """ obs = self._convert_unhashable_cols(self.obs) obs_hash = hash(pd.util.hash_pandas_object(obs).sum()) obsm_hash = sum( hash(pd.util.hash_pandas_object(df).sum()) for df in self.obsm.values() ) return obs_hash + obsm_hash @property def adata(self): """AnnotatedData representation of the cohort (NEW - Cached version) Returns an `AnnData <https://anndata.readthedocs.io/en/stable/generated/anndata.AnnData.html>`_ object, which stores a data matrix together with annotations of observations, variables, and unstructured annotations. Warning: This returns a copy of the data. Modifications to the returned object will not be reflected in the cohort. To modify the cohort through the Adata object, use Cohort._overwrite_from_adata() """ current_hash = self._get_data_hash() if self._adata is not None and self._adata_hash == current_hash: return self._adata else: self._adata = self.to_adata() self._adata_hash = current_hash return self._adata
[docs] def to_adata(self): """ Convert the cohort to an AnnData object. Returns: AnnData: An AnnData object. """ obs = self._convert_unhashable_cols(self.obs) static_cols = [col for col in obs.columns if col not in [self.primary_key]] adata = ep.ad.df_to_anndata( obs, index_column=self.primary_key, columns_obs_only=static_cols ) # Add dynamic variables for var_name, df_var in self.obsm.items(): cols_agg = { col: lambda x: list(x) for col in df_var.columns if col not in [self.primary_key] } grouped = df_var.groupby(self.primary_key).agg(cols_agg).reset_index() merged_df = adata.obs.merge(grouped, on=self.primary_key, how="left") merged_df.set_index(self.primary_key, inplace=True) adata.obsm[var_name] = merged_df[cols_agg.keys()] return adata
[docs] def to_csv(self, folder: str): """ Save the cohort to CSV files. Args: folder: Path to the folder. """ os.makedirs(folder, exist_ok=True) self.obs.to_csv(os.path.join(folder, "_obs.csv"), index=False) for var_name, var_data in self.obsm.items(): var_data.to_csv(os.path.join(folder, f"{var_name}.csv"), index=False)
def _overwrite_from_adata(self, adata): self.obs = adata.obs.copy().reset_index() self.obsm = {} for var_name in adata.obsm.keys(): self.obsm[var_name] = adata.obsm[var_name]
[docs] def tableone( self, ignore_cols: list = [], groupby: str = None, filter: str = None, pval: bool = False, **kwargs, ) -> TableOne: """ Create a `TableOne <https://tableone.readthedocs.io/en/latest/index.html>`_ object for the cohort. Args: ignore_cols: Columns to ignore. groupby: Column to group by. filter: Filter to apply to the data. pval: Whether to calculate p-values. **kwargs: Additional arguments to pass to TableOne. Returns: TableOne: A TableOne object. Examples: >>> tableone = cohort.tableone() >>> print(tableone) >>> tableone.to_csv("tableone.csv") >>> tableone = cohort.tableone(groupby="sex", pval=False) >>> print(tableone) >>> tableone.to_csv("tableone_sex.csv") """ # Ignore datetime columns ignore_cols += [ col for col in self.obs.select_dtypes(include=["datetime"]).columns ] # Ignore more columns ignore_cols += [ "case_id", "icu_stay_id", "patient_id", "procedure_id", "c_patnr", ] if groupby is not None: ignore_cols += [groupby] columns = [col for col in self.obs.columns if col not in ignore_cols] # Reorder columns to put icu_id last (because it is so long...) if "icu_id" in columns: columns = [col for col in columns if col != "icu_id"] + ["icu_id"] # Set non-normal distribution for all numeric columns nonnormal_cols = [ col for col in self.obs.select_dtypes(include=["number"]).columns if col in columns and col not in ["inhospital_death"] ] # Explicitly set booleans and objects to categorical categorical_cols = [ col for col in self.obs.select_dtypes( include=["bool", "object", "category"] ).columns if col in columns ] # Filter data if filter is not None: obs_filtered = self.obs.query(filter) else: obs_filtered = self.obs logger.info(f"Columns: {columns}") logger.info(f"Nonnormal: {nonnormal_cols}") logger.info(f"Categorical: {categorical_cols}") logger.info(f"Groupby: {groupby}") return TableOne( obs_filtered, columns=columns, nonnormal=nonnormal_cols, categorical=categorical_cols, groupby=groupby, pval=pval, **kwargs, )
def _load_ishmed_fall(self): """ Get the static data from the it_ishmed_fall table. """ query = f""" SELECT it_ishmed_fall.c_patnr AS patient_id, it_ishmed_fall.c_falnr AS case_id, it_ishmed_fall.c_aufnahme AS hospital_admission, it_ishmed_fall.c_entlassung AS hospital_discharge, it_ishmed_patient.c_gender AS sex, it_ishmed_patient.c_birthdate AS birthdate, it_ishmed_patient.c_datetimeofdeath AS death_timestamp FROM {self.database}.it_ishmed_fall LEFT JOIN {self.database}.it_ishmed_patient USING (c_falnr) WHERE {self._extr_end_date('it_ishmed_fall')} """ if self.filters: query += f" AND {self.filters}" obs = self.conn.df_read_sql(query=query) for col in [ "hospital_admission", "hospital_discharge", "birthdate", "death_timestamp", ]: obs[col] = pd.to_datetime(obs[col], errors="coerce") self.obs = obs print(f"SUCCESS: Extracted ISHMED_FALL data for {len(self.obs)} cases") def _load_ishmed_bewegung(self): """ Get the ICU stays from the it_ishmed_bewegung table. """ mv_adm_transfer = ( GLOBAL_VARS["hospital_data"]["mv_types"]["Admission"] + GLOBAL_VARS["hospital_data"]["mv_types"]["Transfer"] ) mv_adm_transfer = utils.json_list_sql(mv_adm_transfer) query = f""" SELECT c_falnr AS case_id, c_begin AS icu_admission, c_ende AS icu_discharge, c_pflege_oe_id AS icu_id FROM {self.database}.it_ishmed_bewegung WHERE c_bewegungsart IN ({mv_adm_transfer}) AND c_pflege_oe_id IN ({utils.json_list_sql(GLOBAL_VARS['hospital_data']['icu_ids'])}) AND {self._extr_end_date('it_ishmed_bewegung')} AND c_falnr IN ({self._get_case_id_query()}) """ df_bew = self.conn.df_read_sql(query=query) for col in ["icu_admission", "icu_discharge"]: df_bew[col] = pd.to_datetime(df_bew[col], errors="coerce") if self.merge_consecutive: print("Merging consecutive ICU stays...") grouped = df_bew.groupby("case_id") groups_to_merge = [group for name, group in grouped if len(group) > 1] # Apply the merging function only to groups with multiple stays # merged_groups = [utils.merge_consecutive_stays(group) for group in tqdm(groups_to_merge, desc="Merging consecutive stays", unit="case")] with multiprocessing.Pool() as pool: merged_groups = list( tqdm( pool.imap(utils.merge_consecutive_stays, groups_to_merge), total=len(groups_to_merge), desc="Merging consecutive stays", unit="cases", ) ) # Combine the merged groups with the single-stay groups single_stay_groups = grouped.filter(lambda x: len(x) == 1) df_bew_merged = pd.concat([single_stay_groups] + merged_groups) df_bew_merged = df_bew_merged.sort_values( by=["case_id", "icu_admission"] ).reset_index(drop=True) print(f"Original number of stays: {len(df_bew)}") print(f"Number of stays after merging: {len(df_bew_merged)}") print( f"Percentage of stays after merging: {len(df_bew_merged)/len(df_bew):.2%}" ) self.df_bew = df_bew_merged else: self.df_bew = df_bew print("SUCCESS: Extracted ICU stays") def _load_icu_stays(self): """ Merge the ICU stays with the static data to get a dataframe with icu stay as primary key. """ df_icu_stays = self.df_bew.merge(self.obs, on="case_id", how="left") print(f"Number of stays: {len(df_icu_stays)}") df_icu_stays = utils.filter_by_condition( df_icu_stays, lambda df: df["icu_admission"] < df["hospital_admission"], description="icu_admission < hospital_admission", mode="drop", ) df_icu_stays = utils.filter_by_condition( df_icu_stays, lambda df: df["icu_discharge"] > df["hospital_discharge"], description="icu_discharge > hospital_discharge", mode="drop", ) df_icu_stays["icu_stay_id"] = ( df_icu_stays.groupby("case_id").cumcount().add(1).astype(str) ) df_icu_stays["icu_stay_id"] = ( df_icu_stays["case_id"] + "_" + df_icu_stays["icu_stay_id"] ) # Re-order columns df_icu_stays = df_icu_stays[ ["case_id", "icu_stay_id", "icu_admission", "icu_discharge"] + [ col for col in df_icu_stays.columns if col not in ["case_id", "icu_stay_id", "icu_admission", "icu_discharge"] ] ] print("SUCCESS: Merged ICU stays with static data") self.obs = df_icu_stays def _load_procedures(self): """ Get the procedures from the it_ishmed_procedure table and assign OP times from COPRA. """ # Step 1: Get SAP surgeries (OPS code starting with 5) df_sap = self.conn.df_read_sql( f""" SELECT c_falnr AS case_id, c_prozedur_code AS ops_code, c_prozedur_begin FROM {self.database}.it_ishmed_procedure WHERE c_prozedur_code LIKE '5%' AND {self._extr_end_date('it_ishmed_procedure')} AND c_falnr IN ({self._get_case_id_query()}) """, datetime_cols=["c_prozedur_begin"], ) # Step 2: Get COPRA surgeries with start and end times # Extract Behandlung_OP_Zeiten (var_id 101322) with parent = Behandlung # Then extract BEGAN (Anesthesia start) and ENDAN (Anesthesia end) from Behandlung_OP_Zeiten_Ereignisname (var_id 101323[event_timestamp] and 101324[event_name] with parent = Behandlung_OP_Zeiten) # behandlung: First level - Behandlung (Treatment) entries # Get 1263 (Behandlung_Beginn) and 1264 (Behandlung_Ende) # h2: Second level - OP_Zeiten (Operation Times) entries # n: Event Names (BEGAN/ENDAN) # t: Event Timestamps extract_cols = ["began", "endan", "schni", "naht", "freig"] query = f""" WITH behandlung AS ( SELECT b.c_falnr AS case_id, b.c_id AS beh_id, MAX(CASE WHEN h.c_var_id = 1263 THEN h.c_value END) as beh_beginn, MAX(CASE WHEN h.c_var_id = 1264 THEN h.c_value END) as beh_ende FROM {self.database}.it_copra6_hierarchy_v2 b LEFT JOIN {self.database}.it_copra6_hierarchy_v2 h ON b.c_id = h.c_parent_id AND h.c_var_id IN (1263, 1264) WHERE b.c_var_id = 30 AND {self._extr_end_date('b')} AND b.c_falnr IN ({self._get_case_id_query()}) GROUP BY b.c_falnr, b.c_id ), op_zeiten AS ( SELECT h2.c_id as zeit_id, b.case_id, b.beh_id, b.beh_beginn, b.beh_ende FROM behandlung b JOIN {self.database}.it_copra6_hierarchy_v2 h2 ON b.beh_id = h2.c_parent_id WHERE h2.c_var_id = 101322 ) SELECT oz.case_id, oz.beh_id, oz.beh_beginn, oz.beh_ende, {", ".join([f"MAX(CASE WHEN n.c_var_id = 101324 AND n.c_value = '{col.upper()}' THEN t.c_value END) as {col}" for col in extract_cols])} FROM op_zeiten oz JOIN {self.database}.it_copra6_hierarchy_v2 n ON oz.zeit_id = n.c_parent_id JOIN {self.database}.it_copra6_hierarchy_v2 t ON n.c_parent_id = t.c_parent_id AND t.c_var_id = 101323 WHERE n.c_var_id = 101324 GROUP BY oz.case_id, oz.beh_id, oz.beh_beginn, oz.beh_ende """ df_op_times = self.conn.df_read_sql( query, datetime_cols=extract_cols + ["beh_beginn", "beh_ende"] ) # Step 3: Merge with SAP surgeries (by closest start time) # Step 3.1: Drop missings initial_len_sap = len(df_sap) initial_len_op_times = len(df_op_times) df_sap = df_sap.dropna(subset=["c_prozedur_begin"]) df_op_times = df_op_times.dropna(subset=["beh_beginn"]) drop_sap = initial_len_sap - len(df_sap) drop_op_times = initial_len_op_times - len(df_op_times) print( f"DROP: {drop_sap} rows ({drop_sap / initial_len_sap:.2%}) due to c_prozedur_begin in df_sap is NaT" ) print( f"DROP: {drop_op_times} rows ({drop_op_times / initial_len_op_times:.2%}) due to beh_beginn in df_op_times is NaT" ) # Step 3.2: Combine multiple surgeries with same start time initial_len_sap = len(df_sap) df_sap = ( df_sap.groupby(["case_id", "c_prozedur_begin"]) .agg({"ops_code": lambda x: list(x)}) .reset_index() ) combine_sap = initial_len_sap - len(df_sap) print( f"COMBINE: {combine_sap} rows ({combine_sap / initial_len_sap:.2%}) due to multiple surgeries at the same time" ) df_sap = df_sap.sort_values(by="c_prozedur_begin") df_op_times = df_op_times.sort_values(by="beh_beginn") df_proc = pd.merge_asof( df_sap, df_op_times, left_on="c_prozedur_begin", right_on="beh_beginn", by="case_id", direction="nearest", ) # Step 4: Clean columns df_proc.rename( columns={ "beh_beginn": "op_start_dtime_any", "beh_ende": "op_end_dtime_any", "c_prozedur_begin": "op_dtime_ops", "began": "op_start_anaesthesia", "endan": "op_ende_anaesthesia", "schni": "op_schnitt_anaesthesia", "naht": "op_naht_anaesthesia", "freig": "op_freigabe_anaesthesia", }, inplace=True, ) df_proc = df_proc.drop(columns=["beh_id"]) # Step 5: Save procedures self.df_proc = df_proc print("SUCCESS: Extracted procedures") def _load_procedure_id(self): """ Merge the procedures with the static data to get a dataframe with procedure as primary key. """ # Step 1: Merge procedures with static data df_individual_proc = self.df_proc.merge(self.obs, on="case_id", how="left") print(f"Number of procedures: {len(df_individual_proc)}") # Step 2: Filter for viable times (within hospital admission and discharge) df_individual_proc = utils.filter_by_condition( df_individual_proc, lambda df: df["op_start_dtime_any"] < df["hospital_admission"], description="op_start_dtime_any < hospital_admission", mode="drop", ) df_individual_proc = utils.filter_by_condition( df_individual_proc, lambda df: df["op_end_dtime_any"] > df["hospital_discharge"], description="op_end_dtime_any > hospital_discharge", mode="drop", ) # Step 3: Sort procedures (by start time within each case) df_individual_proc = df_individual_proc.sort_values( by=["case_id", "op_start_dtime_any"] ) # Step 4: Create procedure_id (as sequential IDs starting at 1 for each case concatenated with case_id) df_individual_proc["procedure_id"] = ( df_individual_proc.groupby("case_id").cumcount().add(1).astype(str) ) df_individual_proc["procedure_id"] = ( df_individual_proc["case_id"] + "_" + df_individual_proc["procedure_id"] ) # Step 5: Re-order columns (for consistency) df_individual_proc = df_individual_proc[ [ "patient_id", "case_id", "procedure_id", "sex", "birthdate", "death_timestamp", "hospital_admission", "hospital_discharge", "ops_code", "op_dtime_ops", "op_start_dtime_any", "op_end_dtime_any", "op_start_anaesthesia", "op_freigabe_anaesthesia", "op_ende_anaesthesia", "op_schnitt_anaesthesia", "op_naht_anaesthesia", ] ] # Step 6: Update self.obs self.obs = df_individual_proc print("SUCCESS: Merged procedures with static data")
[docs] def load_default_vars(self): """ Load the default variables defined in ``vars.json``. It is recommended to use this after filtering your cohort for eligibility to speed up the process. Returns: None: Variables are loaded into the cohort. """ default_vars = VARS["corr_defaults"]["default_vars"] apply_defaults = default_vars["global"] + default_vars.get(self.obs_level, []) for var_name in apply_defaults: var = ex.Variable.from_corr_vars( var_name, cohort=self, tmin=self.t_min, tmax=self.t_max ) if isinstance( var, ex.NativeDynamic ): # For dynamic variables, convert to static to get value on admission var = var.on_admission( select="!closest(hospital_admission, 0, 24h) value" ) self.add_variable(var)
@property def axiom_cache(self) -> dict: """ This is a deepcopy of cached native dynamic variables that are always stored at a hospital stay level (i.e. tmin=hospital_admission and tmax=hospital_discharge). If you do not know what to do with this, you probably do not need it. """ return copy.deepcopy( self._axiom_cache ) # Deepcopy to avoid reference issues to Variable objects in the cache
[docs] def clear_axiom_cache(self) -> None: """Delete the axiom cache. Can be useful to free up memory, or for debugging purposes.""" self._axiom_cache = {}
def _load_axiom( self, variable: str | ex.Variable, overwrite: bool = False ) -> ex.NativeDynamic: """Load a variable from the Axiom cache, add it if not present. Args: variable: Variable to load. overwrite (bool): Whether to overwrite the variable if it is already present in the cache. Returns: Variable (NativeDynamic): The loaded variable. """ if isinstance(variable, str): var_name = variable else: var_name = variable.var_name if var_name in self.axiom_cache.keys() and not overwrite: return self.axiom_cache[var_name] else: if isinstance(variable, str): var = ex.NativeDynamic.from_corr_vars( var_name, cohort=self, tmin="hospital_admission", tmax="hospital_discharge", ) else: var = variable var.tmin = "hospital_admission" var.tmax = "hospital_discharge" var.extract(self) self._axiom_cache[var_name] = var return self.axiom_cache[var_name]
[docs] def get_obsm_filtered(self, var_name: str, tmin: str, tmax: str) -> pd.DataFrame: """Filter a variable stored in obsm by tmin and tmax. You may specify tmin and tmax as a tuple (e.g. ("hospital_admission", "+1d")), in which case it will be relative to the hospital admission time of the patient. Args: var_name: Name of the variable to filter. tmin: Name of the column to use as tmin or tuple (see description). tmax: Name of the column to use as tmax or tuple (see description). Returns: Filtered variable. Examples: >>> var_data = cohort.get_obsm_filtered( ... var_name="blood_sodium", ... tmin=("hospital_admission", "+1d"), ... tmax="hospital_discharge" ... ) """ var = self._load_axiom(var_name) return self._filter_by_tmin_tmax(var.data, tmin, tmax)
[docs] def add_variable( self, variable: str | ex.Variable, save_as=None, tmin=None, tmax=None ) -> None: """Add a variable to the cohort. You may specify tmin and tmax as a tuple (e.g. ("hospital_admission", "+1d")), in which case it will be relative to the hospital admission time of the patient. Args: variable: Variable to add. Either a string with the variable name (from `vars.json`) or a Variable object. save_as: Name of the column to save the variable as. Defaults to variable name. tmin: Name of the column to use as tmin or tuple (see description). tmax: Name of the column to use as tmax or tuple (see description). Returns: None: Variable is added to the cohort. Examples: >>> cohort.add_variable("blood_sodium") >>> cohort.add_variable( ... variable="anx_dx_covid_19", ... tmin=("hospital_admission", "-1d"), ... tmax=cohort.t_eligible ... ) >>> cohort.add_variable( ... NativeStatic( ... var_name="highest_hct_before_eligible", ... select="!max value", ... base_var='blood_hematokrit', ... tmax=cohort.t_eligible ... ) ... ) >>> cohort.add_variable( ... variable='any_med_glu', ... save_as="glucose_prior_eligible", ... tmin=(cohort.t_eligible, "-48h"), ... tmax=cohort.t_eligible ... ) """ if isinstance(variable, str): tmin = tmin if tmin is not None else self.t_min tmax = tmax if tmax is not None else self.t_max var = ex.Variable.from_corr_vars( variable, cohort=self, tmin=tmin, tmax=tmax ) else: assert ( tmin is None and tmax is None ), "Please specify tmin and tmax directly in the Variable object" var = variable print(f"Extracting {var.var_name} with tmin: {var.tmin} and tmax: {var.tmax}") var.extract(self) self._save_variable(var, save_as=save_as)
def _save_variable(self, var: "ex.Variable", save_as=None) -> None: """Save a Variable object to the cohort. Will either add a column to obs or add an entry to obsm (for dynamic variables). Variables will also be added to Cohort.variables as variable objects. Args: var: Variable to save. (Must be already extracted) save_as: Name of the column to save the variable as. Defaults to variable name. Returns: None: Variable is saved to the cohort. """ self.variables[var.var_name] = var data = var.data if save_as is None: save_as = var.var_name if var.dynamic: self.obsm[save_as] = data else: if any(col.startswith(save_as) for col in self.obs.columns): for col in self.obs.columns: if col.startswith(save_as): print(f"DROP existing variable: {col} ({len(self.obs)})") self.obs.drop(columns=[col], inplace=True) if save_as != var.var_name: if isinstance( var, ex.NativeStatic ): # This is the only var type that can have multiple columns if len(var.select_cols) > 1: data.columns = [ f"{save_as}_{col.split('_', 1)[1]}" for col in data.columns ] else: data.rename(columns={var.var_name: save_as}, inplace=True) self.obs = pd.merge(self.obs, data, on=self.primary_key, how="left") print(f"SUCCESS: Saved variable: {save_as}") print(var) print(var.data.describe()) def _filter_by_tmin_tmax(self, df_var, tmin, tmax): """ Filter a dataframe by tmin and tmax. Currently not in use. """ if tmin is None: tmin = self.get_time_series(self.t_min) if tmax is None: tmax = self.get_time_series(self.t_max) time_bounds = self.obs.set_index(self.primary_key)[[tmin, tmax]].reset_index() df_merged = df_var.merge(time_bounds, on=self.primary_key, how="left") df_merged[tmin] = pd.to_datetime(df_merged[tmin], errors="coerce") df_merged[tmax] = pd.to_datetime(df_merged[tmax], errors="coerce") df_filtered = df_merged[ (df_merged["recordtime"] >= df_merged[tmin]) & (df_merged["recordtime"] <= df_merged[tmax]) ] df_filtered = df_filtered.drop([tmin, tmax], axis=1).reset_index(drop=True) return df_filtered # Time anchors, inclusion and exclusion
[docs] def set_t_eligible(self, t_eligible: str, drop_ineligible: bool = True) -> None: """ Set the time anchor for eligibility. This can be referenced as cohort.t_eligible throughout the process and is required to add inclusion or exclusion criteria. Args: t_eligible: Name of the column to use as t_eligible. drop_ineligible: Whether to drop ineligible patients. Defaults to True. Returns: None: t_eligible is set. Examples: >>> # Add a suitable time-anchor variable >>> cohort.add_variable(NativeStatic( ... var_name="spo2_lt_90", ... base_var="spo2", ... select="!first recordtime", ... where="value < 90", ... )) >>> # Set the time anchor for eligibility >>> cohort.set_t_eligible("spo2_lt_90") """ assert t_eligible in self.obs.columns, f"Column {t_eligible} not found in obs." assert pd.api.types.is_datetime64_any_dtype( self.obs[t_eligible] ), f"Column {t_eligible} is not a datetime column." if self.t_eligible is not None: print( f"WARNING: t_eligible already set to {self.t_eligible}. Will overwrite and set to {t_eligible}. \nCAVE! Previously ineligible patients will not be restored." ) self.t_eligible = t_eligible if drop_ineligible: # Drop where t_eligible is NaT self.obs = utils.filter_by_condition( self.obs, lambda df: df[t_eligible].isna(), description=f"{t_eligible} is NaT", mode="drop", ) # Rest axiom cache to account for significant memory reduction self.clear_axiom_cache()
[docs] def set_t_outcome(self, t_outcome): """ Set the time anchor for outcome. This can be referenced as cohort.t_outcome throughout the process and is recommended to specify for your study. Args: t_outcome (str): Name of the column to use as t_outcome. Returns: None: t_outcome is set. Examples: >>> cohort.set_t_outcome("hospital_discharge") """ assert t_outcome in self.obs.columns, f"Column {t_outcome} not found in obs." assert pd.api.types.is_datetime64_any_dtype( self.obs[t_outcome] ), f"Column {t_outcome} is not a datetime column." if self.t_outcome is not None: print( f"WARNING: t_outcome already set to {self.t_outcome}. Will overwrite and set to {t_outcome}." ) self.t_outcome = t_outcome
[docs] def include(self, *args, **kwargs): """ Add an inclusion criterion to the cohort. It is recommended to use ``Cohort.add_inclusion()`` and add all of your inclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method. Warning: You must call ``Cohort.add_inclusion()`` before calling ``Cohort.include()`` to ensure that the inclusion criteria are properly tracked. Args: variable (str | Variable), operation (str), label (str), operations_done (str) [Optional: tmin, tmax] Returns: None: Criterion is added to the cohort. Examples: >>> cohort.include( ... variable="age_on_admission", ... operation=">= 18", ... label="Adult", ... operations_done="Include only adult patients" ... ) """ return self._include_exclude("include", *args, **kwargs)
[docs] def exclude(self, *args, **kwargs): """ Add an exclusion criterion to the cohort. It is recommended to use ``Cohort.add_exclusion()`` and add all of your exclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method. Warning: You must call ``Cohort.add_exclusion()`` before calling ``Cohort.exclude()`` to ensure that the exclusion criteria are properly tracked. Args: variable (str | Variable), operation (str), label (str), operations_done (str) [Optional: tmin, tmax] Returns: None: Criterion is added to the cohort. Examples: >>> cohort.exclude( ... variable="elix_total", ... operation="> 20", ... operations_done="Exclude patients with high Elixhauser score" ... ) """ return self._include_exclude("exclude", *args, **kwargs)
def _include_exclude( self, mode: Literal["include", "exclude"], variable: str | ex.Variable, operation: str, label: str = "", operations_done: str = "", allow_obs=False, tmin: str | None = None, tmax: str | None = None, ): """ Add an inclusion or exclusion criterion to the cohort. Args: mode (Literal["include", "exclude"]): Whether to add an inclusion or exclusion criterion. variable (str | Variable), operation (str), label (str), operations_done (str) allow_obs (bool): Allow using a stored variable in obs insteead of re-extracting. CAVE: This can be dangerous when trying to set custom time bounds. [Optional: tmin, tmax] Returns: None: Criterion is added to the cohort. """ adata = self.to_adata() try: if operation.lower() in ["true", "false"]: operation = "== True" if operation.lower() == "true" else "== False" if tmax is None: assert ( self.t_eligible is not None ), "t_eligible is not set. Please set t_eligible before adding inclusion criteria or set tmax manually (not recommended)." tmax = self.t_eligible if tmin is None: tmin = "hospital_admission" # Check if this is a defined variable or a custom variable if variable in self.constant_vars or ( allow_obs and variable in self.obs.columns ): data = self.obs[[variable, self.primary_key]] var_name = variable else: if isinstance(variable, str): var = ex.Variable.from_corr_vars( variable, cohort=self, tmin=tmin, tmax=tmax ) var_name = variable else: var = variable var_name = variable.var_name var.tmin = tmin var.tmax = tmax data = var.extract(self) data = data.reset_index() op_ids = data.query(f"{var_name}{operation}")[self.primary_key] criterion = { "variable": var_name, "operation": operation, "label": label, "operations_done": operations_done, "tmin": tmin, "tmax": tmax, } match mode: case "include": adata = adata[adata.obs_names.isin(op_ids)] self.include_ct(adata, label=label, operations_done=operations_done) self.inclusion_criteria.append(criterion) case "exclude": adata = adata[~adata.obs_names.isin(op_ids)] self.exclude_ct(adata, label=label, operations_done=operations_done) self.exclusion_criteria.append(criterion) except Exception as e: print( f"ERROR: Failed to add {mode} criterion ({var_name}{operation}) with data: \n{data.head()} \n{e} \n{traceback.format_exc()}" ) self._overwrite_from_adata(adata)
[docs] def add_inclusion(self, inclusion_list: list[dict] = []): """ Add an inclusion criteria to the cohort. Args: inclusion_list (list): List of inclusion criteria. Must include a dictionary with keys: * ``variable`` (str | Variable): Variable to use for exclusion * ``operation`` (str): Operation to apply (e.g., "> 5", "== True") * ``label`` (str): Short label for the exclusion step * ``operations_done`` (str): Detailed description of what this exclusion does * ``tmin`` (str, optional): Start time for variable extraction * ``tmax`` (str, optional): End time for variable extraction Returns: ct (CohortTracker): CohortTracker object, can be used to plot inclusion chart Note: Per default, all inclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds. Examples: >>> ct = cohort.add_inclusion([ ... { ... "variable": "age_on_admission", ... "operation": ">= 18", ... "label": "Adult patients", ... "operations_done": "Excluded patients under 18 years old" ... } ... ]) >>> ct.plot_flowchart() """ # Backwards compatibility if hasattr(self, "include_ct") and self.include_ct is not None: raise ValueError( "Inclusion criteria already set. Please use Cohort.include() to add individual inclusion criteria." ) if not hasattr(self, "inclusion_criteria") or self.inclusion_criteria is None: self.inclusion_criteria = [] if not hasattr(self, "include_ct"): self.include_ct = None adata = self.to_adata() self.include_ct = CohortTracker(adata, columns=["patient_id"]) self.include_ct(adata, label="Initial cohort") for inclusion in inclusion_list: self.include(**inclusion) return self.include_ct
[docs] def add_exclusion(self, exclusion_list: list[dict] = []): """ Add an exclusion criteria to the cohort. Args: exclusion_list (list): List of exclusion criteria. Each criterion is a dictionary containing: * ``variable`` (str | Variable): Variable to use for exclusion * ``operation`` (str): Operation to apply (e.g., "> 5", "== True") * ``label`` (str): Short label for the exclusion step * ``operations_done`` (str): Detailed description of what this exclusion does * ``tmin`` (str, optional): Start time for variable extraction * ``tmax`` (str, optional): End time for variable extraction Returns: ct (CohortTracker): CohortTracker object, can be used to plot exclusion chart Note: Per default, all exclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds. Examples: >>> ct = cohort.add_exclusion([ ... { ... "variable": "any_rrt_icu", ... "operation": "true", ... "label": "No RRT", ... "operations_done": "Excluded RRT before hypernatremia" ... }, ... { ... "variable": "any_dx_tbi", ... "operation": "true", ... "label": "No TBI", ... "operations_done": "Excluded TBI before hypernatremia" ... }, ... { ... "variable": NativeStatic( ... var_name="sodium_count", ... select="!count value", ... base_var="blood_sodium"), ... "operation": "< 1", ... "label": "Final cohort", ... "operations_done": "Excluded cases with less than 1 sodium measurement after hypernatremia", ... "tmin": cohort.t_eligible, ... "tmax": "hospital_discharge" ... } ... ]) >>> ct.plot_flowchart() # Plot the exclusion flowchart """ # Backwards compatibility if hasattr(self, "exclude_ct") and self.exclude_ct is not None: raise ValueError( "Exclusion criteria already set. Please use Cohort.exclude() to add individual exclusion criteria." ) if not hasattr(self, "exclusion_criteria") or self.exclusion_criteria is None: self.exclusion_criteria = [] if not hasattr(self, "exclude_ct"): self.exclude_ct = None adata = self.to_adata() self.exclude_ct = CohortTracker(adata, columns=["patient_id"]) self.exclude_ct(adata, label="Meets inclusion criteria") for exclusion in exclusion_list: self.exclude(**exclusion) return self.exclude_ct
[docs] def add_variable_definition(self, var_name: str, var_dict: dict) -> None: """Add or update a local variable definition. Args: var_name: Name of the variable var_dict: Dictionary containing variable definition. Can be partial - missing fields will be inherited from global definition. Examples: >>> # Add completely new variable >>> cohort.add_variable_definition("my_new_var", { ... "type": "native_dynamic", ... "table": "it_ishmed_labor", ... "where": "c_katalog_leistungtext LIKE '%new%'", ... "value_dtype": "DOUBLE", ... "cleaning": {"value": {"low": 100, "high": 150}} ... }) >>> # Partially override existing variable >>> cohort.add_variable_definition("blood_sodium", { ... "where": "c_katalog_leistungtext LIKE '%custom_sodium%'" ... }) """ if var_name not in self.project_vars: self.project_vars[var_name] = {} # Update existing definition self.project_vars[var_name].update(var_dict)