import json
import math
import multiprocessing
import os
import re
from functools import partial
import numpy as np
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy
from tqdm import tqdm
import corr_vars.utils as utils
# Local imports
from corr_vars.static import VARS
def _get_password_from_file(password_file: str | None) -> str:
"""Get the password from a file or prompt for it.
Args:
password_file (str): The path to the password file.
Returns:
str: The password.
"""
if not password_file or password_file is True:
password_file = f"{os.path.expanduser('~')}/password.txt"
try:
with open(password_file, "r") as file:
password = file.read().strip()
except Exception as e:
raise FileNotFoundError(f"Password file not found: {password_file} \n {e}")
return password
[docs]
def json_list_sql(json_obj):
"""Convert a JSON list to string for SQL IN clause.
Args:
json_obj (list[str]): The JSON list.
Returns:
str: The SQL list.
"""
return ", ".join(f"'{item}'" for item in json_obj)
[docs]
def filter_by_condition(
df: pd.DataFrame, condition_func, description="", verbose=True, mode="drop"
) -> pd.DataFrame:
"""
Drop rows from a DataFrame based on a condition function.
Parameters:
df (pd.DataFrame): The input DataFrame.
condition_func (callable): A function that takes the DataFrame and returns a boolean Series.
description (str): Description of the condition.
mode (str): Whether to drop or keep the rows. Can be "drop" or "keep".
Returns:
pd.DataFrame: The DataFrame with rows dropped based on the condition.
"""
mask = condition_func(df)
if mode == "drop":
df2 = df[~mask]
elif mode == "keep":
df2 = df[mask]
else:
raise ValueError(f"Mode must be 'drop' or 'keep'. Got {mode}")
if verbose:
print(
f"{mode.upper()}: {mask.sum()} rows ({mask.sum()/len(df):.2%}) due to {description}"
)
return df2
[docs]
def merge_consecutive_stays(group):
"""Merge consecutive ICU stays in a dataframe.
Args:
group (pd.DataFrame): ICU stays for a single case
Returns:
pd.DataFrame: The merged dataframe.
"""
treshold = pd.Timedelta(hours=24)
group = group.sort_values(by=["icu_id", "icu_admission"]).reset_index(
drop=True
) # Reset index to use shift function
group["last_discharge_same_icu"] = group.groupby(by="icu_id")[
"icu_discharge"
].shift(1)
group["time_since_last_discharge_same_icu"] = (
group["icu_admission"] - group["last_discharge_same_icu"]
)
group["merge_with_previous"] = (
group["time_since_last_discharge_same_icu"] < treshold
)
group["merge_group"] = (~group["merge_with_previous"]).cumsum()
gr_merged = (
group.groupby(by=["icu_id", "merge_group"])
.agg({"case_id": "first", "icu_admission": "first", "icu_discharge": "last"})
.reset_index()
)
gr_merged = gr_merged.drop(columns=["merge_group"])
return gr_merged
[docs]
def df_find_closest(
group: pd.DataFrame,
to_col: str,
tdelta: str = "0",
pm_before: str = "52w",
pm_after: str = "52w",
):
"""Find the closest record to the target time.
Args:
group (pd.DataFrame): The input dataframe.
to_col (str): The column containing the target time.
tdelta (str): The time delta. Defaults to "0".
pm_before (str): The time range before the target time. Defaults to "52w".
pm_after (str): The time range after the target time. Defaults to "52w".
Returns:
pd.Series: The closest record.
"""
tdelta = pd.to_timedelta(tdelta)
pm_before = pd.to_timedelta(pm_before)
pm_after = pd.to_timedelta(pm_after)
target_time = group[to_col] + tdelta
group["timediff"] = (group["recordtime"] - target_time).abs()
within_range = group[
(group["recordtime"] >= (target_time - pm_before))
& (group["recordtime"] <= (target_time + pm_after))
]
if within_range.empty:
return pd.Series([np.nan] * len(group.columns), index=group.columns)
closest = within_range.loc[within_range["timediff"].idxmin()]
closest.drop(columns=["timediff"], inplace=True)
return closest
[docs]
def aggregate_column(
group: DataFrameGroupBy, agg_func: str, column: str, params: list[str] = []
):
match agg_func:
case "median":
return group[column].median()
case "mean":
return group[column].mean()
case "perc":
return group[column].quantile(float(params[0]) / 100)
case "count":
return group.size()
case _:
raise ValueError(f"Unknown aggregation function: {agg_func}")
[docs]
def parse_select(select_str):
"""Parse the select syntax to extract function name, parameters and columns.
Args:
select_str (str): The select string.
Returns:
tuple: (function_name, params, columns)
Examples:
- !first value, recordtime
- !closest(timestamp, 1d, 2d) value, recordtime
- !any value
- !last value, recordtime
"""
select_str = select_str.strip()
# Extract function name
if select_str.startswith("!"):
if "(" in select_str:
# Handle function with parameters
func_name = select_str[1 : select_str.find("(")].strip()
params_str = select_str[
select_str.find("(") + 1 : select_str.find(")")
].strip()
params = [p.strip() for p in params_str.split(",")]
cols_str = select_str[select_str.find(")") + 1 :].strip()
else:
# Handle function without parameters
parts = select_str[1:].split(" ", 1)
func_name = parts[0].strip()
params = []
cols_str = parts[1] if len(parts) > 1 else ""
else:
func_name, params = None, None
cols = select_str.split(",")
return func_name, params, cols
# Extract columns to select
if cols_str:
columns = [c.strip() for c in cols_str.split(",")]
else:
columns = []
return func_name, params, columns
[docs]
def parse_col_list(cols: list[str], var_name: str):
"""Parse a list of columns to a select string.
Args:
cols (list[str]): The columns to parse.
var_name (str): The variable name.
Returns:
str: The select string.
Examples:
- ["value", "recordtime"] -> "value as var_name_value, recordtime as var_name_recordtime"
- ["value"] -> "value AS var_name"
"""
if len(cols) > 1:
select = ", ".join(
[f"{col.strip()} AS {var_name}_{col.strip()}" for col in cols]
)
else:
select = f"{cols[0].strip()} AS {var_name}"
return select
[docs]
def convert_interval_sql(interval_str):
"""Convert shorthand interval notation to SQL INTERVAL syntax.
Supports: y->year, M->month, w->week, d->day, h->hour, m->minute, s->second
Args:
interval_str (str): The interval string.
Returns:
str: The SQL INTERVAL syntax.
Examples:
- "2h" -> "INTERVAL '2' HOUR"
- "-2h" -> "INTERVAL '-2' HOUR"
- "30m" -> "INTERVAL '30' MINUTE"
"""
if not interval_str or interval_str == "0":
return "INTERVAL 0 SECOND"
# Remove leading + if present
interval_str = interval_str.strip().lstrip("+")
# Parse the interval
unit_map = {
"y": "YEAR",
"M": "MONTH",
"w": "WEEK",
"d": "DAY",
"h": "HOUR",
"m": "MINUTE",
"s": "SECOND",
}
# Extract number and unit
match = re.match(r"^(-?\d+)([yMwdhms])$", interval_str)
if not match:
raise ValueError(
f"Invalid interval format: {interval_str}. Use format like '2h', '-2h', '30m', etc. (Supports y->year, M->month, w->week, d->day, h->hour, m->minute, s->second)"
)
number, unit = match.groups()
sql_unit = unit_map.get(unit)
if not sql_unit:
raise ValueError(
f"Invalid unit: {unit}. Valid units are: {', '.join(unit_map.keys())}"
)
return f"INTERVAL {number} {sql_unit}"
[docs]
def get_cb_values(
df_chunk, primary_key="case_id", tmin_col="tmin", tmax_col="tmax", ttarget_col=None
):
"""Get the case bounds CTE. This is used to apply a time filter to a native dynamic variable.
Args:
df_chunk (pd.DataFrame): A chunk of the cohort dataframe.
primary_key (str): The primary key column. Defaults to "case_id".
tmin_col (str): The column containing the tmin. Defaults to "tmin".
tmax_col (str): The column containing the tmax. Defaults to "tmax".
ttarget_col (str): The column containing the target time. Defaults to None. (For !closest)
Returns:
str: The case bounds CTE.
"""
# Drop rows with NaN tmin or tmax
df_chunk = df_chunk.dropna(subset=[tmin_col, tmax_col])
values = []
for _, row in df_chunk.iterrows():
value = f"('{row['case_id']}', '{row[tmin_col]}', '{row[tmax_col]}'"
if ttarget_col:
value += f", '{row[ttarget_col]}'"
if primary_key != "case_id":
value += f", '{row[primary_key]}'"
value += ")"
values.append(value)
cb_values = ", ".join(values)
cb = f"""WITH case_bounds (case_id, case_tmin, case_tmax{', ' + ttarget_col if ttarget_col else ''}{', ' + primary_key if primary_key != 'case_id' else ''}) AS (
VALUES {cb_values}
)"""
return cb
[docs]
def build_base_query(
cb: str,
var: "NativeDynamic",
database: str,
primary_key: str,
extr_end_date: str,
ttarget_col: str = None,
):
"""Build the base query for a native dynamic variable.
Args:
cb (str): The case bounds CTE.
var (NativeDynamic): The native dynamic variable.
database (str): The database name.
primary_key (str): The primary key column. Defaults to "case_id".
extr_end_date (str): The extraction end date.
ttarget_col (str): The target time column. Defaults to None.
Returns:
tuple: Query, table info, and columns to keep.
"""
table = var.table
table_info = VARS["tables"][table]
value_col, recordtime_col, case_id_col = (
table_info["value"],
table_info["recordtime"],
table_info["case_id"],
)
optional_columns = {
col: table_info[col]
for col in table_info
if col not in ["value", "recordtime", "case_id"]
}
optional_columns_str = ", ".join(
[f"v.{table_info[col]} AS {col}" for col in optional_columns]
)
value = (
f"CAST(v.{value_col} AS {var.value_dtype}) AS value"
if var.value_dtype
else f"v.{value_col} AS value"
)
query = f"""
{cb}
SELECT
v.{case_id_col} AS case_id,
v.{recordtime_col} AS recordtime,
{value}
{f', {optional_columns_str}' if optional_columns_str else ''}
{f', cb.{primary_key}' if primary_key != 'case_id' else ''}
{f', cb.{ttarget_col}' if ttarget_col else ''}
FROM
{database}.{table} v
JOIN
case_bounds cb ON v.{case_id_col} = cb.case_id
WHERE
v.{case_id_col} IN (SELECT DISTINCT case_id FROM case_bounds)
AND v.{recordtime_col} BETWEEN cb.case_tmin AND cb.case_tmax
AND {extr_end_date}
{f"AND ({var.where})" if var.where else ''}
"""
keep_cols = [primary_key, "recordtime", "value"] + [col for col in optional_columns]
return query, table_info, keep_cols
[docs]
def get_time_series(obs: pd.DataFrame, col_name: str, tdelta: str = "0"):
"""Returns a series with the time since col_name (in datetime format) plus tdelta (in pd.Timedelta format).
Args:
obs (pd.DataFrame): The observation dataframe.
col_name (str): The column name.
tdelta (str): The time delta. Defaults to "0".
Returns:
pd.Series: A series with the time column + tdelta.
"""
tdelta = pd.to_timedelta(tdelta)
return obs[col_name] + tdelta
[docs]
def parse_time_args(
obs: pd.DataFrame,
tmin: str | tuple[str, str] | None = None,
tmax: str | tuple[str, str] | None = None,
):
"""Parse the time arguments to get the time series.
Args:
obs (pd.DataFrame): The observation dataframe.
tmin (str | tuple[str, str] | None): The tmin argument. Defaults to None.
tmax (str | tuple[str, str] | None): The tmax argument. Defaults to None.
Returns:
tuple: tmin (pd.Series), tmax (pd.Series).
"""
if isinstance(tmin, tuple):
col_name, tdelta = tmin
tmin = get_time_series(obs, col_name, tdelta=tdelta)
else:
tmin = get_time_series(obs, col_name=tmin)
if isinstance(tmax, tuple):
col_name, tdelta = tmax
tmax = get_time_series(obs, col_name, tdelta=tdelta)
else:
tmax = get_time_series(obs, col_name=tmax)
return tmin, tmax
[docs]
def sql_to_pandas(query: str) -> str:
"""
Converts SQL-like expressions to pandas-compatible expressions. Supported operations include IN, NOT IN, LIKE, AND, OR.
Args:
query (str): The SQL-like expression.
Returns:
str: The equivalent pandas-compatible expression.
"""
if not any(keyword in query for keyword in ["NOT IN", "IN", "LIKE", "OR", "AND"]):
return query
# Convert NOT IN expressions
query = re.sub(r"(\w+)\s+NOT IN\s+\(([^)]+)\)", r"~\1.isin([\2])", query)
# Convert IN expressions
query = re.sub(r"(\w+)\s+IN\s+\(([^)]+)\)", r"\1.isin([\2])", query)
# Convert LIKE expressions
query = re.sub(
r"(\w+)\s+LIKE\s+'([^']+)%'", r"\1.str.contains(r'^\2', regex=True)", query
)
# Convert OR conditions
query = re.sub(r"\s+OR\s+", " | ", query)
# Convert AND conditions
query = re.sub(r"\s+AND\s+", " & ", query)
print(
"Warning: SQL-expressions for NativeStatic variables are deprecated and will be removed in a future version. Please create new variables using the pandas syntax directly."
)
return query
[docs]
def get_variables(var_name=""):
"""
Get variables from the VARS dictionary.
"""
if not var_name:
return json.dumps(VARS["variables"], indent=4)
else:
return json.dumps(VARS["variables"][var_name], indent=4)
[docs]
def merge_consecutive(
data: pd.DataFrame,
primary_key: str,
recordtime: str = "recordtime",
recordtime_end: str = "recordtime_end",
time_threshold: pd.Timedelta = pd.Timedelta(minutes=30),
) -> pd.DataFrame:
"""
Combine consecutive sessions (<30min separation) of ecmo_vv_icu into a single session.
Args:
data (pd.DataFrame): The data to merge.
primary_key (str): The primary key column.
recordtime (str): The recordtime column. Defaults to "recordtime".
recordtime_end (str): The recordtime_end column. Defaults to "recordtime_end".
time_threshold (pd.Timedelta): The time threshold. Defaults to 30 minutes.
Returns:
pd.DataFrame: The merged data.
"""
# Sort by primary_key and recordtime
data = data.sort_values(by=[primary_key, recordtime]).reset_index(drop=True)
def combine_consecutive_sessions(group):
group["last_event_recordtime_end"] = group["recordtime_end"].shift(1)
group["time_since_last_event"] = (
group["recordtime"] - group["last_event_recordtime_end"]
)
group["merge_with_previous"] = group["time_since_last_event"] < time_threshold
group["merge_group"] = (~group["merge_with_previous"]).cumsum()
gr_merged = (
group.groupby(by=[primary_key, "merge_group"])
.agg({"value": "first", "recordtime": "first", "recordtime_end": "last"})
.reset_index()
)
return gr_merged
grouped = data.groupby(primary_key)
grouped = grouped.apply(combine_consecutive_sessions)
return grouped.reset_index(drop=True)
def _extract_with_co6_parent_chunked(
chunk_info: tuple[int, int],
df: pd.DataFrame,
dbpass: str | None,
conn_args: dict,
password_file: str | None,
primary_key: str,
table: str,
parent_name: str,
suffixes: dict[str, str],
database: str,
):
start_idx, end_idx = chunk_info
if dbpass:
conn = utils.ImpalaConnector(password=dbpass, **conn_args)
else:
conn = utils.ImpalaConnector.with_password_file(
password_file=password_file, **conn_args
)
# Build the case bounds CTE
cb = get_cb_values(
df.iloc[start_idx:end_idx],
primary_key=primary_key,
tmin_col="tmin",
tmax_col="tmax",
)
# Build JOIN clauses and SELECT expressions dynamically
joins = []
selects = []
for suffix, col_name in suffixes.items():
alias = f"xc_{col_name}"
joins.append(
f"""
JOIN {database}.{table} {alias}
ON {alias}.c_parent_id = p.c_id
AND {alias}.c_var_name = '{parent_name}{suffix}'"""
)
cast = (
f"CAST({alias}.c_value AS DATE)"
if col_name == "recordtime"
else f"{alias}.c_value"
)
selects.append(f"{cast} AS {col_name}")
# Build the query
query = f"""
{cb}
SELECT
p.c_falnr AS case_id,
{', '.join(selects)}
{f', cb.{primary_key}' if primary_key != 'case_id' else ''}
FROM
{database}.{table} p
{''.join(joins)}
JOIN case_bounds cb
ON p.c_falnr = cb.case_id
WHERE
p.c_var_name = '{parent_name}'
AND CAST(xc_recordtime.c_value AS DATE) BETWEEN cb.case_tmin AND cb.case_tmax
"""
# Execute the query
df = conn.df_read_sql(query, datetime_cols=["recordtime"])
return df