Source code for corr_vars.utils.helpers

import json
import math
import multiprocessing
import os
import re
from functools import partial

import numpy as np
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy
from tqdm import tqdm

import corr_vars.utils as utils

# Local imports
from corr_vars.static import VARS


def _get_password_from_file(password_file: str | None) -> str:
    """Get the password from a file or prompt for it.

    Args:
        password_file (str): The path to the password file.

    Returns:
        str: The password.
    """
    if not password_file or password_file is True:
        password_file = f"{os.path.expanduser('~')}/password.txt"
    try:
        with open(password_file, "r") as file:
            password = file.read().strip()
    except Exception as e:
        raise FileNotFoundError(f"Password file not found: {password_file} \n {e}")
    return password


[docs] def json_list_sql(json_obj): """Convert a JSON list to string for SQL IN clause. Args: json_obj (list[str]): The JSON list. Returns: str: The SQL list. """ return ", ".join(f"'{item}'" for item in json_obj)
[docs] def filter_by_condition( df: pd.DataFrame, condition_func, description="", verbose=True, mode="drop" ) -> pd.DataFrame: """ Drop rows from a DataFrame based on a condition function. Parameters: df (pd.DataFrame): The input DataFrame. condition_func (callable): A function that takes the DataFrame and returns a boolean Series. description (str): Description of the condition. mode (str): Whether to drop or keep the rows. Can be "drop" or "keep". Returns: pd.DataFrame: The DataFrame with rows dropped based on the condition. """ mask = condition_func(df) if mode == "drop": df2 = df[~mask] elif mode == "keep": df2 = df[mask] else: raise ValueError(f"Mode must be 'drop' or 'keep'. Got {mode}") if verbose: print( f"{mode.upper()}: {mask.sum()} rows ({mask.sum()/len(df):.2%}) due to {description}" ) return df2
[docs] def merge_consecutive_stays(group): """Merge consecutive ICU stays in a dataframe. Args: group (pd.DataFrame): ICU stays for a single case Returns: pd.DataFrame: The merged dataframe. """ treshold = pd.Timedelta(hours=24) group = group.sort_values(by=["icu_id", "icu_admission"]).reset_index( drop=True ) # Reset index to use shift function group["last_discharge_same_icu"] = group.groupby(by="icu_id")[ "icu_discharge" ].shift(1) group["time_since_last_discharge_same_icu"] = ( group["icu_admission"] - group["last_discharge_same_icu"] ) group["merge_with_previous"] = ( group["time_since_last_discharge_same_icu"] < treshold ) group["merge_group"] = (~group["merge_with_previous"]).cumsum() gr_merged = ( group.groupby(by=["icu_id", "merge_group"]) .agg({"case_id": "first", "icu_admission": "first", "icu_discharge": "last"}) .reset_index() ) gr_merged = gr_merged.drop(columns=["merge_group"]) return gr_merged
[docs] def df_find_closest( group: pd.DataFrame, to_col: str, tdelta: str = "0", pm_before: str = "52w", pm_after: str = "52w", ): """Find the closest record to the target time. Args: group (pd.DataFrame): The input dataframe. to_col (str): The column containing the target time. tdelta (str): The time delta. Defaults to "0". pm_before (str): The time range before the target time. Defaults to "52w". pm_after (str): The time range after the target time. Defaults to "52w". Returns: pd.Series: The closest record. """ tdelta = pd.to_timedelta(tdelta) pm_before = pd.to_timedelta(pm_before) pm_after = pd.to_timedelta(pm_after) target_time = group[to_col] + tdelta group["timediff"] = (group["recordtime"] - target_time).abs() within_range = group[ (group["recordtime"] >= (target_time - pm_before)) & (group["recordtime"] <= (target_time + pm_after)) ] if within_range.empty: return pd.Series([np.nan] * len(group.columns), index=group.columns) closest = within_range.loc[within_range["timediff"].idxmin()] closest.drop(columns=["timediff"], inplace=True) return closest
[docs] def aggregate_column( group: DataFrameGroupBy, agg_func: str, column: str, params: list[str] = [] ): match agg_func: case "median": return group[column].median() case "mean": return group[column].mean() case "perc": return group[column].quantile(float(params[0]) / 100) case "count": return group.size() case _: raise ValueError(f"Unknown aggregation function: {agg_func}")
[docs] def parse_select(select_str): """Parse the select syntax to extract function name, parameters and columns. Args: select_str (str): The select string. Returns: tuple: (function_name, params, columns) Examples: - !first value, recordtime - !closest(timestamp, 1d, 2d) value, recordtime - !any value - !last value, recordtime """ select_str = select_str.strip() # Extract function name if select_str.startswith("!"): if "(" in select_str: # Handle function with parameters func_name = select_str[1 : select_str.find("(")].strip() params_str = select_str[ select_str.find("(") + 1 : select_str.find(")") ].strip() params = [p.strip() for p in params_str.split(",")] cols_str = select_str[select_str.find(")") + 1 :].strip() else: # Handle function without parameters parts = select_str[1:].split(" ", 1) func_name = parts[0].strip() params = [] cols_str = parts[1] if len(parts) > 1 else "" else: func_name, params = None, None cols = select_str.split(",") return func_name, params, cols # Extract columns to select if cols_str: columns = [c.strip() for c in cols_str.split(",")] else: columns = [] return func_name, params, columns
[docs] def parse_col_list(cols: list[str], var_name: str): """Parse a list of columns to a select string. Args: cols (list[str]): The columns to parse. var_name (str): The variable name. Returns: str: The select string. Examples: - ["value", "recordtime"] -> "value as var_name_value, recordtime as var_name_recordtime" - ["value"] -> "value AS var_name" """ if len(cols) > 1: select = ", ".join( [f"{col.strip()} AS {var_name}_{col.strip()}" for col in cols] ) else: select = f"{cols[0].strip()} AS {var_name}" return select
[docs] def convert_interval_sql(interval_str): """Convert shorthand interval notation to SQL INTERVAL syntax. Supports: y->year, M->month, w->week, d->day, h->hour, m->minute, s->second Args: interval_str (str): The interval string. Returns: str: The SQL INTERVAL syntax. Examples: - "2h" -> "INTERVAL '2' HOUR" - "-2h" -> "INTERVAL '-2' HOUR" - "30m" -> "INTERVAL '30' MINUTE" """ if not interval_str or interval_str == "0": return "INTERVAL 0 SECOND" # Remove leading + if present interval_str = interval_str.strip().lstrip("+") # Parse the interval unit_map = { "y": "YEAR", "M": "MONTH", "w": "WEEK", "d": "DAY", "h": "HOUR", "m": "MINUTE", "s": "SECOND", } # Extract number and unit match = re.match(r"^(-?\d+)([yMwdhms])$", interval_str) if not match: raise ValueError( f"Invalid interval format: {interval_str}. Use format like '2h', '-2h', '30m', etc. (Supports y->year, M->month, w->week, d->day, h->hour, m->minute, s->second)" ) number, unit = match.groups() sql_unit = unit_map.get(unit) if not sql_unit: raise ValueError( f"Invalid unit: {unit}. Valid units are: {', '.join(unit_map.keys())}" ) return f"INTERVAL {number} {sql_unit}"
[docs] def get_cb_values( df_chunk, primary_key="case_id", tmin_col="tmin", tmax_col="tmax", ttarget_col=None ): """Get the case bounds CTE. This is used to apply a time filter to a native dynamic variable. Args: df_chunk (pd.DataFrame): A chunk of the cohort dataframe. primary_key (str): The primary key column. Defaults to "case_id". tmin_col (str): The column containing the tmin. Defaults to "tmin". tmax_col (str): The column containing the tmax. Defaults to "tmax". ttarget_col (str): The column containing the target time. Defaults to None. (For !closest) Returns: str: The case bounds CTE. """ # Drop rows with NaN tmin or tmax df_chunk = df_chunk.dropna(subset=[tmin_col, tmax_col]) values = [] for _, row in df_chunk.iterrows(): value = f"('{row['case_id']}', '{row[tmin_col]}', '{row[tmax_col]}'" if ttarget_col: value += f", '{row[ttarget_col]}'" if primary_key != "case_id": value += f", '{row[primary_key]}'" value += ")" values.append(value) cb_values = ", ".join(values) cb = f"""WITH case_bounds (case_id, case_tmin, case_tmax{', ' + ttarget_col if ttarget_col else ''}{', ' + primary_key if primary_key != 'case_id' else ''}) AS ( VALUES {cb_values} )""" return cb
[docs] def build_base_query( cb: str, var: "NativeDynamic", database: str, primary_key: str, extr_end_date: str, ttarget_col: str = None, ): """Build the base query for a native dynamic variable. Args: cb (str): The case bounds CTE. var (NativeDynamic): The native dynamic variable. database (str): The database name. primary_key (str): The primary key column. Defaults to "case_id". extr_end_date (str): The extraction end date. ttarget_col (str): The target time column. Defaults to None. Returns: tuple: Query, table info, and columns to keep. """ table = var.table table_info = VARS["tables"][table] value_col, recordtime_col, case_id_col = ( table_info["value"], table_info["recordtime"], table_info["case_id"], ) optional_columns = { col: table_info[col] for col in table_info if col not in ["value", "recordtime", "case_id"] } optional_columns_str = ", ".join( [f"v.{table_info[col]} AS {col}" for col in optional_columns] ) value = ( f"CAST(v.{value_col} AS {var.value_dtype}) AS value" if var.value_dtype else f"v.{value_col} AS value" ) query = f""" {cb} SELECT v.{case_id_col} AS case_id, v.{recordtime_col} AS recordtime, {value} {f', {optional_columns_str}' if optional_columns_str else ''} {f', cb.{primary_key}' if primary_key != 'case_id' else ''} {f', cb.{ttarget_col}' if ttarget_col else ''} FROM {database}.{table} v JOIN case_bounds cb ON v.{case_id_col} = cb.case_id WHERE v.{case_id_col} IN (SELECT DISTINCT case_id FROM case_bounds) AND v.{recordtime_col} BETWEEN cb.case_tmin AND cb.case_tmax AND {extr_end_date} {f"AND ({var.where})" if var.where else ''} """ keep_cols = [primary_key, "recordtime", "value"] + [col for col in optional_columns] return query, table_info, keep_cols
[docs] def get_time_series(obs: pd.DataFrame, col_name: str, tdelta: str = "0"): """Returns a series with the time since col_name (in datetime format) plus tdelta (in pd.Timedelta format). Args: obs (pd.DataFrame): The observation dataframe. col_name (str): The column name. tdelta (str): The time delta. Defaults to "0". Returns: pd.Series: A series with the time column + tdelta. """ tdelta = pd.to_timedelta(tdelta) return obs[col_name] + tdelta
[docs] def parse_time_args( obs: pd.DataFrame, tmin: str | tuple[str, str] | None = None, tmax: str | tuple[str, str] | None = None, ): """Parse the time arguments to get the time series. Args: obs (pd.DataFrame): The observation dataframe. tmin (str | tuple[str, str] | None): The tmin argument. Defaults to None. tmax (str | tuple[str, str] | None): The tmax argument. Defaults to None. Returns: tuple: tmin (pd.Series), tmax (pd.Series). """ if isinstance(tmin, tuple): col_name, tdelta = tmin tmin = get_time_series(obs, col_name, tdelta=tdelta) else: tmin = get_time_series(obs, col_name=tmin) if isinstance(tmax, tuple): col_name, tdelta = tmax tmax = get_time_series(obs, col_name, tdelta=tdelta) else: tmax = get_time_series(obs, col_name=tmax) return tmin, tmax
[docs] def sql_to_pandas(query: str) -> str: """ Converts SQL-like expressions to pandas-compatible expressions. Supported operations include IN, NOT IN, LIKE, AND, OR. Args: query (str): The SQL-like expression. Returns: str: The equivalent pandas-compatible expression. """ if not any(keyword in query for keyword in ["NOT IN", "IN", "LIKE", "OR", "AND"]): return query # Convert NOT IN expressions query = re.sub(r"(\w+)\s+NOT IN\s+\(([^)]+)\)", r"~\1.isin([\2])", query) # Convert IN expressions query = re.sub(r"(\w+)\s+IN\s+\(([^)]+)\)", r"\1.isin([\2])", query) # Convert LIKE expressions query = re.sub( r"(\w+)\s+LIKE\s+'([^']+)%'", r"\1.str.contains(r'^\2', regex=True)", query ) # Convert OR conditions query = re.sub(r"\s+OR\s+", " | ", query) # Convert AND conditions query = re.sub(r"\s+AND\s+", " & ", query) print( "Warning: SQL-expressions for NativeStatic variables are deprecated and will be removed in a future version. Please create new variables using the pandas syntax directly." ) return query
[docs] def get_variables(var_name=""): """ Get variables from the VARS dictionary. """ if not var_name: return json.dumps(VARS["variables"], indent=4) else: return json.dumps(VARS["variables"][var_name], indent=4)
[docs] def extract_df_data( df: pd.DataFrame, col_dict: dict = None, filter_dict: "dict[str, list:str]" = None, exact_match: bool = False, remove_prefix: bool = False, drop: bool = False, ) -> pd.DataFrame: """ Extracts data from a DataFrame. Parameters: df (pandas.DataFrame): The DataFrame to operate on. col_dict (dict, optional): A dictionary mapping column names to new names. Defaults to None. filter_dict (dict[str, list:str], optional): A dictionary where keys are column names and values are lists of values to filter rows by (may include regex pattern for exact_match=False). Defaults to None. exact_match (bool, optional): If True, performs exact matching when filtering. Defaults to False. remove_prefix (bool, optional): If True, removes prefix from default_key. Defaults to False. drop (bool, optional): If True, drop all columns not specified in col_dict. Returns: pandas.DataFrame: A DataFrame containing the extracted data from the original DataFrame. """ # rename columns if applicable if col_dict: df = df.rename(columns=col_dict) if drop: df = df.filter(items=col_dict.values()) # filter rows for given values if filter_dict: for column, values in filter_dict.items(): if exact_match: df = df[df[column].isin(values)] else: if df[column].dtype == "object": regex = "|".join(values) df = df[df[column].str.contains(regex, na=False, regex=True)] else: df = df[df[column].isin(values)] return df
[docs] def merge_consecutive( data: pd.DataFrame, primary_key: str, recordtime: str = "recordtime", recordtime_end: str = "recordtime_end", time_threshold: pd.Timedelta = pd.Timedelta(minutes=30), ) -> pd.DataFrame: """ Combine consecutive sessions (<30min separation) of ecmo_vv_icu into a single session. Args: data (pd.DataFrame): The data to merge. primary_key (str): The primary key column. recordtime (str): The recordtime column. Defaults to "recordtime". recordtime_end (str): The recordtime_end column. Defaults to "recordtime_end". time_threshold (pd.Timedelta): The time threshold. Defaults to 30 minutes. Returns: pd.DataFrame: The merged data. """ # Sort by primary_key and recordtime data = data.sort_values(by=[primary_key, recordtime]).reset_index(drop=True) def combine_consecutive_sessions(group): group["last_event_recordtime_end"] = group["recordtime_end"].shift(1) group["time_since_last_event"] = ( group["recordtime"] - group["last_event_recordtime_end"] ) group["merge_with_previous"] = group["time_since_last_event"] < time_threshold group["merge_group"] = (~group["merge_with_previous"]).cumsum() gr_merged = ( group.groupby(by=[primary_key, "merge_group"]) .agg({"value": "first", "recordtime": "first", "recordtime_end": "last"}) .reset_index() ) return gr_merged grouped = data.groupby(primary_key) grouped = grouped.apply(combine_consecutive_sessions) return grouped.reset_index(drop=True)
def _extract_with_co6_parent_chunked( chunk_info: tuple[int, int], df: pd.DataFrame, dbpass: str | None, conn_args: dict, password_file: str | None, primary_key: str, table: str, parent_name: str, suffixes: dict[str, str], database: str, ): start_idx, end_idx = chunk_info if dbpass: conn = utils.ImpalaConnector(password=dbpass, **conn_args) else: conn = utils.ImpalaConnector.with_password_file( password_file=password_file, **conn_args ) # Build the case bounds CTE cb = get_cb_values( df.iloc[start_idx:end_idx], primary_key=primary_key, tmin_col="tmin", tmax_col="tmax", ) # Build JOIN clauses and SELECT expressions dynamically joins = [] selects = [] for suffix, col_name in suffixes.items(): alias = f"xc_{col_name}" joins.append( f""" JOIN {database}.{table} {alias} ON {alias}.c_parent_id = p.c_id AND {alias}.c_var_name = '{parent_name}{suffix}'""" ) cast = ( f"CAST({alias}.c_value AS DATE)" if col_name == "recordtime" else f"{alias}.c_value" ) selects.append(f"{cast} AS {col_name}") # Build the query query = f""" {cb} SELECT p.c_falnr AS case_id, {', '.join(selects)} {f', cb.{primary_key}' if primary_key != 'case_id' else ''} FROM {database}.{table} p {''.join(joins)} JOIN case_bounds cb ON p.c_falnr = cb.case_id WHERE p.c_var_name = '{parent_name}' AND CAST(xc_recordtime.c_value AS DATE) BETWEEN cb.case_tmin AND cb.case_tmax """ # Execute the query df = conn.df_read_sql(query, datetime_cols=["recordtime"]) return df
[docs] def extract_with_co6_parent( parent_name: str, suffixes: dict[str, str], df: pd.DataFrame, cohort: "Cohort", dbpass: str | None = None, table: str = "it_copra6_hierarchy_v2", chunk_size: int = 30000, ): """Extract variables from the specified table following the Co6 hierarchy schema. Will merge child elements of a parent variable. You may reference `this file <https://health-data.charite.de/_media/data-model/copra6/copra6parents_02_05_22.xlsx>`_ for available parent and child relations. Args: parent_name (str): The parent name. suffixes (dict[str, str]): Dictionary mapping suffixes to column names. Must include a suffix that maps to "recordtime". df (pd.DataFrame): Dataframe with case_id, tmin, tmax columns. [Copy of cohort.obs] cohort (Cohort): The cohort. dbpass (str): The database password (only to be passed from the Variable.extract() method, do not specify passwords in your code). table (str): The table to extract from. Defaults to "it_copra6_hierarchy_v2". chunk_size (int): The chunk size. Defaults to 30000. Returns: pd.DataFrame: The extracted data (columns: case_id, recordtime, plus any additional columns specified) Examples: >>> df = extract_with_co6_parent( >>> parent_name="Score_SOFA", >>> suffixes={ >>> "_Wert": "value", >>> "_Date": "recordtime", >>> }, >>> df=cohort.obs, >>> cohort=cohort, >>> ) """ if "recordtime" not in suffixes.values(): raise ValueError("suffixes must include one that maps to 'recordtime'") f_args = { "df": df, "dbpass": dbpass, "conn_args": cohort.conn_args, "password_file": cohort.password_file, "database": cohort.database, "primary_key": cohort.primary_key, "table": table, "parent_name": parent_name, "suffixes": suffixes, } # Rest of the function remains the same num_chunks = math.ceil(len(df) / chunk_size) if num_chunks == 1: df = _extract_with_co6_parent_chunked((0, len(df)), **f_args) else: process_func = partial(_extract_with_co6_parent_chunked, **f_args) chunks = [ (i * chunk_size, min((i + 1) * chunk_size, len(df))) for i in range(num_chunks) ] with multiprocessing.Pool() as pool: results = list( tqdm( pool.imap(process_func, chunks), total=num_chunks, desc=f"Extracting {parent_name}", unit="chunks", ) ) df = pd.concat(results, copy=False, ignore_index=True) return df