Source code for corr_vars.core.extract

from __future__ import annotations

import copy
import getpass
import math
import multiprocessing
from functools import partial

import pandas as pd
from tqdm import tqdm

import corr_vars.static as static
import corr_vars.utils as utils
from corr_vars import logger
from corr_vars.static import variables

CHUNK_SIZE = 30_000


[docs] class Variable: """Base class for all variables. Args: var_name: The variable name. native: True if the variable is native (extracted from the database or simple aggregation of native variables). dynamic: True if the variable is dynamic (time-series). tmin: The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta). tmax: The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta). requires: List of variables required to calculate the variable. Note that tmin and tmax can be None when you create a Variable object, but must be set before extraction. If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax. This base class should not be used directly; use one of the subclasses instead. """ def __init__( self, var_name: str, native: bool, dynamic: bool, requires: list[str] = [], tmin: str | tuple[str, str] | None = None, tmax: str | tuple[str, str] | None = None, ): self.var_name = var_name self.native = native # True if the variable is native (extracted from the database or simple aggregation of native variables) self.dynamic = dynamic # True if the variable is dynamic (time-series) self.requires = requires self.chunk_size = CHUNK_SIZE self.data = None self.tmin = tmin self.tmax = tmax
[docs] @classmethod def from_json(cls, var_name, var_dict, tmin=None, tmax=None): """Create a Variable object from a variable dictionary. Args: var_name (str): The variable name. var_dict (dict): The variable dictionary (from vars.json). tmin (str | tuple[str, str]): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta). tmax (str | tuple[str, str]): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta). Returns: Variable: Variable object, depending on the variable type. """ var_type = var_dict.get("type").lower() match var_type: case "native_dynamic": return NativeDynamic( var_name=var_name, table=var_dict.get("table"), where=var_dict.get("where"), value_dtype=var_dict.get("value_dtype"), cleaning=var_dict.get("cleaning"), tmin=tmin, tmax=tmax, ) case "native_static": return NativeStatic( var_name=var_name, select=var_dict.get("select"), base_var=var_dict.get("variable"), where=var_dict.get("where"), tmin=tmin, tmax=tmax, ) case "derived_dynamic": return DerivedDynamic( var_name=var_name, requires=var_dict.get("requires", []), cleaning=var_dict.get("cleaning"), tmin=tmin, tmax=tmax, ) case "derived_static": return DerivedStatic( var_name=var_name, requires=var_dict.get("requires", []), expression=var_dict.get("expression"), tmin=tmin, tmax=tmax, ) case "complex": return ComplexVariable( var_name=var_name, requires=var_dict.get("requires", []), dynamic=var_dict.get("dynamic"), tmin=tmin, tmax=tmax, )
[docs] @classmethod def from_corr_vars(cls, var_name, cohort=None, tmin=None, tmax=None): """Create a Variable object from a variable name. Args: var_name: The variable name (in vars.json). cohort: Cohort object (to pass custom variable definitions). tmin: The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta). tmax: The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta). Returns: Variable: Variable object, depending on the variable type. """ var_dict = static.VARS["variables"].get(var_name, None) if cohort: if var_name in cohort.project_vars: if var_dict is None: var_dict = cohort.project_vars[var_name] # Use local definition else: var_dict.update( cohort.project_vars[var_name] ) # Update with local definition if var_dict is None: raise KeyError(f"Variable {var_name} not found in vars.json") return cls.from_json(var_name, var_dict, tmin=tmin, tmax=tmax)
def __repr__(self): return self.__str__() def __str__(self): return f"""Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'}) tmin: {self.tmin}, tmax: {self.tmax} data: {'Not extracted' if self.data is None else self.data.shape}""" def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state)
[docs] def call_var_function(self, cohort: "Cohort") -> bool: """ Call the variable function if it exists. Args: cohort (Cohort) Returns: True if the function was called, False otherwise. Operations will be applied to Variable.data directly. """ has_var_function = hasattr(variables, self.var_name) if has_var_function: logger.info(f"Calling variable function {self.var_name}") var_function = getattr(variables, self.var_name) self.data = var_function(var=self, cohort=copy.deepcopy(cohort)) return True else: return False
class NativeVariable(Variable): def __init__( self, var_name, native, dynamic, requires: list[str] = [], tmin=None, tmax=None, chunk_size=CHUNK_SIZE, ): super().__init__( var_name, native=native, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, ) self.chunk_size = chunk_size def extract(self, cohort: "Cohort") -> pd.DataFrame: # Placeholder for extract logic, to be implemented in subclasses assert self.tmin is not None, "tmin must be provided for all variables" assert self.tmax is not None, "tmax must be provided for all variables" self.data = self._extract_chunked(cohort) self.call_var_function(cohort) if self.cleaning and self.data is not None and self.dynamic: for colname in self.cleaning: if self.cleaning[colname].get("low"): self.data = self.data[ self.data[colname] >= self.cleaning[colname]["low"] ] if self.cleaning[colname].get("high"): self.data = self.data[ self.data[colname] <= self.cleaning[colname]["high"] ] return self.data def _extract_chunked(self, cohort: "Cohort"): """Manages chunked extraction by dividing the observation data into smaller chunks to avoid the SQL query size limit.""" if cohort.conn.password_persist: dbpass = None else: dbpass = getpass.getpass(prompt="Enter your password: ") assert dbpass, "Password not provided." num_chunks = math.ceil(len(cohort.obs) / self.chunk_size) if num_chunks > 1: chunks = [ (i * self.chunk_size, min((i + 1) * self.chunk_size, len(cohort.obs))) for i in range(num_chunks) ] # Create lightweight cohort object light_cohort = copy.copy(cohort) light_cohort.clear_axiom_cache() process_func = partial( self._extract_from_db, cohort=light_cohort, dbpass=dbpass ) with multiprocessing.Pool() as pool: results = list( tqdm( pool.imap(process_func, chunks), total=num_chunks, desc=f"Extracting {self.var_name}", unit="chunks", ) ) print("Processing data...") return pd.concat(results, copy=False, ignore_index=True) else: return self._extract_from_db( chunk_info=(0, len(cohort.obs)), cohort=cohort, dbpass=dbpass ) def _extract_from_db( self, chunk_info: tuple[int, int], cohort: "Cohort", dbpass: str | None = None ): """Placeholder for extraction logic from the database. Must be implemented in subclasses.""" raise NotImplementedError
[docs] class NativeDynamic(NativeVariable): """Native dynamic variables are extracted directly from the database and represent time-series data. The extracted data will be in long format, including columns like recordtime, value, and depending on the table definition, recordtime_end (e.g., for therapy events) and description (e.g., medication names). The resulting dataframe will be available as ``Cohort.obsm["var_name"]`` or as ``Variable.data``. The cleaning term will be applied at the end of the extraction process. Args: var_name: Name of the variable. table: Source table to extract from (e.g., it_copra6_hierachy_v2, it_copra6_therapy). See `HDL Hue <https://hdl-edge01.charite.de:8443/gateway/cdp-proxy/hue/hue/metastore/tables/db_hypercapnia_prepared?connector_id=impala&namespace=5ac5d06e-ba46-46eb-a9b3-6f3d91901617>`_ for a list of available tables. where: SQL statement to filter the source table. Must use column names available in the source table. value_dtype: SQL data type of the value column, e.g., `BIGINT`, `BOOLEAN`, `DATETIME`, `DOUBLE`, `STRING`. cleaning: Dictionary specifying lower and upper bounds for impossible values. Note: Should only include physically impossible values, not unlikely/percentile-based values. Consult a clinician if unsure! tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. Examples: >>> # Basic lab value extraction >>> v = NativeDynamic( ... var_name="blood_urea", ... table="it_ishmed_labor", ... where="c_katalog_leistungtext LIKE '%arnstoff%' AND c_wert <> '0'", ... value_dtype="DOUBLE", ... cleaning={"value": {"low": 2, "high": 500}} ... tmin="hospital_admission", ... tmax="hospital_discharge" ... ) >>> v.extract(cohort) >>> v.data >>> # Therapy event with start/end times >>> # End time will be automatically added as recordtime_end for specified tables >>> v = NativeDynamic( ... var_name="ecmo_vva_vav_icu", ... table="it_copra6_therapy", ... where="c_apparat_mode IN ('v-v/a ECMO','v-a/v ECMO')", ... value_dtype="VARCHAR", ... cleaning=None, ... tmin="hospital_admission", ... tmax="hospital_discharge" ... ) >>> v.extract(cohort) >>> v.data """ def __init__( self, var_name: str, table: str, where: str, value_dtype: str, cleaning: dict, tmin=None, tmax=None, ): requires = [table] super().__init__( var_name, native=True, dynamic=True, requires=requires, tmin=tmin, tmax=tmax ) self.table = table self.where = where self.value_dtype = value_dtype self.cleaning = cleaning def _extract_from_db( self, chunk_info: tuple[int, int], cohort: "Cohort", dbpass: str | None = None ): """Extract data from the database.""" start_idx, end_idx = chunk_info df_chunk = cohort.obs.iloc[start_idx:end_idx].copy() df_chunk["tmin"], df_chunk["tmax"] = utils.parse_time_args( df_chunk, tmin=self.tmin, tmax=self.tmax ) cb_values = utils.get_cb_values(df_chunk, primary_key=cohort.primary_key) query, table_info, keep_cols = utils.build_base_query( cb_values, self, cohort.database, cohort.primary_key, cohort._extr_end_date("v"), ) if dbpass: conn = utils.ImpalaConnector(password=dbpass, **cohort.conn_args) else: conn = utils.ImpalaConnector.with_password_file( password_file=cohort.password_file, **cohort.conn_args ) return conn.df_read_sql( query=query, datetime_cols=[col for col in table_info if "recordtime" in col], keep_cols=keep_cols, )
[docs] def on_admission(self, select: str = "!first value") -> NativeStatic: """Create a new NativeStatic variable based on the current variable that extracts the value on admission. Args: select: Select clause specifying aggregation function and columns. Defaults to ``!first value``. Returns: NativeStatic Examples: >>> # Return the first value >>> var_adm = variable.on_admission() >>> cohort.add_variable(var_adm) >>> # Be more specific with your selection >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value") >>> cohort.add_variable(var_adm) """ return NativeStatic( var_name=f"{self.var_name}_adm", select=select, base_var=self.var_name, tmin=self.tmin, tmax=self.tmax, )
[docs] class NativeStatic(NativeVariable): """NativeStatic variables represent simple aggregations of NativeDynamic variables. Args: var_name (str): Name of the variable. select (str): Select clause specifying aggregation function and columns. base_var (str): Name of the base variable (must be a native_dynamic variable). where (str, optional): Optional WHERE clause (in format for polars). tmin (str, optional): Minimum time for the extraction. tmax (str, optional): Maximum time for the extraction. The select argument supports several aggregation functions: - ``!first [columns]``: Returns the first row within this case >>> "!first value" # Single column >>> "!first value, recordtime" # Multiple columns - ``!last [columns]``: Returns the last row within this case >>> "!last value" >>> "!last value, recordtime" - ``!any``: Returns True if any value exists >>> "!any" >>> "!any value" - ``!closest(to_column, timedelta, plusminus) [columns]``: Selects value closest to specified column Args: to_column: Column to compare "recordtime" against timedelta: Time to add to "to_column" for comparison plusminus: Allowed time mismatch (can specify different before/after with space) >>> "!closest(hospital_admission) value, recordtime" # Closest to admission >>> "!closest(hospital_admission, 0, 2h 3h) value" # 2h before to 3h after >>> "!closest(first_intubation_dtime, 6h, 2h) value" # 6h after intubation ±2h - ``!mean [column]``: Calculates mean value >>> "!mean value" - ``!median [column]``: Calculates median value >>> "!median value" - ``!perc(quantile) [column]``: Calculates specified percentile >>> "!perc(75) value" # 75th percentile The where argument supports Pandas-style boolean expressions. These are evaluated in the context of the base variable by pd.eval(). Where also supports magic commands (starting with !) to filter the data. Supported commands are: - ``!isin(column, [values])``: Filters rows where the value in column is in values - ``!startswith(column, [values])``: Filters rows where the value in column starts with any of the values - ``!endswith(column, [values])``: Filters rows where the value in column ends with any of the values """ def __init__( self, var_name: str, select: str, base_var: str | NativeDynamic, where: str | None = None, tmin=None, tmax=None, ): requires = [base_var] super().__init__( var_name, native=True, dynamic=False, requires=requires, tmin=tmin, tmax=tmax, ) self.select = select logger.debug(f"NativeStatic: {self.select}, {base_var}, {where}") if isinstance(base_var, str): self.base_var = Variable.from_corr_vars(base_var, tmin=tmin, tmax=tmax) else: self.base_var = base_var self.base_var_data = ( None # Will only be loaded if use_cache is True on extract() ) self.where = where self.agg_func, self.agg_params, self.select_cols = utils.parse_select( self.select ) if len(self.select_cols) == 0: self.select_cols = ["value"]
[docs] def extract(self, cohort: "Cohort", use_cache: bool = True) -> pd.DataFrame: """Extract the variable. You do not need to call this yourself, as it is called internally when you add the variable to a cohort. However, you may call it directly to obtain variable data independently of the cohort. You still need a cohort object for case ids and other metadata. Args: cohort: Cohort object. use_cache: Whether to use cached base variable. This is highly recommended, as direct SQL-extraction is currently being phased out. Returns: Extracted variable. After extraction, you may also access the data as ``Variable.data``. Examples: >>> var = NativeStatic( ... var_name="first_sodium_recordtime", ... select="!first recordtime", ... base_var="blood_sodium", ... tmin="hospital_admission" ... ) >>> var.extract(cohort) # With var.extract(), the data will not be added to the cohort. >>> var.data """ assert ( self.tmin is not None ), "tmin must be provided for all variables upon extraction" assert ( self.tmax is not None ), "tmax must be provided for all variables upon extraction" if use_cache: self.data = self._extract_agg_cached(cohort) # Add suffix if multiple select columns if len(self.select_cols) > 1: self.data.columns = [ f"{self.var_name}_{col}" for col in self.select_cols ] else: self.data.rename( columns={self.select_cols[0]: self.var_name}, inplace=True ) self.call_var_function(cohort) return self.data else: logger.warning( "The uncached version is providing a direct SQL query. While faster for some variables, not all aggregation functions are currently supported." ) res = self._extract_chunked(cohort) # We do not need to rename columns here because this is handled in the SQL query self.data = res self.call_var_function(cohort) return self.data
def _extract_from_db(self, chunk_info: tuple[int, int], cohort: "Cohort"): print( "WARNING: Using deprecated SQL-based extraction for NativeStatic variables. Please use pandas syntax directly." ) start_idx, end_idx = chunk_info df_chunk = cohort.obs.iloc[start_idx:end_idx].copy() df_chunk["tmin"], df_chunk["tmax"] = utils.parse_time_args( df_chunk, tmin=self.tmin, tmax=self.tmax ) drop_cols = ["row_n", "time_diff", "recordtime"] cb = utils.get_cb_values(df_chunk, primary_key=cohort.primary_key) base_query, table_info, keep_cols = utils.build_base_query( cb, self.base_var, cohort.database, cohort.primary_key, cohort._extr_end_date("v"), ) match self.agg_func.lower(): case "first" | "last": select = utils.parse_col_list(self.select_cols, self.var_name) order_by = f"ORDER BY recordtime {'ASC' if self.agg_func.lower() == 'first' else 'DESC'}" query = f""" WITH ranked_values AS ( SELECT {select}, {cohort.primary_key}, ROW_NUMBER() OVER (PARTITION BY {cohort.primary_key} {order_by}) AS row_n FROM ({base_query}) subquery {f'WHERE ({self.where})' if self.where else ''} ) SELECT * FROM ranked_values WHERE row_n = 1 """ case "any": assert ( len(self.select_cols) <= 1 ), "Cannot use 'any' with multiple select columns." assert len(self.agg_params) == 0, "Cannot use 'any' with parameters." if len(self.select_cols) == 0: select = "MIN(value)" else: select = f"MIN({self.select_cols[0].strip()})" query = f""" SELECT {select} AS {self.var_name}, {cohort.primary_key} FROM ({base_query}) subquery {f'WHERE ({self.where})' if self.where else ''} GROUP BY {cohort.primary_key} """ case "closest": assert ( len(self.agg_params) >= 1 ), "Closest requires at least one parameter: to [which column to compare to], [optional: tdelta, pm]" to_col = self.agg_params[0] drop_cols.append(to_col) tdelta = utils.convert_interval_sql( self.agg_params[1] if len(self.agg_params) > 1 else "0" ) plusminus = self.agg_params[2] if len(self.agg_params) > 2 else "52w" if " " in plusminus: pm_before, pm_after = plusminus.split(" ") else: pm_before = pm_after = plusminus pm_before = utils.convert_interval_sql(pm_before) pm_after = utils.convert_interval_sql(pm_after) cb_values = utils.get_cb_values( df_chunk, primary_key=cohort.primary_key, ttarget_col=to_col ) base_query, table_info, keep_cols = utils.build_base_query( cb_values, self.base_var, cohort.database, cohort.primary_key, cohort._extr_end_date("v"), ttarget_col=to_col, ) select = utils.parse_col_list(self.select_cols, self.var_name) query = f""" WITH ranked_values AS ( SELECT {select}, {cohort.primary_key}, recordtime, {to_col}, ABS(UNIX_TIMESTAMP(recordtime) - UNIX_TIMESTAMP(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}))) as time_diff, ROW_NUMBER() OVER ( PARTITION BY {cohort.primary_key} ORDER BY ABS(UNIX_TIMESTAMP(recordtime) - UNIX_TIMESTAMP(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}))) ) as row_n FROM ({base_query}) subquery WHERE recordtime BETWEEN DATE_SUB(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}), {pm_before}) AND DATE_ADD(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}), {pm_after}) {f'AND ({self.where})' if self.where else ''} ) SELECT * FROM ranked_values WHERE row_n = 1 """ drop_cols.append(to_col) case _: select = utils.parse_col_list(self.select_cols, self.var_name) query = f""" SELECT {select}, {cohort.primary_key} FROM ({base_query}) subquery {f'WHERE ({self.where})' if self.where else ''} GROUP BY {cohort.primary_key} """ conn = utils.ImpalaConnector(**cohort.conn_args) res = conn.df_read_sql( query=query, datetime_cols=[col for col in table_info if "recordtime" in col], ) res.set_index(cohort.primary_key, inplace=True) for col in drop_cols: if col in res.columns: res.drop(columns=[col], inplace=True) return res def _extract_agg_cached(self, cohort): """Extract static variable using the cached base variable in Cohort.axiom_cache Args: cohort (Cohort): Cohort object. Returns: pd.DataFrame: Extracted variable. """ try: base_data = cohort._load_axiom(self.base_var).data except ( KeyError ): # The except clause below is left in for compatibility, but there should be no KeyError from cohort._load_axiom() try: base_data = cohort.obsm[self.base_var.var_name] except KeyError: base_data = self.base_var.extract(cohort) base_data.sort_values(by="recordtime", ascending=True, inplace=True) base_data.reset_index(inplace=True) base_data.set_index(cohort.primary_key, inplace=True) if self.where: if self.where.startswith("!"): where_func, where_params, _ = utils.parse_select(self.where) match where_func: case "isin": base_data = base_data[ base_data[where_params[0]].isin(where_params[1]) ] case "startswith": mask = base_data[where_params[0]].str.startswith( tuple(where_params[1]) ) base_data = base_data[mask] case "endswith": mask = base_data[where_params[0]].str.endswith( tuple(where_params[1]) ) base_data = base_data[mask] case _: raise ValueError( f"Unsupported where function: {where_func}. Supported functions are: isin, startswith, endswith" ) else: base_data = base_data[base_data.eval(utils.sql_to_pandas(self.where))] tmin, tmax = utils.parse_time_args(cohort.obs, tmin=self.tmin, tmax=self.tmax) obs = cohort.obs.copy() obs["tmin"] = tmin obs["tmax"] = tmax base_data_time = base_data.merge( obs[["tmin", "tmax", cohort.primary_key]], on=cohort.primary_key, how="left" ) # Filter base_data based on tmin and tmax base_data_time = base_data_time[ base_data_time["tmin"] <= base_data_time["recordtime"] ] base_data_time = base_data_time[ base_data_time["tmax"] >= base_data_time["recordtime"] ] base_data_grouped = base_data_time.groupby(cohort.primary_key) params = self.agg_params match self.agg_func.lower(): case "first": res = base_data_grouped.first() case "last": res = base_data_grouped.last() case "mean": res = base_data_grouped.mean() case "median": res = base_data_grouped.median() case "min": res = base_data_grouped.min() case "max": res = base_data_grouped.max() case "std": res = base_data_grouped.std() case "perc": res = base_data_grouped.quantile(float(params[0]) / 100) case "count": res = base_data_grouped.size().to_frame(name=self.select_cols[0]) case "any": res = ( base_data_grouped.any() .reindex(cohort.obs[cohort.primary_key]) .fillna(False) ) case "closest": to_col = params[0] tdelta = params[1] if len(params) > 1 else None plusminus = params[2] if len(params) > 2 else "52w" base_data = pd.merge( base_data, cohort.obs[[cohort.primary_key, to_col]], on=cohort.primary_key, how="left", ) base_data_grouped = base_data.groupby(cohort.primary_key) if " " in plusminus: pm_before, pm_after = plusminus.split(" ") else: pm_before = pm_after = plusminus res = base_data_grouped.apply( lambda group: utils.df_find_closest( group=group, to_col=to_col, tdelta=tdelta, pm_before=pm_before, pm_after=pm_after, ) ) logger.info( f"Extracted {self.var_name} using cached base variable. / {res.columns}" ) res = res[self.select_cols] # Append suffix if multiple select columns if len(self.select_cols) > 1: res.columns = [f"{self.var_name}_{col}" for col in self.select_cols] else: res.rename(columns={self.select_cols[0]: self.var_name}, inplace=True) return res
class DerivedVariable(Variable): def __init__( self, var_name, dynamic, requires: list[str], expression: str | None = None, tmin=None, tmax=None, ): super().__init__( var_name, native=False, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, ) self.required_vars = {} self.expression = expression def _get_required_vars(self, cohort: "Cohort"): if len(self.requires) > 0: for var_name in self.requires: if var_name not in cohort.constant_vars: var = Variable.from_corr_vars( var_name, tmin=self.tmin, tmax=self.tmax ) if isinstance(var, NativeDynamic): self.required_vars[var_name] = cohort._load_axiom( var.var_name ) # Use caching for dynamic variables else: var.extract( cohort ) # these will not be cached (static variables) self.required_vars[var_name] = var def extract(self, cohort: "Cohort") -> pd.DataFrame: pass
[docs] class DerivedDynamic(DerivedVariable): """Derived dynamic variables are extracted using a custom function. Warning: You cannot add these variables manually yet, as they always require a custom function in variables.py. This will be addressed in the future. Args: var_name: Name of the variable. requires: List of required variables. tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. cleaning: Cleaning parameters ({column_name: {low: int, high: int}}) Examples: Currently, this does not work as you need to add a custom function in variables.py. >>> DerivedDynamic( ... var_name="pf_ratio", ... requires=["blood_pao2_arterial", "vent_fio2"]) """ def __init__( self, var_name, requires: list[str], cleaning: dict | None = None, tmin=None, tmax=None, ): super().__init__( var_name, dynamic=True, requires=requires, tmin=tmin, tmax=tmax ) self.cleaning = cleaning
[docs] def extract(self, cohort: "Cohort") -> pd.DataFrame: """Extract the variable. Args: cohort (Cohort): Cohort object. Returns: pd.DataFrame: Extracted variable. """ assert ( self.tmin is not None ), "tmin must be provided for all variables upon extraction" assert ( self.tmax is not None ), "tmax must be provided for all variables upon extraction" self._get_required_vars(cohort) self.call_var_function(cohort) if self.cleaning and self.data is not None: for colname in self.cleaning: if self.cleaning[colname].get("low"): self.data = self.data[ self.data[colname] >= self.cleaning[colname]["low"] ] if self.cleaning[colname].get("high"): self.data = self.data[ self.data[colname] <= self.cleaning[colname]["high"] ] return self.data
[docs] class DerivedStatic(DerivedVariable): """Derived static variables are extracted using an expression. Args: var_name: Name of the variable. requires: List of required variables. expression: Expression to extract the variable. tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. Note that DerivedStatic variables are executed on the cohort.obs dataframe and must reference existing columns in cohort.obs. For DerivedStatic variables, you may either provide an expression or a custom function in variables.py. Use expressions where possible, but custom functions if you require more complex logic. Examples: >>> DerivedStatic( ... var_name="inhospital_death", ... requires=["hospital_discharge", "death_timestamp"], ... expression="hospital_discharge <= death_timestamp" ... ) >>> DerivedStatic( ... var_name="any_va_ecmo_icu", ... requires=["ecmo_va_icu_ops", "ecmo_va_icu"] ... expression=(ecmo_va_icu_ops | ecmo_va_icu) ... ) """ def __init__( self, var_name, requires: list[str], expression: str | None = None, tmin=None, tmax=None, ): super().__init__( var_name, dynamic=False, requires=requires, tmin=tmin, tmax=tmax ) self.expression = expression self.computed_expression = False
[docs] def extract(self, cohort: "Cohort") -> pd.DataFrame: assert ( self.tmin is not None ), "tmin must be provided for all variables upon extraction" assert ( self.tmax is not None ), "tmax must be provided for all variables upon extraction" self._get_required_vars(cohort) self.data = self._compute_expression(cohort) called_var_function = self.call_var_function(cohort) assert ( self.computed_expression or called_var_function ), "No expression or variable function found." assert ( self.var_name in self.data.columns ), f"DerivedStatic variable {self.var_name} not found in extracted data. ({self.data.columns})" self.data = self.data[[cohort.primary_key, self.var_name]] return self.data
def _compute_expression(self, cohort): """Compute the expression using the required variables. Args: cohort (Cohort): Cohort object. Returns: pd.DataFrame: Extracted variable. (Returns the original obs dataframe if no expression is provided.) """ obs = cohort.obs.copy() for var in self.required_vars.values(): if not var.dynamic: obs = obs.merge(var.data, on=cohort.primary_key, how="left") if not self.expression: return obs else: obs[self.var_name] = obs.eval(self.expression) self.computed_expression = True return obs
[docs] class ComplexVariable(DerivedVariable): """A derived variable that requires a custom function to be called. Other than the expression, this is identical to a DerivedVariable. Args: var_name: Name of the variable. dynamic: Whether the variable is dynamic. requires: List of required variables. tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. The ComplexVariable requires a custom function in variables.py. """ def __init__(self, var_name, dynamic, requires: list[str], tmin=None, tmax=None): super().__init__(var_name, dynamic, requires, tmin=tmin, tmax=tmax)
[docs] def extract(self, cohort): self._get_required_vars(cohort) self.call_var_function(cohort) return self.data