from __future__ import annotations
import copy
import getpass
import math
import multiprocessing
from functools import partial
import pandas as pd
from tqdm import tqdm
import corr_vars.static as static
import corr_vars.utils as utils
from corr_vars import logger
from corr_vars.static import variables
CHUNK_SIZE = 30_000
[docs]
class Variable:
"""Base class for all variables.
Args:
var_name: The variable name.
native: True if the variable is native (extracted from the database or simple aggregation of native variables).
dynamic: True if the variable is dynamic (time-series).
tmin: The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax: The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
requires: List of variables required to calculate the variable.
Note that tmin and tmax can be None when you create a Variable object, but must be set before extraction.
If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax.
This base class should not be used directly; use one of the subclasses instead.
"""
def __init__(
self,
var_name: str,
native: bool,
dynamic: bool,
requires: list[str] = [],
tmin: str | tuple[str, str] | None = None,
tmax: str | tuple[str, str] | None = None,
):
self.var_name = var_name
self.native = native # True if the variable is native (extracted from the database or simple aggregation of native variables)
self.dynamic = dynamic # True if the variable is dynamic (time-series)
self.requires = requires
self.chunk_size = CHUNK_SIZE
self.data = None
self.tmin = tmin
self.tmax = tmax
[docs]
@classmethod
def from_json(cls, var_name, var_dict, tmin=None, tmax=None):
"""Create a Variable object from a variable dictionary.
Args:
var_name (str): The variable name.
var_dict (dict): The variable dictionary (from vars.json).
tmin (str | tuple[str, str]): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax (str | tuple[str, str]): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
Returns:
Variable: Variable object, depending on the variable type.
"""
var_type = var_dict.get("type").lower()
match var_type:
case "native_dynamic":
return NativeDynamic(
var_name=var_name,
table=var_dict.get("table"),
where=var_dict.get("where"),
value_dtype=var_dict.get("value_dtype"),
cleaning=var_dict.get("cleaning"),
tmin=tmin,
tmax=tmax,
)
case "native_static":
return NativeStatic(
var_name=var_name,
select=var_dict.get("select"),
base_var=var_dict.get("variable"),
where=var_dict.get("where"),
tmin=tmin,
tmax=tmax,
)
case "derived_dynamic":
return DerivedDynamic(
var_name=var_name,
requires=var_dict.get("requires", []),
cleaning=var_dict.get("cleaning"),
tmin=tmin,
tmax=tmax,
)
case "derived_static":
return DerivedStatic(
var_name=var_name,
requires=var_dict.get("requires", []),
expression=var_dict.get("expression"),
tmin=tmin,
tmax=tmax,
)
case "complex":
return ComplexVariable(
var_name=var_name,
requires=var_dict.get("requires", []),
dynamic=var_dict.get("dynamic"),
tmin=tmin,
tmax=tmax,
)
[docs]
@classmethod
def from_corr_vars(cls, var_name, cohort=None, tmin=None, tmax=None):
"""Create a Variable object from a variable name.
Args:
var_name: The variable name (in vars.json).
cohort: Cohort object (to pass custom variable definitions).
tmin: The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax: The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
Returns:
Variable: Variable object, depending on the variable type.
"""
var_dict = static.VARS["variables"].get(var_name, None)
if cohort:
if var_name in cohort.project_vars:
if var_dict is None:
var_dict = cohort.project_vars[var_name] # Use local definition
else:
var_dict.update(
cohort.project_vars[var_name]
) # Update with local definition
if var_dict is None:
raise KeyError(f"Variable {var_name} not found in vars.json")
return cls.from_json(var_name, var_dict, tmin=tmin, tmax=tmax)
def __repr__(self):
return self.__str__()
def __str__(self):
return f"""Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'})
tmin: {self.tmin}, tmax: {self.tmax}
data: {'Not extracted' if self.data is None else self.data.shape}"""
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
[docs]
def call_var_function(self, cohort: "Cohort") -> bool:
"""
Call the variable function if it exists.
Args:
cohort (Cohort)
Returns:
True if the function was called,
False otherwise.
Operations will be applied to Variable.data directly.
"""
has_var_function = hasattr(variables, self.var_name)
if has_var_function:
logger.info(f"Calling variable function {self.var_name}")
var_function = getattr(variables, self.var_name)
self.data = var_function(var=self, cohort=copy.deepcopy(cohort))
return True
else:
return False
class NativeVariable(Variable):
def __init__(
self,
var_name,
native,
dynamic,
requires: list[str] = [],
tmin=None,
tmax=None,
chunk_size=CHUNK_SIZE,
):
super().__init__(
var_name,
native=native,
dynamic=dynamic,
requires=requires,
tmin=tmin,
tmax=tmax,
)
self.chunk_size = chunk_size
def extract(self, cohort: "Cohort") -> pd.DataFrame:
# Placeholder for extract logic, to be implemented in subclasses
assert self.tmin is not None, "tmin must be provided for all variables"
assert self.tmax is not None, "tmax must be provided for all variables"
self.data = self._extract_chunked(cohort)
self.call_var_function(cohort)
if self.cleaning and self.data is not None and self.dynamic:
for colname in self.cleaning:
if self.cleaning[colname].get("low"):
self.data = self.data[
self.data[colname] >= self.cleaning[colname]["low"]
]
if self.cleaning[colname].get("high"):
self.data = self.data[
self.data[colname] <= self.cleaning[colname]["high"]
]
return self.data
def _extract_chunked(self, cohort: "Cohort"):
"""Manages chunked extraction by dividing the observation data into smaller chunks
to avoid the SQL query size limit."""
if cohort.conn.password_persist:
dbpass = None
else:
dbpass = getpass.getpass(prompt="Enter your password: ")
assert dbpass, "Password not provided."
num_chunks = math.ceil(len(cohort.obs) / self.chunk_size)
if num_chunks > 1:
chunks = [
(i * self.chunk_size, min((i + 1) * self.chunk_size, len(cohort.obs)))
for i in range(num_chunks)
]
# Create lightweight cohort object
light_cohort = copy.copy(cohort)
light_cohort.clear_axiom_cache()
process_func = partial(
self._extract_from_db, cohort=light_cohort, dbpass=dbpass
)
with multiprocessing.Pool() as pool:
results = list(
tqdm(
pool.imap(process_func, chunks),
total=num_chunks,
desc=f"Extracting {self.var_name}",
unit="chunks",
)
)
print("Processing data...")
return pd.concat(results, copy=False, ignore_index=True)
else:
return self._extract_from_db(
chunk_info=(0, len(cohort.obs)), cohort=cohort, dbpass=dbpass
)
def _extract_from_db(
self, chunk_info: tuple[int, int], cohort: "Cohort", dbpass: str | None = None
):
"""Placeholder for extraction logic from the database.
Must be implemented in subclasses."""
raise NotImplementedError
[docs]
class NativeDynamic(NativeVariable):
"""Native dynamic variables are extracted directly from the database and represent time-series data.
The extracted data will be in long format, including columns like recordtime, value, and depending
on the table definition, recordtime_end (e.g., for therapy events) and description (e.g., medication names).
The resulting dataframe will be available as ``Cohort.obsm["var_name"]`` or as ``Variable.data``.
The cleaning term will be applied at the end of the extraction process.
Args:
var_name: Name of the variable.
table: Source table to extract from (e.g., it_copra6_hierachy_v2, it_copra6_therapy).
See `HDL Hue <https://hdl-edge01.charite.de:8443/gateway/cdp-proxy/hue/hue/metastore/tables/db_hypercapnia_prepared?connector_id=impala&namespace=5ac5d06e-ba46-46eb-a9b3-6f3d91901617>`_
for a list of available tables.
where: SQL statement to filter the source table. Must use column names available
in the source table.
value_dtype: SQL data type of the value column, e.g., `BIGINT`, `BOOLEAN`, `DATETIME`, `DOUBLE`, `STRING`.
cleaning: Dictionary specifying lower and upper bounds for impossible values.
Note: Should only include physically impossible values, not unlikely/percentile-based values.
Consult a clinician if unsure!
tmin: Minimum time for the extraction.
tmax: Maximum time for the extraction.
Examples:
>>> # Basic lab value extraction
>>> v = NativeDynamic(
... var_name="blood_urea",
... table="it_ishmed_labor",
... where="c_katalog_leistungtext LIKE '%arnstoff%' AND c_wert <> '0'",
... value_dtype="DOUBLE",
... cleaning={"value": {"low": 2, "high": 500}}
... tmin="hospital_admission",
... tmax="hospital_discharge"
... )
>>> v.extract(cohort)
>>> v.data
>>> # Therapy event with start/end times
>>> # End time will be automatically added as recordtime_end for specified tables
>>> v = NativeDynamic(
... var_name="ecmo_vva_vav_icu",
... table="it_copra6_therapy",
... where="c_apparat_mode IN ('v-v/a ECMO','v-a/v ECMO')",
... value_dtype="VARCHAR",
... cleaning=None,
... tmin="hospital_admission",
... tmax="hospital_discharge"
... )
>>> v.extract(cohort)
>>> v.data
"""
def __init__(
self,
var_name: str,
table: str,
where: str,
value_dtype: str,
cleaning: dict,
tmin=None,
tmax=None,
):
requires = [table]
super().__init__(
var_name, native=True, dynamic=True, requires=requires, tmin=tmin, tmax=tmax
)
self.table = table
self.where = where
self.value_dtype = value_dtype
self.cleaning = cleaning
def _extract_from_db(
self, chunk_info: tuple[int, int], cohort: "Cohort", dbpass: str | None = None
):
"""Extract data from the database."""
start_idx, end_idx = chunk_info
df_chunk = cohort.obs.iloc[start_idx:end_idx].copy()
df_chunk["tmin"], df_chunk["tmax"] = utils.parse_time_args(
df_chunk, tmin=self.tmin, tmax=self.tmax
)
cb_values = utils.get_cb_values(df_chunk, primary_key=cohort.primary_key)
query, table_info, keep_cols = utils.build_base_query(
cb_values,
self,
cohort.database,
cohort.primary_key,
cohort._extr_end_date("v"),
)
if dbpass:
conn = utils.ImpalaConnector(password=dbpass, **cohort.conn_args)
else:
conn = utils.ImpalaConnector.with_password_file(
password_file=cohort.password_file, **cohort.conn_args
)
return conn.df_read_sql(
query=query,
datetime_cols=[col for col in table_info if "recordtime" in col],
keep_cols=keep_cols,
)
[docs]
def on_admission(self, select: str = "!first value") -> NativeStatic:
"""Create a new NativeStatic variable based on the current variable that extracts the value on admission.
Args:
select: Select clause specifying aggregation function and columns. Defaults to ``!first value``.
Returns:
NativeStatic
Examples:
>>> # Return the first value
>>> var_adm = variable.on_admission()
>>> cohort.add_variable(var_adm)
>>> # Be more specific with your selection
>>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value")
>>> cohort.add_variable(var_adm)
"""
return NativeStatic(
var_name=f"{self.var_name}_adm",
select=select,
base_var=self.var_name,
tmin=self.tmin,
tmax=self.tmax,
)
[docs]
class NativeStatic(NativeVariable):
"""NativeStatic variables represent simple aggregations of NativeDynamic variables.
Args:
var_name (str): Name of the variable.
select (str): Select clause specifying aggregation function and columns.
base_var (str): Name of the base variable (must be a native_dynamic variable).
where (str, optional): Optional WHERE clause (in format for polars).
tmin (str, optional): Minimum time for the extraction.
tmax (str, optional): Maximum time for the extraction.
The select argument supports several aggregation functions:
- ``!first [columns]``: Returns the first row within this case
>>> "!first value" # Single column
>>> "!first value, recordtime" # Multiple columns
- ``!last [columns]``: Returns the last row within this case
>>> "!last value"
>>> "!last value, recordtime"
- ``!any``: Returns True if any value exists
>>> "!any"
>>> "!any value"
- ``!closest(to_column, timedelta, plusminus) [columns]``: Selects value closest to specified column
Args:
to_column: Column to compare "recordtime" against
timedelta: Time to add to "to_column" for comparison
plusminus: Allowed time mismatch (can specify different before/after with space)
>>> "!closest(hospital_admission) value, recordtime" # Closest to admission
>>> "!closest(hospital_admission, 0, 2h 3h) value" # 2h before to 3h after
>>> "!closest(first_intubation_dtime, 6h, 2h) value" # 6h after intubation ±2h
- ``!mean [column]``: Calculates mean value
>>> "!mean value"
- ``!median [column]``: Calculates median value
>>> "!median value"
- ``!perc(quantile) [column]``: Calculates specified percentile
>>> "!perc(75) value" # 75th percentile
The where argument supports Pandas-style boolean expressions. These are evaluated in the context of the base variable by pd.eval().
Where also supports magic commands (starting with !) to filter the data. Supported commands are:
- ``!isin(column, [values])``: Filters rows where the value in column is in values
- ``!startswith(column, [values])``: Filters rows where the value in column starts with any of the values
- ``!endswith(column, [values])``: Filters rows where the value in column ends with any of the values
"""
def __init__(
self,
var_name: str,
select: str,
base_var: str | NativeDynamic,
where: str | None = None,
tmin=None,
tmax=None,
):
requires = [base_var]
super().__init__(
var_name,
native=True,
dynamic=False,
requires=requires,
tmin=tmin,
tmax=tmax,
)
self.select = select
logger.debug(f"NativeStatic: {self.select}, {base_var}, {where}")
if isinstance(base_var, str):
self.base_var = Variable.from_corr_vars(base_var, tmin=tmin, tmax=tmax)
else:
self.base_var = base_var
self.base_var_data = (
None # Will only be loaded if use_cache is True on extract()
)
self.where = where
self.agg_func, self.agg_params, self.select_cols = utils.parse_select(
self.select
)
if len(self.select_cols) == 0:
self.select_cols = ["value"]
def _extract_from_db(self, chunk_info: tuple[int, int], cohort: "Cohort"):
print(
"WARNING: Using deprecated SQL-based extraction for NativeStatic variables. Please use pandas syntax directly."
)
start_idx, end_idx = chunk_info
df_chunk = cohort.obs.iloc[start_idx:end_idx].copy()
df_chunk["tmin"], df_chunk["tmax"] = utils.parse_time_args(
df_chunk, tmin=self.tmin, tmax=self.tmax
)
drop_cols = ["row_n", "time_diff", "recordtime"]
cb = utils.get_cb_values(df_chunk, primary_key=cohort.primary_key)
base_query, table_info, keep_cols = utils.build_base_query(
cb,
self.base_var,
cohort.database,
cohort.primary_key,
cohort._extr_end_date("v"),
)
match self.agg_func.lower():
case "first" | "last":
select = utils.parse_col_list(self.select_cols, self.var_name)
order_by = f"ORDER BY recordtime {'ASC' if self.agg_func.lower() == 'first' else 'DESC'}"
query = f"""
WITH ranked_values AS (
SELECT
{select},
{cohort.primary_key},
ROW_NUMBER() OVER (PARTITION BY {cohort.primary_key}
{order_by}) AS row_n
FROM
({base_query}) subquery
{f'WHERE ({self.where})' if self.where else ''}
)
SELECT * FROM ranked_values WHERE row_n = 1
"""
case "any":
assert (
len(self.select_cols) <= 1
), "Cannot use 'any' with multiple select columns."
assert len(self.agg_params) == 0, "Cannot use 'any' with parameters."
if len(self.select_cols) == 0:
select = "MIN(value)"
else:
select = f"MIN({self.select_cols[0].strip()})"
query = f"""
SELECT
{select} AS {self.var_name},
{cohort.primary_key}
FROM
({base_query}) subquery
{f'WHERE ({self.where})' if self.where else ''}
GROUP BY {cohort.primary_key}
"""
case "closest":
assert (
len(self.agg_params) >= 1
), "Closest requires at least one parameter: to [which column to compare to], [optional: tdelta, pm]"
to_col = self.agg_params[0]
drop_cols.append(to_col)
tdelta = utils.convert_interval_sql(
self.agg_params[1] if len(self.agg_params) > 1 else "0"
)
plusminus = self.agg_params[2] if len(self.agg_params) > 2 else "52w"
if " " in plusminus:
pm_before, pm_after = plusminus.split(" ")
else:
pm_before = pm_after = plusminus
pm_before = utils.convert_interval_sql(pm_before)
pm_after = utils.convert_interval_sql(pm_after)
cb_values = utils.get_cb_values(
df_chunk, primary_key=cohort.primary_key, ttarget_col=to_col
)
base_query, table_info, keep_cols = utils.build_base_query(
cb_values,
self.base_var,
cohort.database,
cohort.primary_key,
cohort._extr_end_date("v"),
ttarget_col=to_col,
)
select = utils.parse_col_list(self.select_cols, self.var_name)
query = f"""
WITH ranked_values AS (
SELECT
{select},
{cohort.primary_key},
recordtime,
{to_col},
ABS(UNIX_TIMESTAMP(recordtime) - UNIX_TIMESTAMP(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}))) as time_diff,
ROW_NUMBER() OVER (
PARTITION BY {cohort.primary_key}
ORDER BY ABS(UNIX_TIMESTAMP(recordtime) - UNIX_TIMESTAMP(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta})))
) as row_n
FROM ({base_query}) subquery
WHERE recordtime BETWEEN
DATE_SUB(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}), {pm_before})
AND DATE_ADD(DATE_ADD(CAST({to_col} AS TIMESTAMP), {tdelta}), {pm_after})
{f'AND ({self.where})' if self.where else ''}
)
SELECT * FROM ranked_values WHERE row_n = 1
"""
drop_cols.append(to_col)
case _:
select = utils.parse_col_list(self.select_cols, self.var_name)
query = f"""
SELECT
{select},
{cohort.primary_key}
FROM
({base_query}) subquery
{f'WHERE ({self.where})' if self.where else ''}
GROUP BY {cohort.primary_key}
"""
conn = utils.ImpalaConnector(**cohort.conn_args)
res = conn.df_read_sql(
query=query,
datetime_cols=[col for col in table_info if "recordtime" in col],
)
res.set_index(cohort.primary_key, inplace=True)
for col in drop_cols:
if col in res.columns:
res.drop(columns=[col], inplace=True)
return res
def _extract_agg_cached(self, cohort):
"""Extract static variable using the cached base variable in Cohort.axiom_cache
Args:
cohort (Cohort): Cohort object.
Returns:
pd.DataFrame: Extracted variable.
"""
try:
base_data = cohort._load_axiom(self.base_var).data
except (
KeyError
): # The except clause below is left in for compatibility, but there should be no KeyError from cohort._load_axiom()
try:
base_data = cohort.obsm[self.base_var.var_name]
except KeyError:
base_data = self.base_var.extract(cohort)
base_data.sort_values(by="recordtime", ascending=True, inplace=True)
base_data.reset_index(inplace=True)
base_data.set_index(cohort.primary_key, inplace=True)
if self.where:
if self.where.startswith("!"):
where_func, where_params, _ = utils.parse_select(self.where)
match where_func:
case "isin":
base_data = base_data[
base_data[where_params[0]].isin(where_params[1])
]
case "startswith":
mask = base_data[where_params[0]].str.startswith(
tuple(where_params[1])
)
base_data = base_data[mask]
case "endswith":
mask = base_data[where_params[0]].str.endswith(
tuple(where_params[1])
)
base_data = base_data[mask]
case _:
raise ValueError(
f"Unsupported where function: {where_func}. Supported functions are: isin, startswith, endswith"
)
else:
base_data = base_data[base_data.eval(utils.sql_to_pandas(self.where))]
tmin, tmax = utils.parse_time_args(cohort.obs, tmin=self.tmin, tmax=self.tmax)
obs = cohort.obs.copy()
obs["tmin"] = tmin
obs["tmax"] = tmax
base_data_time = base_data.merge(
obs[["tmin", "tmax", cohort.primary_key]], on=cohort.primary_key, how="left"
)
# Filter base_data based on tmin and tmax
base_data_time = base_data_time[
base_data_time["tmin"] <= base_data_time["recordtime"]
]
base_data_time = base_data_time[
base_data_time["tmax"] >= base_data_time["recordtime"]
]
base_data_grouped = base_data_time.groupby(cohort.primary_key)
params = self.agg_params
match self.agg_func.lower():
case "first":
res = base_data_grouped.first()
case "last":
res = base_data_grouped.last()
case "mean":
res = base_data_grouped.mean()
case "median":
res = base_data_grouped.median()
case "min":
res = base_data_grouped.min()
case "max":
res = base_data_grouped.max()
case "std":
res = base_data_grouped.std()
case "perc":
res = base_data_grouped.quantile(float(params[0]) / 100)
case "count":
res = base_data_grouped.size().to_frame(name=self.select_cols[0])
case "any":
res = (
base_data_grouped.any()
.reindex(cohort.obs[cohort.primary_key])
.fillna(False)
)
case "closest":
to_col = params[0]
tdelta = params[1] if len(params) > 1 else None
plusminus = params[2] if len(params) > 2 else "52w"
base_data = pd.merge(
base_data,
cohort.obs[[cohort.primary_key, to_col]],
on=cohort.primary_key,
how="left",
)
base_data_grouped = base_data.groupby(cohort.primary_key)
if " " in plusminus:
pm_before, pm_after = plusminus.split(" ")
else:
pm_before = pm_after = plusminus
res = base_data_grouped.apply(
lambda group: utils.df_find_closest(
group=group,
to_col=to_col,
tdelta=tdelta,
pm_before=pm_before,
pm_after=pm_after,
)
)
logger.info(
f"Extracted {self.var_name} using cached base variable. / {res.columns}"
)
res = res[self.select_cols]
# Append suffix if multiple select columns
if len(self.select_cols) > 1:
res.columns = [f"{self.var_name}_{col}" for col in self.select_cols]
else:
res.rename(columns={self.select_cols[0]: self.var_name}, inplace=True)
return res
class DerivedVariable(Variable):
def __init__(
self,
var_name,
dynamic,
requires: list[str],
expression: str | None = None,
tmin=None,
tmax=None,
):
super().__init__(
var_name,
native=False,
dynamic=dynamic,
requires=requires,
tmin=tmin,
tmax=tmax,
)
self.required_vars = {}
self.expression = expression
def _get_required_vars(self, cohort: "Cohort"):
if len(self.requires) > 0:
for var_name in self.requires:
if var_name not in cohort.constant_vars:
var = Variable.from_corr_vars(
var_name, tmin=self.tmin, tmax=self.tmax
)
if isinstance(var, NativeDynamic):
self.required_vars[var_name] = cohort._load_axiom(
var.var_name
) # Use caching for dynamic variables
else:
var.extract(
cohort
) # these will not be cached (static variables)
self.required_vars[var_name] = var
def extract(self, cohort: "Cohort") -> pd.DataFrame:
pass
[docs]
class DerivedDynamic(DerivedVariable):
"""Derived dynamic variables are extracted using a custom function.
Warning:
You cannot add these variables manually yet, as they always require a custom function in variables.py. This will be addressed in the future.
Args:
var_name: Name of the variable.
requires: List of required variables.
tmin: Minimum time for the extraction.
tmax: Maximum time for the extraction.
cleaning: Cleaning parameters ({column_name: {low: int, high: int}})
Examples:
Currently, this does not work as you need to add a custom function in variables.py.
>>> DerivedDynamic(
... var_name="pf_ratio",
... requires=["blood_pao2_arterial", "vent_fio2"])
"""
def __init__(
self,
var_name,
requires: list[str],
cleaning: dict | None = None,
tmin=None,
tmax=None,
):
super().__init__(
var_name, dynamic=True, requires=requires, tmin=tmin, tmax=tmax
)
self.cleaning = cleaning
[docs]
class DerivedStatic(DerivedVariable):
"""Derived static variables are extracted using an expression.
Args:
var_name: Name of the variable.
requires: List of required variables.
expression: Expression to extract the variable.
tmin: Minimum time for the extraction.
tmax: Maximum time for the extraction.
Note that DerivedStatic variables are executed on the cohort.obs dataframe and must reference existing columns in cohort.obs.
For DerivedStatic variables, you may either provide an expression or a custom function in variables.py.
Use expressions where possible, but custom functions if you require more complex logic.
Examples:
>>> DerivedStatic(
... var_name="inhospital_death",
... requires=["hospital_discharge", "death_timestamp"],
... expression="hospital_discharge <= death_timestamp"
... )
>>> DerivedStatic(
... var_name="any_va_ecmo_icu",
... requires=["ecmo_va_icu_ops", "ecmo_va_icu"]
... expression=(ecmo_va_icu_ops | ecmo_va_icu)
... )
"""
def __init__(
self,
var_name,
requires: list[str],
expression: str | None = None,
tmin=None,
tmax=None,
):
super().__init__(
var_name, dynamic=False, requires=requires, tmin=tmin, tmax=tmax
)
self.expression = expression
self.computed_expression = False
def _compute_expression(self, cohort):
"""Compute the expression using the required variables.
Args:
cohort (Cohort): Cohort object.
Returns:
pd.DataFrame: Extracted variable. (Returns the original obs dataframe if no expression is provided.)
"""
obs = cohort.obs.copy()
for var in self.required_vars.values():
if not var.dynamic:
obs = obs.merge(var.data, on=cohort.primary_key, how="left")
if not self.expression:
return obs
else:
obs[self.var_name] = obs.eval(self.expression)
self.computed_expression = True
return obs
[docs]
class ComplexVariable(DerivedVariable):
"""A derived variable that requires a custom function to be called.
Other than the expression, this is identical to a DerivedVariable.
Args:
var_name: Name of the variable.
dynamic: Whether the variable is dynamic.
requires: List of required variables.
tmin: Minimum time for the extraction.
tmax: Maximum time for the extraction.
The ComplexVariable requires a custom function in variables.py.
"""
def __init__(self, var_name, dynamic, requires: list[str], tmin=None, tmax=None):
super().__init__(var_name, dynamic, requires, tmin=tmin, tmax=tmax)