Model Deployment : Estimating Heart Failure Survival Risk Profiles From Cardiovascular, Hematologic And Metabolic Markers¶
- 1. Table of Contents
- 1.1 Data Background
- 1.2 Data Description
- 1.3 Data Quality Assessment
- 1.4 Data Preprocessing
- 1.5 Data Exploration
- 1.6 Predictive Model Development
- 1.6.1 Pre-Modelling Data Preparation
- 1.6.2 Data Splitting
- 1.6.3 Modelling Pipeline Development
- 1.6.4 Cox Proportional Hazards Regression Model Fitting | Hyperparameter Tuning | Validation
- 1.6.5 Cox Net Survival Model Fitting | Hyperparameter Tuning | Validation
- 1.6.6 Survival Tree Model Fitting | Hyperparameter Tuning | Validation
- 1.6.7 Random Survival Forest Model Fitting | Hyperparameter Tuning | Validation
- 1.6.8 Gradient Boosted Survival Model Fitting | Hyperparameter Tuning | Validation
- 1.6.9 Model Selection
- 1.6.10 Model Testing
- 1.6.11 Model Inference
- 1.7 Predictive Model Deployment Using Streamlit and Streamlit Community Cloud
- 2. Summary
- 3. References
1. Table of Contents ¶
This project implements the Cox Proportional Hazards Regression, Cox Net Survival, Survival Tree, Random Survival Forest, and Gradient Boosted Survival models as independent base learners using various helpful packages in Python to estimate the survival probabilities of right-censored survival time and status responses. The resulting predictions derived from the candidate models were evaluated in terms of their discrimination power using the Harrel's Concordance Index metric. Penalties including Ridge Regularization and Elastic Net Regularization were evaluated to impose constraints on the model coefficient updates, as applicable. Additionally, survival probability functions were estimated for model risk-groups and sampled individual cases. The final model was deployed as a prototype application with a web interface via Streamlit. All results were consolidated in a Summary presented at the end of the document.
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Cox Proportional Hazards Regression Models relates the time until an event occurs (such as death or disease progression) to one or more predictor variables. The model is expressed through its hazard function, which represents the risk of the event happening at a particular time for an individual, given that the individual has survived up to that time. The mathematical equation is represented by the baseline hazard function (referring to the hazard for an individual when all of their covariates are zero, representing the inherent risk of the event happening over time, but is not directly estimated in the Cox model. Instead, the focus is on how the covariates influence the hazard relative to this baseline) and an exponential term that modifies the baseline hazard based on the individual's covariates (Each covariate is associated with a regression coefficient which measures the strength and direction of the effect of the covariate on the hazard. The exponential function ensures that the hazard is always positive, as hazard values can’t be negative). The proportional hazards assumption in this model means that the ratio of hazards between any two individuals is constant over time and is determined by the differences in their covariates. The Cox model doesn’t require a specific form for the baseline hazard, making it flexible, while properly accounting for censored data, which is common in survival studies.
Regularization Methods, when applied to Cox Proportional Hazards regression, are primarily used to prevent overfitting and enhance the model's ability to generalize to new data. Overfitting in this context occurs when the model captures not only the true relationship between covariates and the hazard but also the noise in the training data, leading to poor performance on unseen data. To address this, regularization introduces a penalty for large coefficient values, which helps control model complexity. In Cox regression, this is achieved by adding a regularization term to the objective function, penalizing large coefficients and encouraging simpler models. By constraining the size of the coefficients, regularization reduces the risk of overfitting, helping the model to capture the underlying patterns in the data more effectively. Ultimately, this enables the model to generalize better, as it is less likely to fit the training data too closely and more likely to identify the true associations between predictors and the hazard.
Streamlit is an open-source Python library that simplifies the creation and deployment of web applications for machine learning and data science projects. It allows developers and data scientists to turn Python scripts into interactive web apps quickly without requiring extensive web development knowledge. Streamlit seamlessly integrates with popular Python libraries such as Pandas, Matplotlib, Plotly, and TensorFlow, allowing one to leverage existing data processing and visualization tools within the application. Streamlit apps can be easily deployed on various platforms, including Streamlit Community Cloud, Heroku, or any cloud service that supports Python web applications.
Streamlit Community Cloud, formerly known as Streamlit Sharing, is a free cloud-based platform provided by Streamlit that allows users to easily deploy and share Streamlit apps online. It is particularly popular among data scientists, machine learning engineers, and developers for quickly showcasing projects, creating interactive demos, and sharing data-driven applications with a wider audience without needing to manage server infrastructure. Significant features include free hosting (Streamlit Community Cloud provides free hosting for Streamlit apps, making it accessible for users who want to share their work without incurring hosting costs), easy deployment (users can connect their GitHub repository to Streamlit Community Cloud, and the app is automatically deployed from the repository), continuous deployment (if the code in the connected GitHub repository is updated, the app is automatically redeployed with the latest changes), sharing capabilities (once deployed, apps can be shared with others via a simple URL, making it easy for collaborators, stakeholders, or the general public to access and interact with the app), built-in authentication (users can restrict access to their apps using GitHub-based authentication, allowing control over who can view and interact with the app), and community support (the platform is supported by a community of users and developers who share knowledge, templates, and best practices for building and deploying Streamlit apps).
1.1. Data Background ¶
An open Heart Failure Dataset from Papers With Code (with all credits attributed to Saurav Mishra) was used for the analysis as consolidated from the following primary source:
- Research Paper entitled A Comparative Study for Time-to-Event Analysis and Survival Prediction for Heart Failure Condition using Machine Learning Techniques from the Journal of Electronics, Electromedical Engineering, and Medical Informatics
- Research Paper entitled Machine Learning Can Predict Survival of Patients with Heart Failure from Serum Creatinine and Ejection Fraction Alone from the BMC Medical Informatics and Decision Making Journal
This study hypothesized that cardiovascular, hematologic, and metabolic markers influence heart failure survival risks between patients.
The event status and survival duration variables for the study are:
- DEATH_EVENT - Status of the patient within the follow-up period (0, censored | 1, death)
- TIME - Follow-up period (Days)
The predictor variables for the study are:
- AGE - Patient's age (Years)
- ANAEMIA - Hematologic marker for the indication of anaemia (decrease of red blood cells or hemoglobin level in the blood) (0, Absent | 1 Present)
- CREATININE_PHOSPHOKINASE - Metabolic marker for the level of the CPK enzyme in the blood (mcg/L)
- DIABETES - Metabolic marker for the indication of diabetes (0, Absent | 1 Present)
- EJECTION_FRACTION - Cardiovascular marker for the ejection fraction (percentage of blood leaving the heart at each contraction) (%)
- HIGH_BLOOD_PRESSURE - Cardiovascular marker for the indication of hypertension (0, Absent | 1 Present)
- PLATELETS - Hematologic marker for the platelets in the blood (kiloplatelets/mL)
- SERUM_CREATININE - Metabolic marker for the level of creatinine in the blood (mg/dL)
- SERUM_SODIUM - Metabolic marker for the level of sodium in the blood (mEq/L)
- SEX - Patient's sex (0, Female | 1, Male)
- SMOKING - Cardiovascular marker for the indication of smoking (0, Absent | 1 Present)
1.2. Data Description ¶
- The dataset is comprised of:
- 299 rows (observations)
- 13 columns (variables)
- 2/13 event | duration (object | numeric)
- DEATH_EVENT
- TIME
- 6/13 predictor (numeric)
- AGE
- CREATININE_PHOSPHOKINASE
- EJECTION_FRACTION
- PLATELETS
- SERUM_CREATININE
- SERUM_SODIUM
- 5/13 predictor (object)
- ANAEMIA
- DIABETES
- HIGH_BLOOD_PRESSURE
- SEX
- SMOKING
- 2/13 event | duration (object | numeric)
##################################
# Loading Python Libraries
##################################
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import itertools
import joblib
%matplotlib inline
from operator import add,mul,truediv
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PowerTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import KFold, RepeatedKFold
from sklearn.inspection import permutation_importance
from sklearn.feature_selection import SelectKBest
from statsmodels.nonparametric.smoothers_lowess import lowess
from scipy import stats
from scipy.stats import ttest_ind, chi2_contingency
from scipy.stats import pointbiserialr
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.utils import concordance_index
from lifelines.statistics import logrank_test
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest, GradientBoostingSurvivalAnalysis
from sksurv.tree import SurvivalTree
from sksurv.metrics import concordance_index_censored
from sksurv.nonparametric import kaplan_meier_estimator
import shap
import warnings
warnings.filterwarnings('ignore')
##################################
# Defining file paths
##################################
DATASETS_ORIGINAL_PATH = r"datasets\original"
DATASETS_PREPROCESSED_PATH = r"datasets\preprocessed"
DATASETS_FINAL_PATH = r"datasets\final\complete"
DATASETS_FINAL_TRAIN_PATH = r"datasets\final\train"
DATASETS_FINAL_TRAIN_FEATURES_PATH = r"datasets\final\train\features"
DATASETS_FINAL_TRAIN_TARGET_PATH = r"datasets\final\train\target"
DATASETS_FINAL_VALIDATION_PATH = r"datasets\final\validation"
DATASETS_FINAL_VALIDATION_FEATURES_PATH = r"datasets\final\validation\features"
DATASETS_FINAL_VALIDATION_TARGET_PATH = r"datasets\final\validation\target"
DATASETS_FINAL_TEST_PATH = r"datasets\final\test"
DATASETS_FINAL_TEST_FEATURES_PATH = r"datasets\final\test\features"
DATASETS_FINAL_TEST_TARGET_PATH = r"datasets\final\test\target"
MODELS_PATH = r"models"
PARAMETERS_PATH = r"parameters"
PIPELINES_PATH = r"pipelines"
##################################
# Loading the dataset
# from the DATASETS_ORIGINAL_PATH
##################################
heart_failure = pd.read_csv(os.path.join("..", DATASETS_ORIGINAL_PATH, "heart_failure_clinical_records_dataset.csv"))
##################################
# Performing a general exploration of the dataset
##################################
print('Dataset Dimensions: ')
display(heart_failure.shape)
Dataset Dimensions:
(299, 13)
##################################
# Verifying the column names
##################################
print('Column Names: ')
display(heart_failure.columns)
Column Names:
Index(['age', 'anaemia', 'creatinine_phosphokinase', 'diabetes', 'ejection_fraction', 'high_blood_pressure', 'platelets', 'serum_creatinine', 'serum_sodium', 'sex', 'smoking', 'time', 'DEATH_EVENT'], dtype='object')
##################################
# Removing trailing white spaces
# in column names
##################################
heart_failure.columns = [x.strip() for x in heart_failure.columns]
##################################
# Standardizing the column names
##################################
heart_failure.columns = ['AGE',
'ANAEMIA',
'CREATININE_PHOSPHOKINASE',
'DIABETES',
'EJECTION_FRACTION',
'HIGH_BLOOD_PRESSURE',
'PLATELETS',
'SERUM_CREATININE',
'SERUM_SODIUM',
'SEX',
'SMOKING',
'TIME',
'DEATH_EVENT']
##################################
# Verifying the corrected column names
##################################
print('Column Names: ')
display(heart_failure.columns)
Column Names:
Index(['AGE', 'ANAEMIA', 'CREATININE_PHOSPHOKINASE', 'DIABETES', 'EJECTION_FRACTION', 'HIGH_BLOOD_PRESSURE', 'PLATELETS', 'SERUM_CREATININE', 'SERUM_SODIUM', 'SEX', 'SMOKING', 'TIME', 'DEATH_EVENT'], dtype='object')
##################################
# Listing the column names and data types
##################################
print('Column Names and Data Types:')
display(heart_failure.dtypes)
Column Names and Data Types:
AGE float64 ANAEMIA int64 CREATININE_PHOSPHOKINASE int64 DIABETES int64 EJECTION_FRACTION int64 HIGH_BLOOD_PRESSURE int64 PLATELETS float64 SERUM_CREATININE float64 SERUM_SODIUM int64 SEX int64 SMOKING int64 TIME int64 DEATH_EVENT int64 dtype: object
##################################
# Taking a snapshot of the dataset
##################################
heart_failure.head()
AGE | ANAEMIA | CREATININE_PHOSPHOKINASE | DIABETES | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | PLATELETS | SERUM_CREATININE | SERUM_SODIUM | SEX | SMOKING | TIME | DEATH_EVENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 75.0 | 0 | 582 | 0 | 20 | 1 | 265000.00 | 1.9 | 130 | 1 | 0 | 4 | 1 |
1 | 55.0 | 0 | 7861 | 0 | 38 | 0 | 263358.03 | 1.1 | 136 | 1 | 0 | 6 | 1 |
2 | 65.0 | 0 | 146 | 0 | 20 | 0 | 162000.00 | 1.3 | 129 | 1 | 1 | 7 | 1 |
3 | 50.0 | 1 | 111 | 0 | 20 | 0 | 210000.00 | 1.9 | 137 | 1 | 0 | 7 | 1 |
4 | 65.0 | 1 | 160 | 1 | 20 | 0 | 327000.00 | 2.7 | 116 | 0 | 0 | 8 | 1 |
##################################
# Setting certain integer variables
# to float values
##################################
float_columns = ['AGE',
'CREATININE_PHOSPHOKINASE',
'EJECTION_FRACTION',
'PLATELETS',
'SERUM_CREATININE',
'SERUM_SODIUM',
'TIME']
heart_failure[float_columns] = heart_failure[float_columns].astype(float)
##################################
# Setting certain integer variables
# to object or categorical values
##################################
int_columns = ['ANAEMIA',
'DIABETES',
'HIGH_BLOOD_PRESSURE',
'SMOKING',
'SEX']
heart_failure[int_columns] = heart_failure[int_columns].astype(object)
heart_failure['DEATH_EVENT'] = heart_failure['DEATH_EVENT'].astype('category')
##################################
# Saving a copy of the original dataset
##################################
heart_failure_original = heart_failure.copy()
##################################
# Setting the levels of the dichotomous categorical variables
# to boolean values
##################################
heart_failure['DEATH_EVENT'] = heart_failure['DEATH_EVENT'].cat.set_categories([0, 1], ordered=True)
heart_failure['SEX'] = heart_failure['SEX'].replace({0: 'Female', 1: 'Male'})
heart_failure[int_columns] = heart_failure[int_columns].replace({0: 'Absent', 1: 'Present'})
##################################
# Listing the column names and data types
##################################
print('Column Names and Data Types:')
display(heart_failure.dtypes)
Column Names and Data Types:
AGE float64 ANAEMIA object CREATININE_PHOSPHOKINASE float64 DIABETES object EJECTION_FRACTION float64 HIGH_BLOOD_PRESSURE object PLATELETS float64 SERUM_CREATININE float64 SERUM_SODIUM float64 SEX object SMOKING object TIME float64 DEATH_EVENT category dtype: object
##################################
# Taking a snapshot of the dataset
##################################
heart_failure.head()
AGE | ANAEMIA | CREATININE_PHOSPHOKINASE | DIABETES | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | PLATELETS | SERUM_CREATININE | SERUM_SODIUM | SEX | SMOKING | TIME | DEATH_EVENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 75.0 | Absent | 582.0 | Absent | 20.0 | Present | 265000.00 | 1.9 | 130.0 | Male | Absent | 4.0 | 1 |
1 | 55.0 | Absent | 7861.0 | Absent | 38.0 | Absent | 263358.03 | 1.1 | 136.0 | Male | Absent | 6.0 | 1 |
2 | 65.0 | Absent | 146.0 | Absent | 20.0 | Absent | 162000.00 | 1.3 | 129.0 | Male | Present | 7.0 | 1 |
3 | 50.0 | Present | 111.0 | Absent | 20.0 | Absent | 210000.00 | 1.9 | 137.0 | Male | Absent | 7.0 | 1 |
4 | 65.0 | Present | 160.0 | Present | 20.0 | Absent | 327000.00 | 2.7 | 116.0 | Female | Absent | 8.0 | 1 |
##################################
# Performing a general exploration
# of the numeric variables
##################################
print('Numeric Variable Summary:')
display(heart_failure.describe(include='number').transpose())
Numeric Variable Summary:
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
AGE | 299.0 | 60.833893 | 11.894809 | 40.0 | 51.0 | 60.0 | 70.0 | 95.0 |
CREATININE_PHOSPHOKINASE | 299.0 | 581.839465 | 970.287881 | 23.0 | 116.5 | 250.0 | 582.0 | 7861.0 |
EJECTION_FRACTION | 299.0 | 38.083612 | 11.834841 | 14.0 | 30.0 | 38.0 | 45.0 | 80.0 |
PLATELETS | 299.0 | 263358.029264 | 97804.236869 | 25100.0 | 212500.0 | 262000.0 | 303500.0 | 850000.0 |
SERUM_CREATININE | 299.0 | 1.393880 | 1.034510 | 0.5 | 0.9 | 1.1 | 1.4 | 9.4 |
SERUM_SODIUM | 299.0 | 136.625418 | 4.412477 | 113.0 | 134.0 | 137.0 | 140.0 | 148.0 |
TIME | 299.0 | 130.260870 | 77.614208 | 4.0 | 73.0 | 115.0 | 203.0 | 285.0 |
##################################
# Performing a general exploration
# of the object and categorical variables
##################################
print('Categorical Variable Summary:')
display(heart_failure.describe(include=['category','object']).transpose())
Categorical Variable Summary:
count | unique | top | freq | |
---|---|---|---|---|
ANAEMIA | 299 | 2 | Absent | 170 |
DIABETES | 299 | 2 | Absent | 174 |
HIGH_BLOOD_PRESSURE | 299 | 2 | Absent | 194 |
SEX | 299 | 2 | Male | 194 |
SMOKING | 299 | 2 | Absent | 203 |
DEATH_EVENT | 299 | 2 | 0 | 203 |
1.3. Data Quality Assessment ¶
Data quality findings based on assessment are as follows:
- No duplicated rows observed. All entries are unique.
- No missing data noted for any variable with Null.Count>0 and Fill.Rate<1.0.
- Low variance observed for two numeric predictors with First.Second.Mode.Ratio>5.
- CREATININE_PHOSPHOKINASE: First.Second.Mode.Ratio = 11.75
- PLATELETS: First.Second.Mode.Ratio = 6.25
- No high skewness observed for the numeric predictor with Skewness>3 or Skewness<(-3).
- CREATININE_PHOSPHOKINASE: Skewness = +4.46
- SERUM_CREATININE: Skewness = +4.46
- No low variance observed for the numeric and categorical predictors with Unique.Count.Ratio>10.
##################################
# Counting the number of duplicated rows
##################################
heart_failure.duplicated().sum()
0
##################################
# Gathering the data types for each column
##################################
data_type_list = list(heart_failure.dtypes)
##################################
# Gathering the variable names for each column
##################################
variable_name_list = list(heart_failure.columns)
##################################
# Gathering the number of observations for each column
##################################
row_count_list = list([len(heart_failure)] * len(heart_failure.columns))
##################################
# Gathering the number of missing data for each column
##################################
null_count_list = list(heart_failure.isna().sum(axis=0))
##################################
# Gathering the number of non-missing data for each column
##################################
non_null_count_list = list(heart_failure.count())
##################################
# Gathering the missing data percentage for each column
##################################
fill_rate_list = map(truediv, non_null_count_list, row_count_list)
##################################
# Formulating the summary
# for all columns
##################################
all_column_quality_summary = pd.DataFrame(zip(variable_name_list,
data_type_list,
row_count_list,
non_null_count_list,
null_count_list,
fill_rate_list),
columns=['Column.Name',
'Column.Type',
'Row.Count',
'Non.Null.Count',
'Null.Count',
'Fill.Rate'])
display(all_column_quality_summary)
Column.Name | Column.Type | Row.Count | Non.Null.Count | Null.Count | Fill.Rate | |
---|---|---|---|---|---|---|
0 | AGE | float64 | 299 | 299 | 0 | 1.0 |
1 | ANAEMIA | object | 299 | 299 | 0 | 1.0 |
2 | CREATININE_PHOSPHOKINASE | float64 | 299 | 299 | 0 | 1.0 |
3 | DIABETES | object | 299 | 299 | 0 | 1.0 |
4 | EJECTION_FRACTION | float64 | 299 | 299 | 0 | 1.0 |
5 | HIGH_BLOOD_PRESSURE | object | 299 | 299 | 0 | 1.0 |
6 | PLATELETS | float64 | 299 | 299 | 0 | 1.0 |
7 | SERUM_CREATININE | float64 | 299 | 299 | 0 | 1.0 |
8 | SERUM_SODIUM | float64 | 299 | 299 | 0 | 1.0 |
9 | SEX | object | 299 | 299 | 0 | 1.0 |
10 | SMOKING | object | 299 | 299 | 0 | 1.0 |
11 | TIME | float64 | 299 | 299 | 0 | 1.0 |
12 | DEATH_EVENT | category | 299 | 299 | 0 | 1.0 |
##################################
# Counting the number of columns
# with Fill.Rate < 1.00
##################################
print('Number of Columns with Missing Data:', str(len(all_column_quality_summary[(all_column_quality_summary['Fill.Rate']<1)])))
Number of Columns with Missing Data: 0
##################################
# Gathering the metadata labels for each observation
##################################
row_metadata_list = heart_failure.index.values.tolist()
##################################
# Gathering the number of columns for each observation
##################################
column_count_list = list([len(heart_failure.columns)] * len(heart_failure))
##################################
# Gathering the number of missing data for each row
##################################
null_row_list = list(heart_failure.isna().sum(axis=1))
##################################
# Gathering the missing data percentage for each column
##################################
missing_rate_list = map(truediv, null_row_list, column_count_list)
##################################
# Exploring the rows
# for missing data
##################################
all_row_quality_summary = pd.DataFrame(zip(row_metadata_list,
column_count_list,
null_row_list,
missing_rate_list),
columns=['Row.Name',
'Column.Count',
'Null.Count',
'Missing.Rate'])
display(all_row_quality_summary)
Row.Name | Column.Count | Null.Count | Missing.Rate | |
---|---|---|---|---|
0 | 0 | 13 | 0 | 0.0 |
1 | 1 | 13 | 0 | 0.0 |
2 | 2 | 13 | 0 | 0.0 |
3 | 3 | 13 | 0 | 0.0 |
4 | 4 | 13 | 0 | 0.0 |
... | ... | ... | ... | ... |
294 | 294 | 13 | 0 | 0.0 |
295 | 295 | 13 | 0 | 0.0 |
296 | 296 | 13 | 0 | 0.0 |
297 | 297 | 13 | 0 | 0.0 |
298 | 298 | 13 | 0 | 0.0 |
299 rows × 4 columns
##################################
# Counting the number of rows
# with Fill.Rate < 1.00
##################################
print('Number of Rows with Missing Data:',str(len(all_row_quality_summary[all_row_quality_summary['Missing.Rate']>0])))
Number of Rows with Missing Data: 0
##################################
# Formulating the dataset
# with numeric columns only
##################################
heart_failure_numeric = heart_failure.select_dtypes(include=['number','int'])
##################################
# Gathering the variable names for each numeric column
##################################
numeric_variable_name_list = heart_failure_numeric.columns
##################################
# Gathering the minimum value for each numeric column
##################################
numeric_minimum_list = heart_failure_numeric.min()
##################################
# Gathering the mean value for each numeric column
##################################
numeric_mean_list = heart_failure_numeric.mean()
##################################
# Gathering the median value for each numeric column
##################################
numeric_median_list = heart_failure_numeric.median()
##################################
# Gathering the maximum value for each numeric column
##################################
numeric_maximum_list = heart_failure_numeric.max()
##################################
# Gathering the first mode values for each numeric column
##################################
numeric_first_mode_list = [heart_failure[x].value_counts(dropna=True).index.tolist()[0] for x in heart_failure_numeric]
##################################
# Gathering the second mode values for each numeric column
##################################
numeric_second_mode_list = [heart_failure[x].value_counts(dropna=True).index.tolist()[1] for x in heart_failure_numeric]
##################################
# Gathering the count of first mode values for each numeric column
##################################
numeric_first_mode_count_list = [heart_failure_numeric[x].isin([heart_failure[x].value_counts(dropna=True).index.tolist()[0]]).sum() for x in heart_failure_numeric]
##################################
# Gathering the count of second mode values for each numeric column
##################################
numeric_second_mode_count_list = [heart_failure_numeric[x].isin([heart_failure[x].value_counts(dropna=True).index.tolist()[1]]).sum() for x in heart_failure_numeric]
##################################
# Gathering the first mode to second mode ratio for each numeric column
##################################
numeric_first_second_mode_ratio_list = map(truediv, numeric_first_mode_count_list, numeric_second_mode_count_list)
##################################
# Gathering the count of unique values for each numeric column
##################################
numeric_unique_count_list = heart_failure_numeric.nunique(dropna=True)
##################################
# Gathering the number of observations for each numeric column
##################################
numeric_row_count_list = list([len(heart_failure_numeric)] * len(heart_failure_numeric.columns))
##################################
# Gathering the unique to count ratio for each numeric column
##################################
numeric_unique_count_ratio_list = map(truediv, numeric_unique_count_list, numeric_row_count_list)
##################################
# Gathering the skewness value for each numeric column
##################################
numeric_skewness_list = heart_failure_numeric.skew()
##################################
# Gathering the kurtosis value for each numeric column
##################################
numeric_kurtosis_list = heart_failure_numeric.kurtosis()
numeric_column_quality_summary = pd.DataFrame(zip(numeric_variable_name_list,
numeric_minimum_list,
numeric_mean_list,
numeric_median_list,
numeric_maximum_list,
numeric_first_mode_list,
numeric_second_mode_list,
numeric_first_mode_count_list,
numeric_second_mode_count_list,
numeric_first_second_mode_ratio_list,
numeric_unique_count_list,
numeric_row_count_list,
numeric_unique_count_ratio_list,
numeric_skewness_list,
numeric_kurtosis_list),
columns=['Numeric.Column.Name',
'Minimum',
'Mean',
'Median',
'Maximum',
'First.Mode',
'Second.Mode',
'First.Mode.Count',
'Second.Mode.Count',
'First.Second.Mode.Ratio',
'Unique.Count',
'Row.Count',
'Unique.Count.Ratio',
'Skewness',
'Kurtosis'])
display(numeric_column_quality_summary)
Numeric.Column.Name | Minimum | Mean | Median | Maximum | First.Mode | Second.Mode | First.Mode.Count | Second.Mode.Count | First.Second.Mode.Ratio | Unique.Count | Row.Count | Unique.Count.Ratio | Skewness | Kurtosis | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | AGE | 40.0 | 60.833893 | 60.0 | 95.0 | 60.00 | 50.0 | 33 | 27 | 1.222222 | 47 | 299 | 0.157191 | 0.423062 | -0.184871 |
1 | CREATININE_PHOSPHOKINASE | 23.0 | 581.839465 | 250.0 | 7861.0 | 582.00 | 66.0 | 47 | 4 | 11.750000 | 208 | 299 | 0.695652 | 4.463110 | 25.149046 |
2 | EJECTION_FRACTION | 14.0 | 38.083612 | 38.0 | 80.0 | 35.00 | 38.0 | 49 | 40 | 1.225000 | 17 | 299 | 0.056856 | 0.555383 | 0.041409 |
3 | PLATELETS | 25100.0 | 263358.029264 | 262000.0 | 850000.0 | 263358.03 | 221000.0 | 25 | 4 | 6.250000 | 176 | 299 | 0.588629 | 1.462321 | 6.209255 |
4 | SERUM_CREATININE | 0.5 | 1.393880 | 1.1 | 9.4 | 1.00 | 1.1 | 50 | 32 | 1.562500 | 40 | 299 | 0.133779 | 4.455996 | 25.828239 |
5 | SERUM_SODIUM | 113.0 | 136.625418 | 137.0 | 148.0 | 136.00 | 137.0 | 40 | 38 | 1.052632 | 27 | 299 | 0.090301 | -1.048136 | 4.119712 |
6 | TIME | 4.0 | 130.260870 | 115.0 | 285.0 | 250.00 | 187.0 | 7 | 7 | 1.000000 | 148 | 299 | 0.494983 | 0.127803 | -1.212048 |
##################################
# Counting the number of numeric columns
# with First.Second.Mode.Ratio > 5.00
##################################
len(numeric_column_quality_summary[(numeric_column_quality_summary['First.Second.Mode.Ratio']>5)])
2
##################################
# Identifying the numeric columns
# with First.Second.Mode.Ratio > 5.00
##################################
display(numeric_column_quality_summary[(numeric_column_quality_summary['First.Second.Mode.Ratio']>5)].sort_values(by=['First.Second.Mode.Ratio'], ascending=False))
Numeric.Column.Name | Minimum | Mean | Median | Maximum | First.Mode | Second.Mode | First.Mode.Count | Second.Mode.Count | First.Second.Mode.Ratio | Unique.Count | Row.Count | Unique.Count.Ratio | Skewness | Kurtosis | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | CREATININE_PHOSPHOKINASE | 23.0 | 581.839465 | 250.0 | 7861.0 | 582.00 | 66.0 | 47 | 4 | 11.75 | 208 | 299 | 0.695652 | 4.463110 | 25.149046 |
3 | PLATELETS | 25100.0 | 263358.029264 | 262000.0 | 850000.0 | 263358.03 | 221000.0 | 25 | 4 | 6.25 | 176 | 299 | 0.588629 | 1.462321 | 6.209255 |
##################################
# Counting the number of numeric columns
# with Unique.Count.Ratio > 10.00
##################################
len(numeric_column_quality_summary[(numeric_column_quality_summary['Unique.Count.Ratio']>10)])
0
##################################
# Counting the number of numeric columns
# with Skewness > 3.00 or Skewness < -3.00
##################################
len(numeric_column_quality_summary[(numeric_column_quality_summary['Skewness']>3) | (numeric_column_quality_summary['Skewness']<(-3))])
2
##################################
# Identifying the numeric columns
# with Skewness > 3.00 or Skewness < -3.00
##################################
display(numeric_column_quality_summary[(numeric_column_quality_summary['Skewness']>3) | (numeric_column_quality_summary['Skewness']<(-3))])
Numeric.Column.Name | Minimum | Mean | Median | Maximum | First.Mode | Second.Mode | First.Mode.Count | Second.Mode.Count | First.Second.Mode.Ratio | Unique.Count | Row.Count | Unique.Count.Ratio | Skewness | Kurtosis | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | CREATININE_PHOSPHOKINASE | 23.0 | 581.839465 | 250.0 | 7861.0 | 582.0 | 66.0 | 47 | 4 | 11.7500 | 208 | 299 | 0.695652 | 4.463110 | 25.149046 |
4 | SERUM_CREATININE | 0.5 | 1.393880 | 1.1 | 9.4 | 1.0 | 1.1 | 50 | 32 | 1.5625 | 40 | 299 | 0.133779 | 4.455996 | 25.828239 |
##################################
# Formulating the dataset
# with object or categorical column only
##################################
heart_failure_object = heart_failure.select_dtypes(include=['object','category'])
##################################
# Gathering the variable names for the object or categorical column
##################################
categorical_variable_name_list = heart_failure_object.columns
##################################
# Gathering the first mode values for the object or categorical column
##################################
categorical_first_mode_list = [heart_failure[x].value_counts().index.tolist()[0] for x in heart_failure_object]
##################################
# Gathering the second mode values for each object or categorical column
##################################
categorical_second_mode_list = [heart_failure[x].value_counts().index.tolist()[1] for x in heart_failure_object]
##################################
# Gathering the count of first mode values for each object or categorical column
##################################
categorical_first_mode_count_list = [heart_failure_object[x].isin([heart_failure[x].value_counts(dropna=True).index.tolist()[0]]).sum() for x in heart_failure_object]
##################################
# Gathering the count of second mode values for each object or categorical column
##################################
categorical_second_mode_count_list = [heart_failure_object[x].isin([heart_failure[x].value_counts(dropna=True).index.tolist()[1]]).sum() for x in heart_failure_object]
##################################
# Gathering the first mode to second mode ratio for each object or categorical column
##################################
categorical_first_second_mode_ratio_list = map(truediv, categorical_first_mode_count_list, categorical_second_mode_count_list)
##################################
# Gathering the count of unique values for each object or categorical column
##################################
categorical_unique_count_list = heart_failure_object.nunique(dropna=True)
##################################
# Gathering the number of observations for each object or categorical column
##################################
categorical_row_count_list = list([len(heart_failure_object)] * len(heart_failure_object.columns))
##################################
# Gathering the unique to count ratio for each object or categorical column
##################################
categorical_unique_count_ratio_list = map(truediv, categorical_unique_count_list, categorical_row_count_list)
categorical_column_quality_summary = pd.DataFrame(zip(categorical_variable_name_list,
categorical_first_mode_list,
categorical_second_mode_list,
categorical_first_mode_count_list,
categorical_second_mode_count_list,
categorical_first_second_mode_ratio_list,
categorical_unique_count_list,
categorical_row_count_list,
categorical_unique_count_ratio_list),
columns=['Categorical.Column.Name',
'First.Mode',
'Second.Mode',
'First.Mode.Count',
'Second.Mode.Count',
'First.Second.Mode.Ratio',
'Unique.Count',
'Row.Count',
'Unique.Count.Ratio'])
display(categorical_column_quality_summary)
Categorical.Column.Name | First.Mode | Second.Mode | First.Mode.Count | Second.Mode.Count | First.Second.Mode.Ratio | Unique.Count | Row.Count | Unique.Count.Ratio | |
---|---|---|---|---|---|---|---|---|---|
0 | ANAEMIA | Absent | Present | 170 | 129 | 1.317829 | 2 | 299 | 0.006689 |
1 | DIABETES | Absent | Present | 174 | 125 | 1.392000 | 2 | 299 | 0.006689 |
2 | HIGH_BLOOD_PRESSURE | Absent | Present | 194 | 105 | 1.847619 | 2 | 299 | 0.006689 |
3 | SEX | Male | Female | 194 | 105 | 1.847619 | 2 | 299 | 0.006689 |
4 | SMOKING | Absent | Present | 203 | 96 | 2.114583 | 2 | 299 | 0.006689 |
5 | DEATH_EVENT | 0 | 1 | 203 | 96 | 2.114583 | 2 | 299 | 0.006689 |
##################################
# Counting the number of object or categorical columns
# with First.Second.Mode.Ratio > 5.00
##################################
len(categorical_column_quality_summary[(categorical_column_quality_summary['First.Second.Mode.Ratio']>5)])
0
##################################
# Counting the number of object or categorical columns
# with Unique.Count.Ratio > 10.00
##################################
len(categorical_column_quality_summary[(categorical_column_quality_summary['Unique.Count.Ratio']>10)])
0
1.4. Data Preprocessing ¶
Yeo-Johnson Transformation applies a new family of distributions that can be used without restrictions, extending many of the good properties of the Box-Cox power family. Similar to the Box-Cox transformation, the method also estimates the optimal value of lambda but has the ability to transform both positive and negative values by inflating low variance data and deflating high variance data to create a more uniform data set. While there are no restrictions in terms of the applicable values, the interpretability of the transformed values is more diminished as compared to the other methods.
- Data transformation and scaling is necessary to address excessive outliers and high skewness as observed on several numeric predictors:
- CREATININE_PHOSPHOKINASE: Skewness = +4.463, Outlier.Count = 29, Outlier.Ratio = 0.096
- SERUM_CREATININE: Skewness = +4.456, Outlier.Count = 29, Outlier Ratio = 0.096
- PLATELETS: Skewness = +1.462, Outlier.Count = 21, Outlier.Ratio = 0.070
- Most variables achieved symmetrical distributions with minimal outliers after evaluating a Yeo-Johnson transformation, except for:
- PLATELETS: Skewness = +1.155, Outlier.Count = 18, Outlier.Ratio = 0.060
- Among pairwise combinations of variables in the training subset, sufficiently high correlation values were observed but with no excessive multicollinearity noted:
- TIME and DEATH_EVENT: Point.Biserial.Correlation = -0.530
- SMOKING and SEX: Phi.Coefficient = +0.450
- SERUM_CREATININE and DEATH_EVENT: Point.Biserial.Correlation = +0.290
- AGE and DEATH_EVENT: Point.Biserial.Correlation = +0.250
#################################
# Creating a dataset copy
# for correlation analysis
##################################
heart_failure_correlation = heart_failure_original.copy()
display(heart_failure_correlation)
AGE | ANAEMIA | CREATININE_PHOSPHOKINASE | DIABETES | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | PLATELETS | SERUM_CREATININE | SERUM_SODIUM | SEX | SMOKING | TIME | DEATH_EVENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 75.0 | 0 | 582.0 | 0 | 20.0 | 1 | 265000.00 | 1.9 | 130.0 | 1 | 0 | 4.0 | 1 |
1 | 55.0 | 0 | 7861.0 | 0 | 38.0 | 0 | 263358.03 | 1.1 | 136.0 | 1 | 0 | 6.0 | 1 |
2 | 65.0 | 0 | 146.0 | 0 | 20.0 | 0 | 162000.00 | 1.3 | 129.0 | 1 | 1 | 7.0 | 1 |
3 | 50.0 | 1 | 111.0 | 0 | 20.0 | 0 | 210000.00 | 1.9 | 137.0 | 1 | 0 | 7.0 | 1 |
4 | 65.0 | 1 | 160.0 | 1 | 20.0 | 0 | 327000.00 | 2.7 | 116.0 | 0 | 0 | 8.0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
294 | 62.0 | 0 | 61.0 | 1 | 38.0 | 1 | 155000.00 | 1.1 | 143.0 | 1 | 1 | 270.0 | 0 |
295 | 55.0 | 0 | 1820.0 | 0 | 38.0 | 0 | 270000.00 | 1.2 | 139.0 | 0 | 0 | 271.0 | 0 |
296 | 45.0 | 0 | 2060.0 | 1 | 60.0 | 0 | 742000.00 | 0.8 | 138.0 | 0 | 0 | 278.0 | 0 |
297 | 45.0 | 0 | 2413.0 | 0 | 38.0 | 0 | 140000.00 | 1.4 | 140.0 | 1 | 1 | 280.0 | 0 |
298 | 50.0 | 0 | 196.0 | 0 | 45.0 | 0 | 395000.00 | 1.6 | 136.0 | 1 | 1 | 285.0 | 0 |
299 rows × 13 columns
##################################
# Initializing the correlation matrix
##################################
heart_failure_correlation_matrix = pd.DataFrame(np.zeros((len(heart_failure_correlation.columns), len(heart_failure_correlation.columns))),
columns=heart_failure_correlation.columns,
index=heart_failure_correlation.columns)
##################################
# Calculating different types
# of correlation coefficients
# per variable type
##################################
for i in range(len(heart_failure_correlation.columns)):
for j in range(i, len(heart_failure_correlation.columns)):
if i == j:
heart_failure_correlation_matrix.iloc[i, j] = 1.0
else:
if heart_failure_correlation.dtypes.iloc[i] == 'float64' and heart_failure_correlation.dtypes.iloc[j] == 'float64':
# Pearson correlation for two continuous variables
corr = heart_failure_correlation.iloc[:, i].corr(heart_failure_correlation.iloc[:, j])
elif heart_failure_correlation.dtypes.iloc[i] == 'int64' or heart_failure_correlation.dtypes.iloc[j] == 'int64':
# Point-biserial correlation for one continuous and one binary variable
continuous_var = heart_failure_correlation.iloc[:, i] if heart_failure_correlation.dtypes.iloc[i] == 'int64' else heart_failure_correlation.iloc[:, j]
binary_var = heart_failure_correlation.iloc[:, j] if heart_failure_correlation.dtypes.iloc[j] == 'int64' else heart_failure_correlation.iloc[:, i]
corr, _ = pointbiserialr(continuous_var, binary_var)
else:
# Phi coefficient for two binary variables
corr = heart_failure_correlation.iloc[:, i].corr(heart_failure_correlation.iloc[:, j])
heart_failure_correlation_matrix.iloc[i, j] = corr
heart_failure_correlation_matrix.iloc[j, i] = corr
##################################
# Plotting the correlation matrix
# for all pairwise combinations
# of numeric and categorical columns
##################################
plt.figure(figsize=(17, 8))
sns.heatmap(heart_failure_correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
plt.show()
##################################
# Formulating the dataset
# with numeric columns only
##################################
heart_failure_numeric = heart_failure.select_dtypes(include=['number','int'])
##################################
# Gathering the variable names for each numeric column
##################################
numeric_variable_name_list = heart_failure_numeric.columns
##################################
# Gathering the skewness value for each numeric column
##################################
numeric_skewness_list = heart_failure_numeric.skew()
##################################
# Computing the interquartile range
# for all columns
##################################
heart_failure_numeric_q1 = heart_failure_numeric.quantile(0.25)
heart_failure_numeric_q3 = heart_failure_numeric.quantile(0.75)
heart_failure_numeric_iqr = heart_failure_numeric_q3 - heart_failure_numeric_q1
##################################
# Gathering the outlier count for each numeric column
# based on the interquartile range criterion
##################################
numeric_outlier_count_list = ((heart_failure_numeric < (heart_failure_numeric_q1 - 1.5 * heart_failure_numeric_iqr)) | (heart_failure_numeric > (heart_failure_numeric_q3 + 1.5 * heart_failure_numeric_iqr))).sum()
##################################
# Gathering the number of observations for each column
##################################
numeric_row_count_list = list([len(heart_failure_numeric)] * len(heart_failure_numeric.columns))
##################################
# Gathering the unique to count ratio for each categorical column
##################################
numeric_outlier_ratio_list = map(truediv, numeric_outlier_count_list, numeric_row_count_list)
##################################
# Formulating the outlier summary
# for all numeric columns
##################################
numeric_column_outlier_summary = pd.DataFrame(zip(numeric_variable_name_list,
numeric_skewness_list,
numeric_outlier_count_list,
numeric_row_count_list,
numeric_outlier_ratio_list),
columns=['Numeric.Column.Name',
'Skewness',
'Outlier.Count',
'Row.Count',
'Outlier.Ratio'])
display(numeric_column_outlier_summary)
Numeric.Column.Name | Skewness | Outlier.Count | Row.Count | Outlier.Ratio | |
---|---|---|---|---|---|
0 | AGE | 0.423062 | 0 | 299 | 0.000000 |
1 | CREATININE_PHOSPHOKINASE | 4.463110 | 29 | 299 | 0.096990 |
2 | EJECTION_FRACTION | 0.555383 | 2 | 299 | 0.006689 |
3 | PLATELETS | 1.462321 | 21 | 299 | 0.070234 |
4 | SERUM_CREATININE | 4.455996 | 29 | 299 | 0.096990 |
5 | SERUM_SODIUM | -1.048136 | 4 | 299 | 0.013378 |
6 | TIME | 0.127803 | 0 | 299 | 0.000000 |
##################################
# Formulating the individual boxplots
# for all numeric columns
##################################
for column in heart_failure_numeric:
plt.figure(figsize=(17,1))
sns.boxplot(data=heart_failure_numeric, x=column)
##################################
# Formulating the dataset
# with numeric predictor columns only
##################################
heart_failure_numeric_predictor = heart_failure_numeric.drop('TIME', axis=1)
##################################
# Formulating the dataset
# with categorical or object columns only
##################################
heart_failure_categorical = heart_failure_original.select_dtypes(include=['category','object'])
##################################
# Evaluating a Yeo-Johnson Transformation
# to address the distributional
# shape of the variables
##################################
yeo_johnson_transformer = PowerTransformer(method='yeo-johnson',
standardize=True)
heart_failure_numeric_predictor_transformed_array = yeo_johnson_transformer.fit_transform(heart_failure_numeric_predictor)
##################################
# Formulating a new dataset object
# for the transformed data
##################################
heart_failure_numeric_predictor_transformed = pd.DataFrame(heart_failure_numeric_predictor_transformed_array,
columns=heart_failure_numeric_predictor.columns)
##################################
# Formulating the individual boxplots
# for all transformed numeric predictor columns
##################################
for column in heart_failure_numeric_predictor_transformed:
plt.figure(figsize=(17,1))
sns.boxplot(data=heart_failure_numeric_predictor_transformed, x=column)
##################################
# Formulating the outlier summary
# for all numeric predictor columns
##################################
numeric_variable_name_list = heart_failure_numeric_predictor_transformed.columns
numeric_skewness_list = heart_failure_numeric_predictor_transformed.skew()
heart_failure_numeric_predictor_transformed_q1 = heart_failure_numeric_predictor_transformed.quantile(0.25)
heart_failure_numeric_predictor_transformed_q3 = heart_failure_numeric_predictor_transformed.quantile(0.75)
heart_failure_numeric_predictor_transformed_iqr = heart_failure_numeric_predictor_transformed_q3 - heart_failure_numeric_predictor_transformed_q1
numeric_outlier_count_list = ((heart_failure_numeric_predictor_transformed < (heart_failure_numeric_predictor_transformed_q1 - 1.5 * heart_failure_numeric_predictor_transformed_iqr)) | (heart_failure_numeric_predictor_transformed > (heart_failure_numeric_predictor_transformed_q3 + 1.5 * heart_failure_numeric_predictor_transformed_iqr))).sum()
numeric_row_count_list = list([len(heart_failure_numeric_predictor_transformed)] * len(heart_failure_numeric_predictor_transformed.columns))
numeric_outlier_ratio_list = map(truediv, numeric_outlier_count_list, numeric_row_count_list)
numeric_column_outlier_summary = pd.DataFrame(zip(numeric_variable_name_list,
numeric_skewness_list,
numeric_outlier_count_list,
numeric_row_count_list,
numeric_outlier_ratio_list),
columns=['Numeric.Column.Name',
'Skewness',
'Outlier.Count',
'Row.Count',
'Outlier.Ratio'])
display(numeric_column_outlier_summary)
Numeric.Column.Name | Skewness | Outlier.Count | Row.Count | Outlier.Ratio | |
---|---|---|---|---|---|
0 | AGE | -0.000746 | 0 | 299 | 0.000000 |
1 | CREATININE_PHOSPHOKINASE | 0.044225 | 0 | 299 | 0.000000 |
2 | EJECTION_FRACTION | -0.006637 | 2 | 299 | 0.006689 |
3 | PLATELETS | 0.155360 | 18 | 299 | 0.060201 |
4 | SERUM_CREATININE | 0.150380 | 1 | 299 | 0.003344 |
5 | SERUM_SODIUM | 0.082305 | 3 | 299 | 0.010033 |
1.5. Data Exploration ¶
1.5.1 Exploratory Data Analysis ¶
- In the estimated baseline survival plot, the survival probability did not reach 50% over the observed time period. The last observed survival probability was 58% at TIME=258. Therefore, the median survival time could not be determined from the current data. This suggests that the majority of individuals in the cohort maintained a survival probability above 50% throughout the follow-up period. .
- Bivariate analysis identified individual predictors with potential association to the event status based on visual inspection.
- Higher values for the following numeric predictors are associated with DEATH_EVENT=True:
- AGE
- SERUM_CREATININE
- Lower values for the following numeric predictors are associated with DEATH_EVENT=True:
- EJECTION_FRACTION
- SERUM_SODIUM
- Higher counts for the following object predictors are associated with better differentiation between DEATH_EVENT=True and DEATH_EVENT=False:
- HIGH_BLOOD_PRESSURE
- Higher values for the following numeric predictors are associated with DEATH_EVENT=True:
- Bivariate analysis identified individual predictors with potential association to the survival time based on visual inspection.
- No numeric predictors were associated with TIME:
- Levels for the following object predictors are associated with differences in TIME between DEATH_EVENT=True and DEATH_EVENT=False:
- HIGH_BLOOD_PRESSURE
##################################
# Formulating a complete dataframe
##################################
heart_failure_EDA = pd.concat([heart_failure_numeric_predictor_transformed,
heart_failure.select_dtypes(include=['category','object']),
heart_failure_numeric['TIME']],
axis=1)
heart_failure_EDA['DEATH_EVENT'] = heart_failure_EDA['DEATH_EVENT'].replace({0: False, 1: True})
heart_failure_EDA.head()
AGE | CREATININE_PHOSPHOKINASE | EJECTION_FRACTION | PLATELETS | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | DIABETES | HIGH_BLOOD_PRESSURE | SEX | SMOKING | DEATH_EVENT | TIME | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.173233 | 0.691615 | -1.773346 | 0.110528 | 1.212227 | -1.468519 | Absent | Absent | Present | Male | Absent | True | 4.0 |
1 | -0.423454 | 2.401701 | 0.100914 | 0.093441 | -0.087641 | -0.244181 | Absent | Absent | Absent | Male | Absent | True | 6.0 |
2 | 0.434332 | -0.553424 | -1.773346 | -1.093142 | 0.381817 | -1.642143 | Absent | Absent | Absent | Male | Present | True | 7.0 |
3 | -0.910411 | -0.833885 | -1.773346 | -0.494713 | 1.212227 | -0.006503 | Present | Absent | Absent | Male | Absent | True | 7.0 |
4 | 0.434332 | -0.462335 | -1.773346 | 0.720277 | 1.715066 | -3.285073 | Present | Present | Absent | Female | Absent | True | 8.0 |
##################################
# Saving the EDA data
# to the DATASETS_PREPROCESSED_PATH
##################################
heart_failure_EDA.to_csv(os.path.join("..", DATASETS_PREPROCESSED_PATH, "heart_failure_EDA.csv"), index=True)
##################################
# Plotting the baseline survival curve
# and computing the survival rates
##################################
kmf = KaplanMeierFitter()
kmf.fit(durations=heart_failure_EDA['TIME'], event_observed=heart_failure_EDA['DEATH_EVENT'])
plt.figure(figsize=(17, 8))
kmf.plot_survival_function()
plt.title('Kaplan-Meier Baseline Survival Plot')
plt.ylim(0, 1.05)
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
##################################
# Determing the at-risk numbers
##################################
at_risk_counts = kmf.event_table.at_risk
survival_probabilities = kmf.survival_function_.values.flatten()
time_points = kmf.survival_function_.index
for time, prob, at_risk in zip(time_points, survival_probabilities, at_risk_counts):
if time % 50 == 0:
plt.text(time, prob, f'{prob:.2f} : {at_risk}', ha='left', fontsize=10)
median_survival_time = kmf.median_survival_time_
plt.axvline(x=median_survival_time, color='r', linestyle='--')
plt.axhline(y=0.5, color='r', linestyle='--')
plt.show()
##################################
# Computing the median survival time
##################################
median_survival_time = kmf.median_survival_time_
print(f'Median Survival Time: {median_survival_time}')
Median Survival Time: inf
##################################
# Exploring the relationships between
# the numeric predictors and event status
##################################
plt.figure(figsize=(17, 12))
for i in range(1, 7):
plt.subplot(2, 3, i)
sns.boxplot(x='DEATH_EVENT', y=heart_failure_numeric_predictor.columns[i-1], hue='DEATH_EVENT', data=heart_failure_EDA)
plt.title(f'{heart_failure_numeric_predictor.columns[i-1]} vs DEATH_EVENT Status')
plt.legend(loc='upper center')
plt.tight_layout()
plt.show()
##################################
# Exploring the relationships between
# the numeric predictors and event status
##################################
heart_failure_categorical_predictor = heart_failure_categorical.drop('DEATH_EVENT',axis=1)
heart_failure_EDA[int_columns] = heart_failure_EDA[int_columns].astype(object)
plt.figure(figsize=(17, 12))
for i in range(1, 6):
plt.subplot(2, 3, i)
sns.countplot(hue='DEATH_EVENT', x=heart_failure_categorical_predictor.columns[i-1], data=heart_failure_EDA)
plt.title(f'{heart_failure_categorical_predictor.columns[i-1]} vs DEATH_EVENT Status')
plt.legend(loc='upper center')
plt.tight_layout()
plt.show()
##################################
# Exploring the relationships between
# the numeric predictors and survival time
##################################
plt.figure(figsize=(17, 12))
for i in range(1, 7):
plt.subplot(2, 3, i)
sns.scatterplot(x='TIME', y=heart_failure_numeric_predictor.columns[i-1], hue='DEATH_EVENT', data=heart_failure_EDA)
loess_smoothed = lowess(heart_failure_EDA['TIME'], heart_failure_EDA[heart_failure_numeric_predictor.columns[i-1]], frac=0.3)
plt.plot(loess_smoothed[:, 1], loess_smoothed[:, 0], color='red')
plt.title(f'{heart_failure_numeric_predictor.columns[i-1]} vs Survival Time')
plt.legend(loc='upper center')
plt.tight_layout()
plt.show()
##################################
# Exploring the relationships between
# the object predictors and survival time
##################################
plt.figure(figsize=(17, 12))
for i in range(1, 6):
plt.subplot(2, 3, i)
sns.boxplot(x=heart_failure_categorical_predictor.columns[i-1], y='TIME', hue='DEATH_EVENT', data=heart_failure_EDA)
plt.title(f'{heart_failure_categorical_predictor.columns[i-1]} vs Survival Time')
plt.legend(loc='upper center')
plt.tight_layout()
plt.show()
1.5.2 Hypothesis Testing ¶
- The relationship between the numeric predictors to the DEATH_EVENT event variable was statistically evaluated using the following hypotheses:
- Null: Difference in the means between groups True and False is equal to zero
- Alternative: Difference in the means between groups True and False is not equal to zero
- There is sufficient evidence to conclude of a statistically significant difference between the means of the numeric measurements obtained from the Status groups in 4 numeric predictors given their high t-test statistic values with reported low p-values less than the significance level of 0.05.
- SERUM_CREATININE: T.Test.Statistic=-6.825, T.Test.PValue=0.000
- EJECTION_FRACTION: T.Test.Statistic=+5.495, T.Test.PValue=0.000
- AGE: T.Test.Statistic=-4.274, T.Test.PValue=0.000
- SERUM_SODIUM: T.Test.Statistic=+3.229, T.Test.PValue=0.001
- The relationship between the object predictors to the DEATH_EVENT event variable was statistically evaluated using the following hypotheses:
- Null: The object predictor is independent of the event variable
- Alternative: The object predictor is dependent on the event variable
- There were no categorical predictors that demonstrated sufficient evidence to conclude of a statistically significant relationship between the individual categories and the Status groups with high chisquare statistic values with reported low p-values less than the significance level of 0.05.
- The relationship between the object predictors to the DEATH_EVENT and TIME variables was statistically evaluated using the following hypotheses:
- Null: There is no difference in survival probabilities among cases belonging to each category of the object predictor.
- Alternative: There is a difference in survival probabilities among cases belonging to each category of the object predictor.
- There is sufficient evidence to conclude of a statistically significant difference in survival probabilities between the individual categories and the DEATH_EVENT groups with respect to the survival duration TIME in 1 categorical predictor given its high log-rank test statistic values with reported low p-values less than the significance level of 0.05.
- HIGH_BLOOD_PRESSURE: LR.Test.Statistic=4.406, LR.Test.PValue=0.035
- The relationship between the binned numeric predictors to the DEATH_EVENT and TIME variables was statistically evaluated using the following hypotheses:
- Null: There is no difference in survival probabilities among cases belonging to each category of the binned numeric predictor.
- Alternative: There is a difference in survival probabilities among cases belonging to each category of the binned numeric predictor.
- There is sufficient evidence to conclude of a statistically significant difference in survival probabilities between the individual categories and the DEATH_EVENT groups with respect to the survival duration TIME in 9 binned numeric predictors given their high log-rank test statistic values with reported low p-values less than the significance level of 0.05.
- Binned_SERUM_CREATININE: LR.Test.Statistic=21.190, LR.Test.PValue=0.000
- Binned_EJECTION_FRACTION: LR.Test.Statistic=9.469, LR.Test.PValue=0.002
- Binned_AGE: LR.Test.Statistic=4.951, LR.Test.PValue=0.026
- Binned_SERUM_SODIUM: LR.Test.Statistic=4.887, LR.Test.PValue=0.027
##################################
# Formulating a complete dataframe
##################################
heart_failure_HT = pd.concat([heart_failure_numeric_predictor_transformed,
heart_failure_categorical,
heart_failure_numeric['TIME']],
axis=1)
heart_failure_HT['DEATH_EVENT'] = heart_failure_HT['DEATH_EVENT'].replace({0: False, 1: True})
heart_failure_HT.head()
AGE | CREATININE_PHOSPHOKINASE | EJECTION_FRACTION | PLATELETS | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | DIABETES | HIGH_BLOOD_PRESSURE | SEX | SMOKING | DEATH_EVENT | TIME | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.173233 | 0.691615 | -1.773346 | 0.110528 | 1.212227 | -1.468519 | 0 | 0 | 1 | 1 | 0 | True | 4.0 |
1 | -0.423454 | 2.401701 | 0.100914 | 0.093441 | -0.087641 | -0.244181 | 0 | 0 | 0 | 1 | 0 | True | 6.0 |
2 | 0.434332 | -0.553424 | -1.773346 | -1.093142 | 0.381817 | -1.642143 | 0 | 0 | 0 | 1 | 1 | True | 7.0 |
3 | -0.910411 | -0.833885 | -1.773346 | -0.494713 | 1.212227 | -0.006503 | 1 | 0 | 0 | 1 | 0 | True | 7.0 |
4 | 0.434332 | -0.462335 | -1.773346 | 0.720277 | 1.715066 | -3.285073 | 1 | 1 | 0 | 0 | 0 | True | 8.0 |
##################################
# Computing the t-test
# statistic and p-values
# between the event variable
# and numeric predictor columns
##################################
heart_failure_numeric_ttest_event = {}
for numeric_column in heart_failure_numeric_predictor.columns:
group_0 = heart_failure_HT[heart_failure_HT.loc[:,'DEATH_EVENT']==False]
group_1 = heart_failure_HT[heart_failure_HT.loc[:,'DEATH_EVENT']==True]
heart_failure_numeric_ttest_event['DEATH_EVENT_' + numeric_column] = stats.ttest_ind(
group_0[numeric_column],
group_1[numeric_column],
equal_var=True)
##################################
# Formulating the pairwise ttest summary
# between the event variable
# and numeric predictor columns
##################################
heart_failure_numeric_ttest_summary = heart_failure_HT.from_dict(heart_failure_numeric_ttest_event, orient='index')
heart_failure_numeric_ttest_summary.columns = ['T.Test.Statistic', 'T.Test.PValue']
display(heart_failure_numeric_ttest_summary.sort_values(by=['T.Test.PValue'], ascending=True))
T.Test.Statistic | T.Test.PValue | |
---|---|---|
DEATH_EVENT_SERUM_CREATININE | -6.825678 | 4.927143e-11 |
DEATH_EVENT_EJECTION_FRACTION | 5.495673 | 8.382875e-08 |
DEATH_EVENT_AGE | -4.274623 | 2.582635e-05 |
DEATH_EVENT_SERUM_SODIUM | 3.229580 | 1.378737e-03 |
DEATH_EVENT_PLATELETS | 1.031261 | 3.032576e-01 |
DEATH_EVENT_CREATININE_PHOSPHOKINASE | -0.565564 | 5.721174e-01 |
##################################
# Computing the chisquare
# statistic and p-values
# between the event variable
# and categorical predictor columns
##################################
heart_failure_categorical_chisquare_event = {}
for categorical_column in heart_failure_categorical_predictor.columns:
contingency_table = pd.crosstab(heart_failure_HT[categorical_column],
heart_failure_HT['DEATH_EVENT'])
heart_failure_categorical_chisquare_event['DEATH_EVENT_' + categorical_column] = stats.chi2_contingency(
contingency_table)[0:2]
##################################
# Formulating the pairwise chisquare summary
# between the event variable
# and categorical predictor columns
##################################
heart_failure_categorical_chisquare_event_summary = heart_failure_HT.from_dict(heart_failure_categorical_chisquare_event, orient='index')
heart_failure_categorical_chisquare_event_summary.columns = ['ChiSquare.Test.Statistic', 'ChiSquare.Test.PValue']
display(heart_failure_categorical_chisquare_event_summary.sort_values(by=['ChiSquare.Test.PValue'], ascending=True))
ChiSquare.Test.Statistic | ChiSquare.Test.PValue | |
---|---|---|
DEATH_EVENT_HIGH_BLOOD_PRESSURE | 1.543461 | 0.214103 |
DEATH_EVENT_ANAEMIA | 1.042175 | 0.307316 |
DEATH_EVENT_SMOKING | 0.007331 | 0.931765 |
DEATH_EVENT_DIABETES | 0.000000 | 1.000000 |
DEATH_EVENT_SEX | 0.000000 | 1.000000 |
##################################
# Exploring the relationships between
# the categorical predictors with
# survival event and duration
##################################
plt.figure(figsize=(17, 18))
for i in range(1, 6):
ax = plt.subplot(3, 2, i)
for group in [0,1]:
kmf.fit(durations=heart_failure_HT[heart_failure_HT[heart_failure_categorical_predictor.columns[i-1]] == group]['TIME'],
event_observed=heart_failure_HT[heart_failure_HT[heart_failure_categorical_predictor.columns[i-1]] == group]['DEATH_EVENT'], label=group)
kmf.plot_survival_function(ax=ax)
plt.title(f'Survival Probabilities by {heart_failure_categorical_predictor.columns[i-1]} Categories')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.tight_layout()
plt.show()
##################################
# Computing the log-rank test
# statistic and p-values
# between the event and duration variables
# with the categorical predictor columns
##################################
heart_failure_categorical_lrtest_event = {}
for categorical_column in heart_failure_categorical_predictor.columns:
groups = [0,1]
group_0_event = heart_failure_HT[heart_failure_HT[categorical_column] == groups[0]]['DEATH_EVENT']
group_1_event = heart_failure_HT[heart_failure_HT[categorical_column] == groups[1]]['DEATH_EVENT']
group_0_duration = heart_failure_HT[heart_failure_HT[categorical_column] == groups[0]]['TIME']
group_1_duration = heart_failure_HT[heart_failure_HT[categorical_column] == groups[1]]['TIME']
lr_test = logrank_test(group_0_duration, group_1_duration,event_observed_A=group_0_event, event_observed_B=group_1_event)
heart_failure_categorical_lrtest_event['DEATH_EVENT_TIME_' + categorical_column] = (lr_test.test_statistic, lr_test.p_value)
##################################
# Formulating the log-rank test summary
# between the event and duration variables
# with the categorical predictor columns
##################################
heart_failure_categorical_lrtest_summary = heart_failure_HT.from_dict(heart_failure_categorical_lrtest_event, orient='index')
heart_failure_categorical_lrtest_summary.columns = ['LR.Test.Statistic', 'LR.Test.PValue']
display(heart_failure_categorical_lrtest_summary.sort_values(by=['LR.Test.PValue'], ascending=True))
LR.Test.Statistic | LR.Test.PValue | |
---|---|---|
DEATH_EVENT_TIME_HIGH_BLOOD_PRESSURE | 4.406248 | 0.035808 |
DEATH_EVENT_TIME_ANAEMIA | 2.726464 | 0.098698 |
DEATH_EVENT_TIME_DIABETES | 0.040528 | 0.840452 |
DEATH_EVENT_TIME_SEX | 0.003971 | 0.949752 |
DEATH_EVENT_TIME_SMOKING | 0.002042 | 0.963960 |
##################################
# Creating an alternate copy of the
# EDA data which will utilize
# binning for numeric predictors
##################################
heart_failure_HT_binned = heart_failure_HT.copy()
##################################
# Creating a function to bin
# numeric predictors into two groups
##################################
def bin_numeric_predictor(df, predictor):
median = df[predictor].median()
df[f'Binned_{predictor}'] = np.where(df[predictor] <= median, 0, 1)
return df
##################################
# Binning the numeric predictors
# in the alternate data into two groups
##################################
for numeric_column in heart_failure_numeric_predictor.columns:
heart_failure_HT_binned = bin_numeric_predictor(heart_failure_HT_binned, numeric_column)
##################################
# Formulating the binned numeric predictors
##################################
heart_failure_binned_numeric_predictor = ["Binned_" + predictor for predictor in heart_failure_numeric_predictor.columns]
##################################
# Exploring the relationships between
# the binned numeric predictors with
# survival event and duration
##################################
plt.figure(figsize=(17, 18))
for i in range(1, 7):
ax = plt.subplot(3, 2, i)
for group in [0,1]:
kmf.fit(durations=heart_failure_HT_binned[heart_failure_HT_binned[heart_failure_binned_numeric_predictor[i-1]] == group]['TIME'],
event_observed=heart_failure_HT_binned[heart_failure_HT_binned[heart_failure_binned_numeric_predictor[i-1]] == group]['DEATH_EVENT'], label=group)
kmf.plot_survival_function(ax=ax)
plt.title(f'Survival Probabilities by {heart_failure_binned_numeric_predictor[i-1]} Categories')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.tight_layout()
plt.show()
##################################
# Computing the log-rank test
# statistic and p-values
# between the event and duration variables
# with the binned numeric predictor columns
##################################
heart_failure_binned_numeric_lrtest_event = {}
for binned_numeric_column in heart_failure_binned_numeric_predictor:
groups = [0,1]
group_0_event = heart_failure_HT_binned[heart_failure_HT_binned[binned_numeric_column] == groups[0]]['DEATH_EVENT']
group_1_event = heart_failure_HT_binned[heart_failure_HT_binned[binned_numeric_column] == groups[1]]['DEATH_EVENT']
group_0_duration = heart_failure_HT_binned[heart_failure_HT_binned[binned_numeric_column] == groups[0]]['TIME']
group_1_duration = heart_failure_HT_binned[heart_failure_HT_binned[binned_numeric_column] == groups[1]]['TIME']
lr_test = logrank_test(group_0_duration, group_1_duration,event_observed_A=group_0_event, event_observed_B=group_1_event)
heart_failure_binned_numeric_lrtest_event['DEATH_EVENT_TIME_' + binned_numeric_column] = (lr_test.test_statistic, lr_test.p_value)
##################################
# Formulating the log-rank test summary
# between the event and duration variables
# with the binned numeric predictor columns
##################################
heart_failure_binned_numeric_lrtest_summary = heart_failure_HT_binned.from_dict(heart_failure_binned_numeric_lrtest_event, orient='index')
heart_failure_binned_numeric_lrtest_summary.columns = ['LR.Test.Statistic', 'LR.Test.PValue']
display(heart_failure_binned_numeric_lrtest_summary.sort_values(by=['LR.Test.PValue'], ascending=True))
LR.Test.Statistic | LR.Test.PValue | |
---|---|---|
DEATH_EVENT_TIME_Binned_SERUM_CREATININE | 21.190414 | 0.000004 |
DEATH_EVENT_TIME_Binned_EJECTION_FRACTION | 9.469633 | 0.002089 |
DEATH_EVENT_TIME_Binned_AGE | 4.951760 | 0.026064 |
DEATH_EVENT_TIME_Binned_SERUM_SODIUM | 4.887878 | 0.027046 |
DEATH_EVENT_TIME_Binned_CREATININE_PHOSPHOKINASE | 0.055576 | 0.813630 |
DEATH_EVENT_TIME_Binned_PLATELETS | 0.009122 | 0.923912 |
1.6. Predictive Model Development ¶
1.6.1 Pre-Modelling Data Preparation ¶
- All dichotomous categorical predictors and the target variable were one-hot encoded for the downstream modelling process.
- Predictors determined with insufficient association with the DEATH_EVENT and TIME variables were excluded for the subsequent modelling steps.
- DIABETES: LR.Test.Statistic=0.040, LR.Test.PValue=0.840
- SEX: LR.Test.Statistic=0.003, LR.Test.PValue=0.949
- SMOKING: LR.Test.Statistic=0.002, LR.Test.PValue=0.963
- CREATININE_PHOSPHOKINASE: LR.Test.Statistic=0.055, LR.Test.PValue=0.813
- PLATELETS: LR.Test.Statistic=0.009, LR.Test.PValue=0.923
#################################
# Creating a dataset copy
# for data splitting and modelling
##################################
heart_failure_transformed = heart_failure_original.copy()
heart_failure_transformed['DEATH_EVENT'] = heart_failure_transformed['DEATH_EVENT'].replace({0: False, 1: True})
display(heart_failure_transformed)
AGE | ANAEMIA | CREATININE_PHOSPHOKINASE | DIABETES | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | PLATELETS | SERUM_CREATININE | SERUM_SODIUM | SEX | SMOKING | TIME | DEATH_EVENT | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 75.0 | 0 | 582.0 | 0 | 20.0 | 1 | 265000.00 | 1.9 | 130.0 | 1 | 0 | 4.0 | True |
1 | 55.0 | 0 | 7861.0 | 0 | 38.0 | 0 | 263358.03 | 1.1 | 136.0 | 1 | 0 | 6.0 | True |
2 | 65.0 | 0 | 146.0 | 0 | 20.0 | 0 | 162000.00 | 1.3 | 129.0 | 1 | 1 | 7.0 | True |
3 | 50.0 | 1 | 111.0 | 0 | 20.0 | 0 | 210000.00 | 1.9 | 137.0 | 1 | 0 | 7.0 | True |
4 | 65.0 | 1 | 160.0 | 1 | 20.0 | 0 | 327000.00 | 2.7 | 116.0 | 0 | 0 | 8.0 | True |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
294 | 62.0 | 0 | 61.0 | 1 | 38.0 | 1 | 155000.00 | 1.1 | 143.0 | 1 | 1 | 270.0 | False |
295 | 55.0 | 0 | 1820.0 | 0 | 38.0 | 0 | 270000.00 | 1.2 | 139.0 | 0 | 0 | 271.0 | False |
296 | 45.0 | 0 | 2060.0 | 1 | 60.0 | 0 | 742000.00 | 0.8 | 138.0 | 0 | 0 | 278.0 | False |
297 | 45.0 | 0 | 2413.0 | 0 | 38.0 | 0 | 140000.00 | 1.4 | 140.0 | 1 | 1 | 280.0 | False |
298 | 50.0 | 0 | 196.0 | 0 | 45.0 | 0 | 395000.00 | 1.6 | 136.0 | 1 | 1 | 285.0 | False |
299 rows × 13 columns
##################################
# Saving the tranformed data
# to the DATASETS_PREPROCESSED_PATH
##################################
heart_failure_transformed.to_csv(os.path.join("..", DATASETS_PREPROCESSED_PATH, "heart_failure_transformed.csv"), index=True)
##################################
# Filtering out predictors that did not exhibit
# sufficient discrimination of the target variable
# Saving the tranformed data
# to the DATASETS_PREPROCESSED_PATH
##################################
heart_failure_filtered = heart_failure_transformed.drop(['DIABETES','SEX', 'SMOKING', 'CREATININE_PHOSPHOKINASE','PLATELETS'], axis=1)
heart_failure_filtered.to_csv(os.path.join("..", DATASETS_FINAL_PATH, "heart_failure_final.csv"), index=True)
display(heart_failure_filtered)
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | TIME | DEATH_EVENT | |
---|---|---|---|---|---|---|---|---|
0 | 75.0 | 0 | 20.0 | 1 | 1.9 | 130.0 | 4.0 | True |
1 | 55.0 | 0 | 38.0 | 0 | 1.1 | 136.0 | 6.0 | True |
2 | 65.0 | 0 | 20.0 | 0 | 1.3 | 129.0 | 7.0 | True |
3 | 50.0 | 1 | 20.0 | 0 | 1.9 | 137.0 | 7.0 | True |
4 | 65.0 | 1 | 20.0 | 0 | 2.7 | 116.0 | 8.0 | True |
... | ... | ... | ... | ... | ... | ... | ... | ... |
294 | 62.0 | 0 | 38.0 | 1 | 1.1 | 143.0 | 270.0 | False |
295 | 55.0 | 0 | 38.0 | 0 | 1.2 | 139.0 | 271.0 | False |
296 | 45.0 | 0 | 60.0 | 0 | 0.8 | 138.0 | 278.0 | False |
297 | 45.0 | 0 | 38.0 | 0 | 1.4 | 140.0 | 280.0 | False |
298 | 50.0 | 0 | 45.0 | 0 | 1.6 | 136.0 | 285.0 | False |
299 rows × 8 columns
1.6.2 Data Splitting ¶
- The preprocessed dataset was divided into three subsets using a fixed random seed:
- test data: 25% of the original data with class stratification applied
- train data (initial): 75% of the original data with class stratification applied
- train data (final): 75% of the train (initial) data with class stratification applied
- validation data: 25% of the train (initial) data with class stratification applied
- Although a moderate class imbalance between DEATH_EVENT=True and DEATH_EVENT=False was observed, maintining the time-to-even distribution is crucial for survival analysis. Resampling would require synthetic imputation of event times, which could introduce additional noise and bias into the model. Given the nature of the data, preserving the integrity of the time variable is of higher importance than correcting for a moderate class imbalance.
- Models were developed from the train data (final). Using the same dataset, a subset of models with optimal hyperparameters were selected, based on cross-validation.
- Among candidate models with optimal hyperparameters, the final model were selected based on performance during cross-validation and independent validation.
- Performance of the selected final model (and other candidate models for post-model selection comparison) were evaluated using the test data.
- The preprocessed data is comprised of:
- 299 rows (observations)
- 96 DEATH_EVENT=True: 32.11%
- 203 DEATH_EVENT=False: 67.89%
- 8 columns (variables)
- 2/8 event | duration (object | numeric)
- DEATH_EVENT
- TIME
- 4/8 predictor (numeric)
- AGE
- EJECTION_FRACTION
- SERUM_CREATININE
- SERUM_SODIUM
- 2/8 predictor (object)
- ANAEMIA
- HIGH_BLOOD_PRESSURE
- 2/8 event | duration (object | numeric)
- 299 rows (observations)
- The train data (final) subset is comprised of:
- 168 rows (observations)
- 114 DEATH_EVENT=True: 67.85%
- 54 DEATH_EVENT=False: 32.14%
- 8 columns (variables)
- 168 rows (observations)
- The validation data subset is comprised of:
- 56 rows (observations)
- 38 DEATH_EVENT=True: 67.85%
- 18 DEATH_EVENT=False=No: 32.14%
- 8 columns (variables)
- 56 rows (observations)
- The test data subset is comprised of:
- 75 rows (observations)
- 51 DEATH_EVENT=True: 68.93%
- 24 DEATH_EVENT=False: 32.07%
- 8 columns (variables)
- 75 rows (observations)
##################################
# Creating a dataset copy
# of the filtered data
##################################
heart_failure_final = heart_failure_filtered.copy()
##################################
# Performing a general exploration
# of the final dataset
##################################
print('Final Dataset Dimensions: ')
display(heart_failure_final.shape)
Final Dataset Dimensions:
(299, 8)
print('Target Variable Breakdown: ')
heart_failure_breakdown = heart_failure_final.groupby('DEATH_EVENT', observed=True).size().reset_index(name='Count')
heart_failure_breakdown['Percentage'] = (heart_failure_breakdown['Count'] / len(heart_failure_final)) * 100
display(heart_failure_breakdown)
Target Variable Breakdown:
DEATH_EVENT | Count | Percentage | |
---|---|---|---|
0 | False | 203 | 67.892977 |
1 | True | 96 | 32.107023 |
##################################
# Formulating the train and test data
# from the final dataset
# by applying stratification and
# using a 70-30 ratio
##################################
heart_failure_train_initial, heart_failure_test = train_test_split(heart_failure_final,
test_size=0.25,
stratify=heart_failure_final['DEATH_EVENT'],
random_state=88888888)
##################################
# Performing a general exploration
# of the initial training dataset
##################################
X_train_initial = heart_failure_train_initial.drop(['DEATH_EVENT', 'TIME'], axis=1)
y_train_initial = heart_failure_train_initial[['DEATH_EVENT', 'TIME']]
print('Initial Training Dataset Dimensions: ')
display(X_train_initial.shape)
display(y_train_initial.shape)
print('Initial Training Target Variable Breakdown: ')
display(y_train_initial['DEATH_EVENT'].value_counts())
print('Initial Training Target Variable Proportion: ')
display(y_train_initial['DEATH_EVENT'].value_counts(normalize = True))
Initial Training Dataset Dimensions:
(224, 6)
(224, 2)
Initial Training Target Variable Breakdown:
DEATH_EVENT False 152 True 72 Name: count, dtype: int64
Initial Training Target Variable Proportion:
DEATH_EVENT False 0.678571 True 0.321429 Name: proportion, dtype: float64
##################################
# Performing a general exploration
# of the test dataset
##################################
X_test = heart_failure_test.drop(['DEATH_EVENT', 'TIME'], axis=1)
y_test = heart_failure_test[['DEATH_EVENT', 'TIME']]
print('Test Dataset Dimensions: ')
display(X_test.shape)
display(y_test.shape)
print('Test Target Variable Breakdown: ')
display(y_test['DEATH_EVENT'].value_counts())
print('Test Target Variable Proportion: ')
display(y_test['DEATH_EVENT'].value_counts(normalize = True))
Test Dataset Dimensions:
(75, 6)
(75, 2)
Test Target Variable Breakdown:
DEATH_EVENT False 51 True 24 Name: count, dtype: int64
Test Target Variable Proportion:
DEATH_EVENT False 0.68 True 0.32 Name: proportion, dtype: float64
##################################
# Formulating the train and validation data
# from the train dataset
# by applying stratification and
# using a 70-30 ratio
##################################
heart_failure_train, heart_failure_validation = train_test_split(heart_failure_train_initial,
test_size=0.25,
stratify=heart_failure_train_initial['DEATH_EVENT'],
random_state=88888888)
##################################
# Performing a general exploration
# of the final training dataset
##################################
X_train = heart_failure_train.drop(columns=['DEATH_EVENT', 'TIME'], axis=1)
y_train = heart_failure_train[['DEATH_EVENT', 'TIME']]
print('Final Training Dataset Dimensions: ')
display(X_train.shape)
display(y_train.shape)
print('Final Training Target Variable Breakdown: ')
display(y_train['DEATH_EVENT'].value_counts())
print('Final Training Target Variable Proportion: ')
display(y_train['DEATH_EVENT'].value_counts(normalize = True))
Final Training Dataset Dimensions:
(168, 6)
(168, 2)
Final Training Target Variable Breakdown:
DEATH_EVENT False 114 True 54 Name: count, dtype: int64
Final Training Target Variable Proportion:
DEATH_EVENT False 0.678571 True 0.321429 Name: proportion, dtype: float64
##################################
# Performing a general exploration
# of the validation dataset
##################################
X_validation = heart_failure_validation.drop(columns=['DEATH_EVENT', 'TIME'], axis = 1)
y_validation = heart_failure_validation[['DEATH_EVENT', 'TIME']]
print('Validation Dataset Dimensions: ')
display(X_validation.shape)
display(y_validation.shape)
print('Validation Target Variable Breakdown: ')
display(y_validation['DEATH_EVENT'].value_counts())
print('Validation Target Variable Proportion: ')
display(y_validation['DEATH_EVENT'].value_counts(normalize = True))
Validation Dataset Dimensions:
(56, 6)
(56, 2)
Validation Target Variable Breakdown:
DEATH_EVENT False 38 True 18 Name: count, dtype: int64
Validation Target Variable Proportion:
DEATH_EVENT False 0.678571 True 0.321429 Name: proportion, dtype: float64
##################################
# Saving the training data
# to the DATASETS_FINAL_TRAIN_PATH
# and DATASETS_FINAL_TRAIN_FEATURES_PATH
# and DATASETS_FINAL_TRAIN_TARGET_PATH
##################################
heart_failure_train.to_csv(os.path.join("..", DATASETS_FINAL_TRAIN_PATH, "heart_failure_train.csv"), index=True)
X_train.to_csv(os.path.join("..", DATASETS_FINAL_TRAIN_FEATURES_PATH, "X_train.csv"), index=True)
y_train.to_csv(os.path.join("..", DATASETS_FINAL_TRAIN_TARGET_PATH, "y_train.csv"), index=True)
##################################
# Saving the validation data
# to the DATASETS_FINAL_VALIDATION_PATH
# and DATASETS_FINAL_VALIDATION_FEATURE_PATH
# and DATASETS_FINAL_VALIDATION_TARGET_PATH
##################################
heart_failure_validation.to_csv(os.path.join("..", DATASETS_FINAL_VALIDATION_PATH, "heart_failure_validation.csv"), index=True)
X_validation.to_csv(os.path.join("..", DATASETS_FINAL_VALIDATION_FEATURES_PATH, "X_validation.csv"), index=True)
y_validation.to_csv(os.path.join("..", DATASETS_FINAL_VALIDATION_TARGET_PATH, "y_validation.csv"), index=True)
##################################
# Saving the test data
# to the DATASETS_FINAL_TEST_PATH
# and DATASETS_FINAL_TEST_FEATURES_PATH
# and DATASETS_FINAL_TEST_TARGET_PATH
##################################
heart_failure_test.to_csv(os.path.join("..", DATASETS_FINAL_TEST_PATH, "heart_failure_test.csv"), index=True)
X_test.to_csv(os.path.join("..", DATASETS_FINAL_TEST_FEATURES_PATH, "X_test.csv"), index=True)
y_test.to_csv(os.path.join("..", DATASETS_FINAL_TEST_TARGET_PATH, "y_test.csv"), index=True)
##################################
# Converting the event and duration variables
# for the train, validation and test sets
# to array as preparation for modeling
##################################
y_train_array = np.array([(row.DEATH_EVENT, row.TIME) for index, row in y_train.iterrows()], dtype=[('DEATH_EVENT', 'bool'), ('TIME', 'int')])
y_validation_array = np.array([(row.DEATH_EVENT, row.TIME) for index, row in y_validation.iterrows()], dtype=[('DEATH_EVENT', 'bool'), ('TIME', 'int')])
y_test_array = np.array([(row.DEATH_EVENT, row.TIME) for index, row in y_test.iterrows()], dtype=[('DEATH_EVENT', 'bool'), ('TIME', 'int')])
1.6.3 Modelling Pipeline Development ¶
1.6.3.1 Cox Proportional Hazards Regression ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Cox Proportional Hazards Regression is a semiparametric model used to study the relationship between the survival time of subjects and one or more predictor variables. The model assumes that the hazard ratio (the risk of the event occurring at a specific time) is a product of a baseline hazard function and an exponential function of the predictor variables. It also does not require the baseline hazard to be specified, thus making it a semiparametric model. As a method, it is well-established and widely used in survival analysis, can handle time-dependent covariates and provides a relatively straightforward interpretation. However, the process assumes proportional hazards, which may not hold in all datasets, and may be less flexible in capturing complex relationships between variables and survival times compared to some machine learning models. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves defining the partial likelihood function for the Cox model (which only considers the relative ordering of survival times); using optimization techniques to estimate the regression coefficients by maximizing the log-partial likelihood; estimating the baseline hazard function (although it is not explicitly required for predictions); and calculating the hazard function and survival function for new data using the estimated coefficients and baseline hazard.
Yeo-Johnson Transformation applies a new family of distributions that can be used without restrictions, extending many of the good properties of the Box-Cox power family. Similar to the Box-Cox transformation, the method also estimates the optimal value of lambda but has the ability to transform both positive and negative values by inflating low variance data and deflating high variance data to create a more uniform data set. While there are no restrictions in terms of the applicable values, the interpretability of the transformed values is more diminished as compared to the other methods.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- A modelling pipeline was implemented with the following steps:
- Yeo-johnson transformation from the sklearn.processing Python library API applied to the numeric predictors only. Categorical predictors were excluded from the transformation.
- Cox proportional hazards regression model from the sksurv.linear_model Python library API with 1 hyperparameter:
- alpha = regularization parameter for ridge regression penalty made to vary between 0.00, 0.01, 0.10, 1.00, 10.0 and 100.00
- Hyperparameter tuning was conducted using the 5-fold cross-validation method for 5 repeats with optimal model performance determined using the concordance index.
##################################
# Defining the modelling pipeline
# using the Cox Proportional Hazards Regression Model
##################################
coxph_pipeline_preprocessor = ColumnTransformer(
transformers=[
# Applying PowerTransformer to numeric columns only
('numeric_predictors', PowerTransformer(method='yeo-johnson', standardize=True), ['AGE', 'EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM'])
# Keeping the categorical columns unchanged
], remainder='passthrough'
)
coxph_pipeline = Pipeline([
('yeo_johnson', coxph_pipeline_preprocessor),
('coxph', CoxPHSurvivalAnalysis())])
##################################
# Saving the model pipeline
# developed from the original training data
# for downstream processes
##################################
coxph_pipeline.fit(X_train, y_train_array)
joblib.dump(coxph_pipeline,
os.path.join("..", PIPELINES_PATH, "coxph_pipeline.pkl"))
['..\\pipelines\\coxph_pipeline.pkl']
##################################
# Defining the hyperparameters for grid search
##################################
coxph_hyperparameter_grid = {'coxph__alpha': [0.01, 0.10, 1.00, 10.00]}
##################################
# Setting up the GridSearchCV with 5-fold cross-validation
# and using concordance index as the model evaluation metric
##################################
coxph_grid_search = GridSearchCV(estimator=coxph_pipeline,
param_grid=coxph_hyperparameter_grid,
cv=RepeatedKFold(n_splits=5, n_repeats=5, random_state=88888888),
return_train_score=False,
n_jobs=-1,
verbose=1)
1.6.3.2 Cox Net Survival ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Cox Net Survival is a regularized version of the Cox Proportional Hazards model, which incorporates both L1 (Lasso) and L2 (Ridge) penalties. The model is useful when dealing with high-dimensional data where the number of predictors can be larger than the number of observations. The elastic net penalty helps in both variable selection (via L1) and multicollinearity handling (via L2). As a method, it can handle high-dimensional data and perform variable selection. Additionally, it balances between L1 and L2 penalties, offering flexibility in modeling. However, the process requires tuning of penalty parameters, which can be computationally intensive. Additionally, interpretation is more complex due to the regularization terms. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves defining the penalized partial likelihood function, incorporating both L1 (Lasso) and L2 (Ridge) penalties; application of regularization techniques to estimate the regression coefficients by maximizing the penalized log-partial likelihood; performing cross-validation to select optimal values for the penalty parameters (alpha and l1_ratio); and the calculation of the hazard function and survival function for new data using the estimated regularized coefficients.
Yeo-Johnson Transformation applies a new family of distributions that can be used without restrictions, extending many of the good properties of the Box-Cox power family. Similar to the Box-Cox transformation, the method also estimates the optimal value of lambda but has the ability to transform both positive and negative values by inflating low variance data and deflating high variance data to create a more uniform data set. While there are no restrictions in terms of the applicable values, the interpretability of the transformed values is more diminished as compared to the other methods.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- A modelling pipeline was implemented with the following steps:
- Yeo-johnson transformation from the sklearn.processing Python library API applied to the numeric predictors only. Categorical predictors were excluded from the transformation.
- Cox net survival model from the sksurv.linear_model Python library API with 2 hyperparameters:
- l1_ratio = ElasticNet mixing parameter made to vary between 0.10, 0.50 and 1.00
- alpha_min_ratio = minimum alpha of the regularization path made to vary between 0.0001 and 0.01
- Hyperparameter tuning was conducted using the 5-fold cross-validation method for 5 repeats with optimal model performance determined using the concordance index.
##################################
# Defining the modelling pipeline
# using the cox net survival analysis model
##################################
coxns_pipeline_preprocessor = ColumnTransformer(
transformers=[
# Applying PowerTransformer to numeric columns only
('numeric_predictors', PowerTransformer(method='yeo-johnson', standardize=True), ['AGE', 'EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM'])
# Keeping the categorical columns unchanged
], remainder='passthrough'
)
coxns_pipeline = Pipeline([
('yeo_johnson', coxns_pipeline_preprocessor),
('coxns', CoxnetSurvivalAnalysis())])
##################################
# Saving the model pipeline
# developed from the original training data
# for downstream processes
##################################
coxns_pipeline.fit(X_train, y_train_array)
joblib.dump(coxns_pipeline,
os.path.join("..", PIPELINES_PATH, "coxns_pipeline.pkl"))
['..\\pipelines\\coxns_pipeline.pkl']
##################################
# Defining the hyperparameters for grid search
##################################
coxns_hyperparameter_grid = {'coxns__l1_ratio': [0.10, 0.50, 1.00],
'coxns__alpha_min_ratio': [0.0001, 0.01],
'coxns__fit_baseline_model': [True]}
##################################
# Setting up the GridSearchCV with 5-fold cross-validation
# and using concordance index as the model evaluation metric
##################################
coxns_grid_search = GridSearchCV(estimator=coxns_pipeline,
param_grid=coxns_hyperparameter_grid,
cv=RepeatedKFold(n_splits=5, n_repeats=5, random_state=88888888),
return_train_score=False,
n_jobs=-1,
verbose=1)
1.6.3.3 Survival Tree ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Survival Trees are non-parametric models that partition the data into subgroups (nodes) based on the values of predictor variables, creating a tree-like structure. The tree is built by recursively splitting the data at nodes where the differences in survival times between subgroups are maximized. Each terminal node represents a different survival function. The method have no assumptions about the underlying distribution of survival times, can capture interactions between variables naturally and applies an interpretable visual representation. However, the process can be prone to overfitting, especially with small datasets, and may be less accurate compared to ensemble methods like Random Survival Forest. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves recursively splitting the data at nodes to maximize the differences in survival times between subgroups with the splitting criteria often involving statistical tests (e.g., log-rank test); choosing the best predictor variable and split point at each node that maximizes the separation of survival times; continuously splitting until stopping criteria are met (e.g., minimum number of observations in a node, maximum tree depth); and estimating the survival function based on the survival times of the observations at each terminal node.
Yeo-Johnson Transformation applies a new family of distributions that can be used without restrictions, extending many of the good properties of the Box-Cox power family. Similar to the Box-Cox transformation, the method also estimates the optimal value of lambda but has the ability to transform both positive and negative values by inflating low variance data and deflating high variance data to create a more uniform data set. While there are no restrictions in terms of the applicable values, the interpretability of the transformed values is more diminished as compared to the other methods.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- A modelling pipeline was implemented with the following steps:
- Yeo-johnson transformation from the sklearn.processing Python library API applied to the numeric predictors only. Categorical predictors were excluded from the transformation.
- Survival tree model from the sksurv.tree Python library API with 2 hyperparameters:
- min_samples_split = minimum number of samples required to split an internal node made to vary between 10, 15 and 20
- min_samples_leaf = minimum number of samples required to be at a leaf node made to vary between 3 and 6
- Hyperparameter tuning was conducted using the 5-fold cross-validation method for 5 repeats with optimal model performance determined using the concordance index.
##################################
# Defining the modelling pipeline
# using the survival tree model
##################################
stree_pipeline_preprocessor = ColumnTransformer(
transformers=[
# Applying PowerTransformer to numeric columns only
('numeric_predictors', PowerTransformer(method='yeo-johnson', standardize=True), ['AGE', 'EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM'])
# Keeping the categorical columns unchanged
], remainder='passthrough'
)
stree_pipeline = Pipeline([
('yeo_johnson', stree_pipeline_preprocessor),
('stree', SurvivalTree())])
##################################
# Saving the model pipeline
# developed from the original training data
# for downstream processes
##################################
stree_pipeline.fit(X_train, y_train_array)
joblib.dump(stree_pipeline,
os.path.join("..", PIPELINES_PATH, "stree_pipeline.pkl"))
['..\\pipelines\\stree_pipeline.pkl']
##################################
# Defining the hyperparameters for grid search
##################################
stree_hyperparameter_grid = {'stree__min_samples_split': [10, 15, 20],
'stree__min_samples_leaf': [3, 6],
'stree__random_state': [88888888]}
##################################
# Setting up the GridSearchCV with 5-fold cross-validation
# and using concordance index as the model evaluation metric
##################################
stree_grid_search = GridSearchCV(estimator=stree_pipeline,
param_grid=stree_hyperparameter_grid,
cv=RepeatedKFold(n_splits=5, n_repeats=5, random_state=88888888),
return_train_score=False,
n_jobs=-1,
verbose=1)
1.6.3.4 Random Survival Forest ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Random Survival Forest is an ensemble method that builds multiple survival trees and averages their predictions. The model combines the predictions of multiple survival trees, each built on a bootstrap sample of the data and a random subset of predictors. It uses the concept of ensemble learning to improve predictive accuracy and robustness. As a method, it handles high-dimensional data and complex interactions between variables well; can be more accurate and robust than a single survival tree; and provides measures of variable importance. However, the process can be bomputationally intensive due to the need to build multiple trees, and may be less interpretable than single trees or parametric models like the Cox model. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves generating multiple bootstrap samples from the original dataset; building a survival tree by recursively splitting the data at nodes using a random subset of predictor variables for each bootstrap sample; combining the predictions of all survival trees to form the random survival forest and averaging the survival functions predicted by all trees in the forest to obtain the final survival function for new data.
Yeo-Johnson Transformation applies a new family of distributions that can be used without restrictions, extending many of the good properties of the Box-Cox power family. Similar to the Box-Cox transformation, the method also estimates the optimal value of lambda but has the ability to transform both positive and negative values by inflating low variance data and deflating high variance data to create a more uniform data set. While there are no restrictions in terms of the applicable values, the interpretability of the transformed values is more diminished as compared to the other methods.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- A modelling pipeline was implemented with the following steps:
- Yeo-johnson transformation from the sklearn.processing Python library API applied to the numeric predictors only. Categorical predictors were excluded from the transformation.
- Random survival forest model from the sksurv.ensemble Python library API with 2 hyperparameters:
- n_estimators = number of trees in the forest made to vary between 100, 200 and 300
- min_samples_split = minimum number of samples required to split an internal node made to vary between 10, 15 and 20
- Hyperparameter tuning was conducted using the 5-fold cross-validation method for 5 repeats with optimal model performance determined using the concordance index.
##################################
# Defining the modelling pipeline
# using the random survival forest model
##################################
rsf_pipeline_preprocessor = ColumnTransformer(
transformers=[
# Applying PowerTransformer to numeric columns only
('numeric_predictors', PowerTransformer(method='yeo-johnson', standardize=True), ['AGE', 'EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM'])
# Keeping the categorical columns unchanged
], remainder='passthrough'
)
rsf_pipeline = Pipeline([
('yeo_johnson', rsf_pipeline_preprocessor),
('rsf', RandomSurvivalForest())])
##################################
# Saving the model pipeline
# developed from the original training data
# for downstream processes
##################################
rsf_pipeline.fit(X_train, y_train_array)
joblib.dump(rsf_pipeline,
os.path.join("..", PIPELINES_PATH, "rsf_pipeline.pkl"))
['..\\pipelines\\rsf_pipeline.pkl']
##################################
# Defining the hyperparameters for grid search
##################################
rsf_hyperparameter_grid = {'rsf__n_estimators': [100, 200, 300],
'rsf__min_samples_split': [10, 15, 20],
'rsf__random_state': [88888888]}
##################################
# Setting up the GridSearchCV with 5-fold cross-validation
# and using concordance index as the model evaluation metric
##################################
rsf_grid_search = GridSearchCV(estimator=rsf_pipeline,
param_grid=rsf_hyperparameter_grid,
cv=RepeatedKFold(n_splits=5, n_repeats=5, random_state=88888888),
return_train_score=False,
n_jobs=-1,
verbose=1)
1.6.3.5 Gradient Boosted Survival ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Gradient Boosted Survival is an ensemble technique that builds a series of survival trees, where each tree tries to correct the errors of the previous one. The model uses boosting, a sequential technique where each new tree is fit to the residuals of the combined previous trees, and combines the predictions of all the trees to produce a final prediction. As a method, it has high predictive accuracy, the ability to model complex relationships, and reduces bias and variance compared to single-tree models. However, the process can even be more computationally intensive than Random Survival Forest, requires careful tuning of multiple hyperparameters, and makes interpretation challenging due to the complex nature of the model. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves starting with an initial prediction (often the median survival time or a simple model); calculating the residuals (errors) of the current model's predictions; fitting a survival tree to the residuals to learn the errors made by the current model; updating the current model by adding the new tree weighted by a learning rate parameter; repeating previous steps for a fixed number of iterations or until convergence; and summing the predictions of all trees in the sequence to obtain the final survival function for new data.
Yeo-Johnson Transformation applies a new family of distributions that can be used without restrictions, extending many of the good properties of the Box-Cox power family. Similar to the Box-Cox transformation, the method also estimates the optimal value of lambda but has the ability to transform both positive and negative values by inflating low variance data and deflating high variance data to create a more uniform data set. While there are no restrictions in terms of the applicable values, the interpretability of the transformed values is more diminished as compared to the other methods.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- A modelling pipeline was implemented with the following steps:
- Yeo-johnson transformation from the sklearn.processing Python library API applied to the numeric predictors only. Categorical predictors were excluded from the transformation.
- Gradient boosted survival model from the sksurv.ensemble Python library API with 2 hyperparameters:
- n_estimators = number of regression trees to create made to vary between 100, 200 and 300
- learning_rate = shrinkage parameter for the contribution of each tree made to vary between 0.05, 0.10 and 0.15
- Hyperparameter tuning was conducted using the 5-fold cross-validation method for 5 repeats with optimal model performance determined using the concordance index.
##################################
# Defining the modelling pipeline
# using the gradient boosted survival model
##################################
gbs_pipeline_preprocessor = ColumnTransformer(
transformers=[
# Applying PowerTransformer to numeric columns only
('numeric_predictors', PowerTransformer(method='yeo-johnson', standardize=True), ['AGE', 'EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM'])
# Keeping the categorical columns unchanged
], remainder='passthrough'
)
gbs_pipeline = Pipeline([
('yeo_johnson', gbs_pipeline_preprocessor),
('gbs', GradientBoostingSurvivalAnalysis())])
##################################
# Saving the model pipeline
# developed from the original training data
# for downstream processes
##################################
gbs_pipeline.fit(X_train, y_train_array)
joblib.dump(gbs_pipeline,
os.path.join("..", PIPELINES_PATH, "gbs_pipeline.pkl"))
['..\\pipelines\\gbs_pipeline.pkl']
##################################
# Defining the hyperparameters for grid search
##################################
gbs_hyperparameter_grid = {'gbs__n_estimators': [100, 200, 300],
'gbs__learning_rate': [0.05, 0.10, 0.15],
'gbs__random_state': [88888888]}
##################################
# Setting up the GridSearchCV with 5-fold cross-validation
# and using concordance index as the model evaluation metric
##################################
gbs_grid_search = GridSearchCV(estimator=gbs_pipeline,
param_grid=gbs_hyperparameter_grid,
cv=RepeatedKFold(n_splits=5, n_repeats=5, random_state=88888888),
return_train_score=False,
n_jobs=-1,
verbose=1)
1.6.4 Cox Proportional Hazards Regression Model Fitting | Hyperparameter Tuning | Validation ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Cox Proportional Hazards Regression is a semiparametric model used to study the relationship between the survival time of subjects and one or more predictor variables. The model assumes that the hazard ratio (the risk of the event occurring at a specific time) is a product of a baseline hazard function and an exponential function of the predictor variables. It also does not require the baseline hazard to be specified, thus making it a semiparametric model. As a method, it is well-established and widely used in survival analysis, can handle time-dependent covariates and provides a relatively straightforward interpretation. However, the process assumes proportional hazards, which may not hold in all datasets, and may be less flexible in capturing complex relationships between variables and survival times compared to some machine learning models. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves defining the partial likelihood function for the Cox model (which only considers the relative ordering of survival times); using optimization techniques to estimate the regression coefficients by maximizing the log-partial likelihood; estimating the baseline hazard function (although it is not explicitly required for predictions); and calculating the hazard function and survival function for new data using the estimated coefficients and baseline hazard.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- The cox proportional hazards regression model from the sksurv.linear_model Python library API was implemented.
- The model implementation used 1 hyperparameter:
- alpha = regularization parameter for ridge regression penalty made to vary between 0.00, 0.01, 0.10, 1.00 and 10.0
- Hyperparameter tuning was conducted using the 5-fold cross-validation method repeated 5 times with optimal model performance using the concordance index determined for:
- alpha = 10.00
- The cross-validated model performance of the optimal model is summarized as follows:
- Concordance Index = 0.7073
- The apparent model performance of the optimal model is summarized as follows:
- Concordance Index = 0.7419
- The independent validation model performance of the final model is summarized as follows:
- Concordance Index = 0.7394
- Considerable difference in the apparent and cross-validated model performance observed, indicative of the presence of moderate model overfitting.
- Survival probability curves obtained from the groups generated by dichotomizing the risk scores demonstrated sufficient differentiation across the entire duration.
- Hazard and survival probability estimations for 5 sampled cases demonstrated reasonably smooth profiles.
##################################
# Performing hyperparameter tuning
# through K-fold cross-validation
# using the Cox Proportional Hazards Regression Model
##################################
coxph_grid_search.fit(X_train, y_train_array)
Fitting 25 folds for each of 4 candidates, totalling 100 fits
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxph', CoxPHSurvivalAnalysis())]), n_jobs=-1, param_grid={'coxph__alpha': [0.01, 0.1, 1.0, 10.0]}, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxph', CoxPHSurvivalAnalysis())]), n_jobs=-1, param_grid={'coxph__alpha': [0.01, 0.1, 1.0, 10.0]}, verbose=1)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxph', CoxPHSurvivalAnalysis(alpha=10.0))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
CoxPHSurvivalAnalysis(alpha=10.0)
##################################
# Summarizing the hyperparameter tuning
# results from K-fold cross-validation
##################################
coxph_grid_search_results = pd.DataFrame(coxph_grid_search.cv_results_).sort_values(by='mean_test_score', ascending=False)
coxph_grid_search_results.loc[:, ~coxph_grid_search_results.columns.str.endswith('_time')]
param_coxph__alpha | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | split5_test_score | split6_test_score | split7_test_score | ... | split18_test_score | split19_test_score | split20_test_score | split21_test_score | split22_test_score | split23_test_score | split24_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3 | 10.00 | {'coxph__alpha': 10.0} | 0.758410 | 0.681115 | 0.780612 | 0.840000 | 0.521429 | 0.588957 | 0.637821 | 0.862500 | ... | 0.62 | 0.683019 | 0.732143 | 0.712264 | 0.700272 | 0.525641 | 0.791667 | 0.707318 | 0.084671 | 1 |
2 | 1.00 | {'coxph__alpha': 1.0} | 0.764526 | 0.715170 | 0.785714 | 0.848889 | 0.492857 | 0.616564 | 0.628205 | 0.850000 | ... | 0.63 | 0.675472 | 0.736607 | 0.669811 | 0.694823 | 0.521368 | 0.802083 | 0.701906 | 0.086235 | 2 |
1 | 0.10 | {'coxph__alpha': 0.1} | 0.758410 | 0.702786 | 0.790816 | 0.848889 | 0.492857 | 0.625767 | 0.631410 | 0.854167 | ... | 0.62 | 0.679245 | 0.709821 | 0.669811 | 0.697548 | 0.529915 | 0.796875 | 0.701768 | 0.085985 | 3 |
0 | 0.01 | {'coxph__alpha': 0.01} | 0.758410 | 0.702786 | 0.790816 | 0.848889 | 0.492857 | 0.625767 | 0.631410 | 0.854167 | ... | 0.62 | 0.679245 | 0.709821 | 0.669811 | 0.694823 | 0.529915 | 0.796875 | 0.701134 | 0.086022 | 4 |
4 rows × 30 columns
##################################
# Identifying the best model
##################################
coxph_best_model_train_cv = coxph_grid_search.best_estimator_
print('Best Cox Proportional Hazards Regression Model using the Cross-Validated Train Data: ')
print(f"Best Model Parameters: {coxph_grid_search.best_params_}")
Best Cox Proportional Hazards Regression Model using the Cross-Validated Train Data: Best Model Parameters: {'coxph__alpha': 10.0}
##################################
# Obtaining the cross-validation model performance of the
# optimal Cox Proportional Hazards Regression Model
# on the train set
##################################
optimal_coxph_heart_failure_y_crossvalidation_ci = coxph_grid_search.best_score_
print(f"Cross-Validation Concordance Index: {optimal_coxph_heart_failure_y_crossvalidation_ci}")
Cross-Validation Concordance Index: 0.7073178688853099
##################################
# Formulating a Cox Proportional Hazards Regression Model
# with optimal hyperparameters
##################################
optimal_coxph_model = coxph_grid_search.best_estimator_
optimal_coxph_model.fit(X_train, y_train_array)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxph', CoxPHSurvivalAnalysis(alpha=10.0))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxph', CoxPHSurvivalAnalysis(alpha=10.0))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
CoxPHSurvivalAnalysis(alpha=10.0)
##################################
# Measuring model performance of the
# optimal Cox Proportional Hazards Regression Model
# on the train set
##################################
optimal_coxph_heart_failure_y_train_pred = optimal_coxph_model.predict(X_train)
optimal_coxph_heart_failure_y_train_ci = concordance_index_censored(y_train_array['DEATH_EVENT'],
y_train_array['TIME'],
optimal_coxph_heart_failure_y_train_pred)[0]
print(f"Apparent Concordance Index: {optimal_coxph_heart_failure_y_train_ci}")
Apparent Concordance Index: 0.7419406319821258
##################################
# Measuring model performance of the
# optimal Cox Proportional Hazards Regression Model
# on the validation set
##################################
optimal_coxph_heart_failure_y_validation_pred = optimal_coxph_model.predict(X_validation)
optimal_coxph_heart_failure_y_validation_ci = concordance_index_censored(y_validation_array['DEATH_EVENT'],
y_validation_array['TIME'],
optimal_coxph_heart_failure_y_validation_pred)[0]
print(f"Validation Concordance Index: {optimal_coxph_heart_failure_y_validation_ci}")
Validation Concordance Index: 0.7394270122783083
##################################
# Gathering the concordance indices
# from the train and tests sets for
# Cox Proportional Hazards Regression Model
##################################
coxph_set = pd.DataFrame(["Train","Cross-Validation","Validation"])
coxph_ci_values = pd.DataFrame([optimal_coxph_heart_failure_y_train_ci,
optimal_coxph_heart_failure_y_crossvalidation_ci,
optimal_coxph_heart_failure_y_validation_ci])
coxph_method = pd.DataFrame(["COXPH"]*3)
coxph_summary = pd.concat([coxph_set,
coxph_ci_values,
coxph_method], axis=1)
coxph_summary.columns = ['Set', 'Concordance.Index', 'Method']
coxph_summary.reset_index(inplace=True, drop=True)
display(coxph_summary)
Set | Concordance.Index | Method | |
---|---|---|---|
0 | Train | 0.741941 | COXPH |
1 | Cross-Validation | 0.707318 | COXPH |
2 | Validation | 0.739427 | COXPH |
##################################
# Binning the predicted risks
# into dichotomous groups and
# exploring the relationships with
# survival event and duration
##################################
heart_failure_validation.reset_index(drop=True, inplace=True)
kmf = KaplanMeierFitter()
heart_failure_validation['Predicted_Risks_CoxPH'] = optimal_coxph_heart_failure_y_validation_pred
heart_failure_validation['Predicted_RiskGroups_CoxPH'] = risk_groups = pd.qcut(heart_failure_validation['Predicted_Risks_CoxPH'], 2, labels=['Low-Risk', 'High-Risk'])
plt.figure(figsize=(17, 8))
for group in risk_groups.unique():
group_data = heart_failure_validation[risk_groups == group]
kmf.fit(group_data['TIME'], event_observed=group_data['DEATH_EVENT'], label=group)
kmf.plot_survival_function()
plt.title('COXPH Survival Probabilities by Predicted Risk Groups on Validation Set')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.show()
##################################
# Gathering the predictor information
# for 5 test case samples
##################################
validation_case_details = X_validation.iloc[[5, 10, 15, 20, 25]]
display(validation_case_details)
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
291 | 60.0 | 0 | 35.0 | 0 | 1.4 | 139.0 |
66 | 42.0 | 1 | 15.0 | 0 | 1.3 | 136.0 |
112 | 50.0 | 0 | 25.0 | 0 | 1.6 | 136.0 |
89 | 57.0 | 1 | 25.0 | 1 | 1.1 | 144.0 |
17 | 45.0 | 0 | 14.0 | 0 | 0.8 | 127.0 |
##################################
# Gathering the event and duration information
# for 5 test case samples
##################################
print(y_validation_array[[5, 10, 15, 20, 25]])
[(False, 258) ( True, 65) (False, 90) (False, 79) ( True, 14)]
##################################
# Gathering the risk-groups
# for 5 test case samples
##################################
print(heart_failure_validation.loc[[5, 10, 15, 20, 25]][['Predicted_RiskGroups_CoxPH']])
Predicted_RiskGroups_CoxPH 5 Low-Risk 10 High-Risk 15 High-Risk 20 High-Risk 25 High-Risk
##################################
# Estimating the cumulative hazard
# and survival functions
# for 5 validation cases
##################################
validation_case = X_validation.iloc[[5, 10, 15, 20, 25]]
validation_case_labels = ['Patient_5','Patient_10','Patient_15','Patient_20','Patient_25',]
validation_case_cumulative_hazard_function = optimal_coxph_model.predict_cumulative_hazard_function(validation_case)
validation_case_survival_function = optimal_coxph_model.predict_survival_function(validation_case)
fig, ax = plt.subplots(1,2,figsize=(17, 8))
for hazard_prediction, survival_prediction in zip(validation_case_cumulative_hazard_function, validation_case_survival_function):
ax[0].step(hazard_prediction.x,hazard_prediction(hazard_prediction.x),where='post')
ax[1].step(survival_prediction.x,survival_prediction(survival_prediction.x),where='post')
ax[0].set_title('COXPH Cumulative Hazard for 5 Validation Cases')
ax[0].set_xlabel('TIME')
ax[0].set_ylim(0,2)
ax[0].set_ylabel('Cumulative Hazard')
ax[0].legend(validation_case_labels, loc="upper left")
ax[1].set_title('COXPH Survival Function for 5 Validation Cases')
ax[1].set_xlabel('TIME')
ax[1].set_ylabel('DEATH_EVENT Survival Probability')
ax[1].legend(validation_case_labels, loc="lower left")
plt.show()
##################################
# Saving the best Cox Proportional Hazards Regression Model
# developed from the original training data
##################################
joblib.dump(coxph_best_model_train_cv,
os.path.join("..", MODELS_PATH, "coxph_best_model.pkl"))
['..\\models\\coxph_best_model.pkl']
1.6.5 Cox Net Survival Model Fitting | Hyperparameter Tuning | Validation ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Cox Net Survival is a regularized version of the Cox Proportional Hazards model, which incorporates both L1 (Lasso) and L2 (Ridge) penalties. The model is useful when dealing with high-dimensional data where the number of predictors can be larger than the number of observations. The elastic net penalty helps in both variable selection (via L1) and multicollinearity handling (via L2). As a method, it can handle high-dimensional data and perform variable selection. Additionally, it balances between L1 and L2 penalties, offering flexibility in modeling. However, the process requires tuning of penalty parameters, which can be computationally intensive. Additionally, interpretation is more complex due to the regularization terms. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves defining the penalized partial likelihood function, incorporating both L1 (Lasso) and L2 (Ridge) penalties; application of regularization techniques to estimate the regression coefficients by maximizing the penalized log-partial likelihood; performing cross-validation to select optimal values for the penalty parameters (alpha and l1_ratio); and the calculation of the hazard function and survival function for new data using the estimated regularized coefficients.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- The cox net survival model from the sksurv.linear_model Python library API was implemented.
- The model implementation used 2 hyperparameters:
- l1_ratio = ElasticNet mixing parameter made to vary between 0.10, 0.50 and 1.00
- alpha_min_ratio = minimum alpha of the regularization path made to vary between 0.0001 and 0.01
- Hyperparameter tuning was conducted using the 5-fold cross-validation method repeated 5 times with optimal model performance using the concordance index determined for:
- l1_ratio = 0.10
- alpha_min_ratio = 0.01
- The cross-validated model performance of the optimal model is summarized as follows:
- Concordance Index = 0.7014
- The apparent model performance of the optimal model is summarized as follows:
- Concordance Index = 0.7419
- The independent validation model performance of the final model is summarized as follows:
- Concordance Index = 0.7299
- Considerable difference in the apparent and cross-validated model performance observed, indicative of the presence of moderate model overfitting.
- Survival probability curves obtained from the groups generated by dichotomizing the risk scores demonstrated sufficient differentiation across the entire duration.
- Hazard and survival probability estimations for 5 sampled cases demonstrated reasonably smooth profiles.
##################################
# Performing hyperparameter tuning
# through K-fold cross-validation
# using the Cox Proportional Hazards Regression Model
##################################
coxns_grid_search.fit(X_train, y_train_array)
Fitting 25 folds for each of 6 candidates, totalling 150 fits
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxns', CoxnetSurvivalAnalysis())]), n_jobs=-1, param_grid={'coxns__alpha_min_ratio': [0.0001, 0.01], 'coxns__fit_baseline_model': [True], 'coxns__l1_ratio': [0.1, 0.5, 1.0]}, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxns', CoxnetSurvivalAnalysis())]), n_jobs=-1, param_grid={'coxns__alpha_min_ratio': [0.0001, 0.01], 'coxns__fit_baseline_model': [True], 'coxns__l1_ratio': [0.1, 0.5, 1.0]}, verbose=1)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxns', CoxnetSurvivalAnalysis(alpha_min_ratio=0.01, fit_baseline_model=True, l1_ratio=0.1))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
CoxnetSurvivalAnalysis(alpha_min_ratio=0.01, fit_baseline_model=True, l1_ratio=0.1)
##################################
# Summarizing the hyperparameter tuning
# results from K-fold cross-validation
##################################
coxns_grid_search_results = pd.DataFrame(coxns_grid_search.cv_results_).sort_values(by='mean_test_score', ascending=False)
coxns_grid_search_results.loc[:, ~coxns_grid_search_results.columns.str.endswith('_time')]
param_coxns__alpha_min_ratio | param_coxns__fit_baseline_model | param_coxns__l1_ratio | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | split5_test_score | ... | split18_test_score | split19_test_score | split20_test_score | split21_test_score | split22_test_score | split23_test_score | split24_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3 | 0.0100 | True | 0.1 | {'coxns__alpha_min_ratio': 0.01, 'coxns__fit_b... | 0.761468 | 0.705882 | 0.785714 | 0.844444 | 0.514286 | 0.601227 | ... | 0.630 | 0.675472 | 0.727679 | 0.683962 | 0.694823 | 0.525641 | 0.796875 | 0.701369 | 0.085525 | 1 |
0 | 0.0001 | True | 0.1 | {'coxns__alpha_min_ratio': 0.0001, 'coxns__fit... | 0.761468 | 0.702786 | 0.790816 | 0.848889 | 0.492857 | 0.619632 | ... | 0.625 | 0.675472 | 0.723214 | 0.669811 | 0.694823 | 0.521368 | 0.802083 | 0.701345 | 0.086059 | 2 |
1 | 0.0001 | True | 0.5 | {'coxns__alpha_min_ratio': 0.0001, 'coxns__fit... | 0.761468 | 0.702786 | 0.785714 | 0.848889 | 0.492857 | 0.619632 | ... | 0.625 | 0.675472 | 0.718750 | 0.665094 | 0.694823 | 0.525641 | 0.802083 | 0.701266 | 0.086197 | 3 |
2 | 0.0001 | True | 1.0 | {'coxns__alpha_min_ratio': 0.0001, 'coxns__fit... | 0.761468 | 0.702786 | 0.785714 | 0.848889 | 0.492857 | 0.619632 | ... | 0.625 | 0.675472 | 0.718750 | 0.665094 | 0.694823 | 0.525641 | 0.796875 | 0.700668 | 0.086233 | 4 |
4 | 0.0100 | True | 0.5 | {'coxns__alpha_min_ratio': 0.01, 'coxns__fit_b... | 0.758410 | 0.708978 | 0.790816 | 0.848889 | 0.492857 | 0.619632 | ... | 0.625 | 0.675472 | 0.718750 | 0.669811 | 0.694823 | 0.517094 | 0.802083 | 0.700332 | 0.086584 | 5 |
5 | 0.0100 | True | 1.0 | {'coxns__alpha_min_ratio': 0.01, 'coxns__fit_b... | 0.758410 | 0.705882 | 0.790816 | 0.848889 | 0.492857 | 0.616564 | ... | 0.625 | 0.675472 | 0.714286 | 0.669811 | 0.694823 | 0.517094 | 0.802083 | 0.700281 | 0.086774 | 6 |
6 rows × 32 columns
##################################
# Identifying the best model
##################################
coxns_best_model_train_cv = coxns_grid_search.best_estimator_
print('Best Cox Proportional Hazards Regression Model using the Cross-Validated Train Data: ')
print(f"Best Model Parameters: {coxns_grid_search.best_params_}")
Best Cox Proportional Hazards Regression Model using the Cross-Validated Train Data: Best Model Parameters: {'coxns__alpha_min_ratio': 0.01, 'coxns__fit_baseline_model': True, 'coxns__l1_ratio': 0.1}
##################################
# Obtaining the cross-validation model performance of the
# optimal Cox Net Survival Model
# on the train set
##################################
optimal_coxns_heart_failure_y_crossvalidation_ci = coxns_grid_search.best_score_
print(f"Cross-Validation Concordance Index: {optimal_coxns_heart_failure_y_crossvalidation_ci}")
Cross-Validation Concordance Index: 0.7013694603679497
##################################
# Formulating a Cox Net Survival Model
# with optimal hyperparameters
##################################
optimal_coxns_model = coxns_grid_search.best_estimator_
optimal_coxns_model.fit(X_train, y_train_array)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxns', CoxnetSurvivalAnalysis(alpha_min_ratio=0.01, fit_baseline_model=True, l1_ratio=0.1))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('coxns', CoxnetSurvivalAnalysis(alpha_min_ratio=0.01, fit_baseline_model=True, l1_ratio=0.1))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
CoxnetSurvivalAnalysis(alpha_min_ratio=0.01, fit_baseline_model=True, l1_ratio=0.1)
##################################
# Measuring model performance of the
# optimal Cox Net Survival Model
# on the train set
##################################
optimal_coxns_heart_failure_y_train_pred = optimal_coxns_model.predict(X_train)
optimal_coxns_heart_failure_y_train_ci = concordance_index_censored(y_train_array['DEATH_EVENT'],
y_train_array['TIME'],
optimal_coxns_heart_failure_y_train_pred)[0]
print(f"Apparent Concordance Index: {optimal_coxns_heart_failure_y_train_ci}")
Apparent Concordance Index: 0.7419406319821258
##################################
# Measuring model performance of the
# optimal Cox Net Survival Model
# on the validation set
##################################
optimal_coxns_heart_failure_y_validation_pred = optimal_coxns_model.predict(X_validation)
optimal_coxns_heart_failure_y_validation_ci = concordance_index_censored(y_validation_array['DEATH_EVENT'],
y_validation_array['TIME'],
optimal_coxns_heart_failure_y_validation_pred)[0]
print(f"Validation Concordance Index: {optimal_coxns_heart_failure_y_validation_ci}")
Validation Concordance Index: 0.7298772169167803
##################################
# Gathering the concordance indices
# from the train and tests sets for
# Cox Net Survival Model
##################################
coxns_set = pd.DataFrame(["Train","Cross-Validation","Validation"])
coxns_ci_values = pd.DataFrame([optimal_coxns_heart_failure_y_train_ci,
optimal_coxns_heart_failure_y_crossvalidation_ci,
optimal_coxns_heart_failure_y_validation_ci])
coxns_method = pd.DataFrame(["COXNS"]*3)
coxns_summary = pd.concat([coxns_set,
coxns_ci_values,
coxns_method], axis=1)
coxns_summary.columns = ['Set', 'Concordance.Index', 'Method']
coxns_summary.reset_index(inplace=True, drop=True)
display(coxns_summary)
Set | Concordance.Index | Method | |
---|---|---|---|
0 | Train | 0.741941 | COXNS |
1 | Cross-Validation | 0.701369 | COXNS |
2 | Validation | 0.729877 | COXNS |
##################################
# Binning the predicted risks
# into dichotomous groups and
# exploring the relationships with
# survival event and duration
##################################
heart_failure_validation.reset_index(drop=True, inplace=True)
kmf = KaplanMeierFitter()
heart_failure_validation['Predicted_Risks_CoxNS'] = optimal_coxns_heart_failure_y_validation_pred
heart_failure_validation['Predicted_RiskGroups_CoxNS'] = risk_groups = pd.qcut(heart_failure_validation['Predicted_Risks_CoxNS'], 2, labels=['Low-Risk', 'High-Risk'])
plt.figure(figsize=(17, 8))
for group in risk_groups.unique():
group_data = heart_failure_validation[risk_groups == group]
kmf.fit(group_data['TIME'], event_observed=group_data['DEATH_EVENT'], label=group)
kmf.plot_survival_function()
plt.title('COXNS Survival Probabilities by Predicted Risk Groups on Validation Set')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.show()
##################################
# Gathering the predictor information
# for 5 test case samples
##################################
validation_case_details = X_validation.iloc[[5, 10, 15, 20, 25]]
display(validation_case_details)
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
291 | 60.0 | 0 | 35.0 | 0 | 1.4 | 139.0 |
66 | 42.0 | 1 | 15.0 | 0 | 1.3 | 136.0 |
112 | 50.0 | 0 | 25.0 | 0 | 1.6 | 136.0 |
89 | 57.0 | 1 | 25.0 | 1 | 1.1 | 144.0 |
17 | 45.0 | 0 | 14.0 | 0 | 0.8 | 127.0 |
##################################
# Gathering the event and duration information
# for 5 test case samples
##################################
print(y_validation_array[[5, 10, 15, 20, 25]])
[(False, 258) ( True, 65) (False, 90) (False, 79) ( True, 14)]
##################################
# Gathering the risk-groups
# for 5 test case samples
##################################
print(heart_failure_validation.loc[[5, 10, 15, 20, 25]][['Predicted_RiskGroups_CoxNS']])
Predicted_RiskGroups_CoxNS 5 Low-Risk 10 High-Risk 15 High-Risk 20 High-Risk 25 High-Risk
##################################
# Estimating the cumulative hazard
# and survival functions
# for 5 validation cases
##################################
validation_case = X_validation.iloc[[5, 10, 15, 20, 25]]
validation_case_labels = ['Patient_5','Patient_10','Patient_15','Patient_20','Patient_25',]
validation_case_cumulative_hazard_function = optimal_coxns_model.predict_cumulative_hazard_function(validation_case)
validation_case_survival_function = optimal_coxns_model.predict_survival_function(validation_case)
fig, ax = plt.subplots(1,2,figsize=(17, 8))
for hazard_prediction, survival_prediction in zip(validation_case_cumulative_hazard_function, validation_case_survival_function):
ax[0].step(hazard_prediction.x,hazard_prediction(hazard_prediction.x),where='post')
ax[1].step(survival_prediction.x,survival_prediction(survival_prediction.x),where='post')
ax[0].set_title('COXNS Cumulative Hazard for 5 Validation Cases')
ax[0].set_xlabel('TIME')
ax[0].set_ylim(0,2)
ax[0].set_ylabel('Cumulative Hazard')
ax[0].legend(validation_case_labels, loc="upper left")
ax[1].set_title('COXNS Survival Function for 5 Validation Cases')
ax[1].set_xlabel('TIME')
ax[1].set_ylabel('DEATH_EVENT Survival Probability')
ax[1].legend(validation_case_labels, loc="lower left")
plt.show()
##################################
# Saving the best Cox Proportional Hazards Regression Model
# developed from the original training data
##################################
joblib.dump(coxns_best_model_train_cv,
os.path.join("..", MODELS_PATH, "coxns_best_model.pkl"))
['..\\models\\coxns_best_model.pkl']
1.6.6 Survival Tree Model Fitting | Hyperparameter Tuning | Validation ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Survival Trees are non-parametric models that partition the data into subgroups (nodes) based on the values of predictor variables, creating a tree-like structure. The tree is built by recursively splitting the data at nodes where the differences in survival times between subgroups are maximized. Each terminal node represents a different survival function. The method have no assumptions about the underlying distribution of survival times, can capture interactions between variables naturally and applies an interpretable visual representation. However, the process can be prone to overfitting, especially with small datasets, and may be less accurate compared to ensemble methods like Random Survival Forest. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves recursively splitting the data at nodes to maximize the differences in survival times between subgroups with the splitting criteria often involving statistical tests (e.g., log-rank test); choosing the best predictor variable and split point at each node that maximizes the separation of survival times; continuously splitting until stopping criteria are met (e.g., minimum number of observations in a node, maximum tree depth); and estimating the survival function based on the survival times of the observations at each terminal node.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- The survival tree model from the sksurv.tree Python library API was implemented.
- The model implementation used 2 hyperparameters:
- min_samples_split = minimum number of samples required to split an internal node made to vary between 10, 15 and 20
- min_samples_leaf = minimum number of samples required to be at a leaf node made to vary between 3 and 6
- Hyperparameter tuning was conducted using the 5-fold cross-validation method repeated 5 times with optimal model performance using the concordance index determined for:
- min_samples_split = 20
- min_samples_leaf = 6
- The cross-validated model performance of the optimal model is summarized as follows:
- Concordance Index = 0.6542
- The apparent model performance of the optimal model is summarized as follows:
- Concordance Index = 0.7992
- The independent validation model performance of the final model is summarized as follows:
- Concordance Index = 0.6446
- Significant difference in the apparent and cross-validated model performance observed, indicative of the presence of excessive model overfitting.
- Survival probability curves obtained from the groups generated by dichotomizing the risk scores demonstrated non-optimal differentiation across the entire duration.
- Hazard and survival probability estimations for 5 sampled cases demonstrated non-optimal profiles.
##################################
# Performing hyperparameter tuning
# through K-fold cross-validation
# using the Survival Tree Model
##################################
stree_grid_search.fit(X_train, y_train_array)
Fitting 25 folds for each of 6 candidates, totalling 150 fits
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('stree', SurvivalTree())]), n_jobs=-1, param_grid={'stree__min_samples_leaf': [3, 6], 'stree__min_samples_split': [10, 15, 20], 'stree__random_state': [88888888]}, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('stree', SurvivalTree())]), n_jobs=-1, param_grid={'stree__min_samples_leaf': [3, 6], 'stree__min_samples_split': [10, 15, 20], 'stree__random_state': [88888888]}, verbose=1)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('stree', SurvivalTree(min_samples_leaf=6, min_samples_split=20, random_state=88888888))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
SurvivalTree(min_samples_leaf=6, min_samples_split=20, random_state=88888888)
##################################
# Summarizing the hyperparameter tuning
# results from K-fold cross-validation
##################################
stree_grid_search_results = pd.DataFrame(stree_grid_search.cv_results_).sort_values(by='mean_test_score', ascending=False)
stree_grid_search_results.loc[:, ~stree_grid_search_results.columns.str.endswith('_time')]
param_stree__min_samples_leaf | param_stree__min_samples_split | param_stree__random_state | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | split5_test_score | ... | split18_test_score | split19_test_score | split20_test_score | split21_test_score | split22_test_score | split23_test_score | split24_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
5 | 6 | 20 | 88888888 | {'stree__min_samples_leaf': 6, 'stree__min_sam... | 0.749235 | 0.715170 | 0.721939 | 0.726667 | 0.560714 | 0.562883 | ... | 0.7250 | 0.609434 | 0.680804 | 0.627358 | 0.698910 | 0.611111 | 0.494792 | 0.654169 | 0.072607 | 1 |
2 | 3 | 20 | 88888888 | {'stree__min_samples_leaf': 3, 'stree__min_sam... | 0.669725 | 0.690402 | 0.653061 | 0.777778 | 0.553571 | 0.475460 | ... | 0.7250 | 0.588679 | 0.671875 | 0.632075 | 0.716621 | 0.638889 | 0.505208 | 0.646178 | 0.076711 | 2 |
4 | 6 | 15 | 88888888 | {'stree__min_samples_leaf': 6, 'stree__min_sam... | 0.743119 | 0.687307 | 0.798469 | 0.726667 | 0.539286 | 0.553681 | ... | 0.6325 | 0.658491 | 0.629464 | 0.627358 | 0.660763 | 0.621795 | 0.486979 | 0.636490 | 0.079302 | 3 |
3 | 6 | 10 | 88888888 | {'stree__min_samples_leaf': 6, 'stree__min_sam... | 0.718654 | 0.681115 | 0.801020 | 0.724444 | 0.546429 | 0.558282 | ... | 0.6500 | 0.715094 | 0.642857 | 0.587264 | 0.643052 | 0.611111 | 0.486979 | 0.634086 | 0.078157 | 4 |
1 | 3 | 15 | 88888888 | {'stree__min_samples_leaf': 3, 'stree__min_sam... | 0.669725 | 0.673375 | 0.673469 | 0.746667 | 0.564286 | 0.438650 | ... | 0.6425 | 0.598113 | 0.609375 | 0.632075 | 0.663488 | 0.608974 | 0.502604 | 0.624111 | 0.072077 | 5 |
0 | 3 | 10 | 88888888 | {'stree__min_samples_leaf': 3, 'stree__min_sam... | 0.646789 | 0.684211 | 0.678571 | 0.691111 | 0.582143 | 0.435583 | ... | 0.6300 | 0.635849 | 0.587054 | 0.509434 | 0.632153 | 0.647436 | 0.489583 | 0.609106 | 0.074834 | 6 |
6 rows × 32 columns
##################################
# Identifying the best model
##################################
stree_best_model_train_cv = stree_grid_search.best_estimator_
print('Best Survival Tree Model using the Cross-Validated Train Data: ')
print(f"Best Model Parameters: {stree_grid_search.best_params_}")
Best Survival Tree Model using the Cross-Validated Train Data: Best Model Parameters: {'stree__min_samples_leaf': 6, 'stree__min_samples_split': 20, 'stree__random_state': 88888888}
##################################
# Obtaining the cross-validation model performance of the
# optimal Survival Tree Model
# on the train set
##################################
optimal_stree_heart_failure_y_crossvalidation_ci = stree_grid_search.best_score_
print(f"Cross-Validation Concordance Index: {optimal_stree_heart_failure_y_crossvalidation_ci}")
Cross-Validation Concordance Index: 0.6541686643258245
##################################
# Formulating a Survival Tree Model
# with optimal hyperparameters
##################################
optimal_stree_model = stree_grid_search.best_estimator_
optimal_stree_model.fit(X_train, y_train_array)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('stree', SurvivalTree(min_samples_leaf=6, min_samples_split=20, random_state=88888888))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('stree', SurvivalTree(min_samples_leaf=6, min_samples_split=20, random_state=88888888))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
SurvivalTree(min_samples_leaf=6, min_samples_split=20, random_state=88888888)
##################################
# Measuring model performance of the
# optimal Survival Tree Model
# on the train set
##################################
optimal_stree_heart_failure_y_train_pred = optimal_stree_model.predict(X_train)
optimal_stree_heart_failure_y_train_ci = concordance_index_censored(y_train_array['DEATH_EVENT'],
y_train_array['TIME'],
optimal_stree_heart_failure_y_train_pred)[0]
print(f"Apparent Concordance Index: {optimal_stree_heart_failure_y_train_ci}")
Apparent Concordance Index: 0.7992339610596872
##################################
# Measuring model performance of the
# optimal Survival Tree Model
# on the validation set
##################################
optimal_stree_heart_failure_y_validation_pred = optimal_stree_model.predict(X_validation)
optimal_stree_heart_failure_y_validation_ci = concordance_index_censored(y_validation_array['DEATH_EVENT'],
y_validation_array['TIME'],
optimal_stree_heart_failure_y_validation_pred)[0]
print(f"Validation Concordance Index: {optimal_stree_heart_failure_y_validation_ci}")
Validation Concordance Index: 0.6446111869031378
##################################
# Gathering the concordance indices
# from the train and tests sets for
# Survival Tree Model
##################################
stree_set = pd.DataFrame(["Train","Cross-Validation","Validation"])
stree_ci_values = pd.DataFrame([optimal_stree_heart_failure_y_train_ci,
optimal_stree_heart_failure_y_crossvalidation_ci,
optimal_stree_heart_failure_y_validation_ci])
stree_method = pd.DataFrame(["STREE"]*3)
stree_summary = pd.concat([stree_set,
stree_ci_values,
stree_method], axis=1)
stree_summary.columns = ['Set', 'Concordance.Index', 'Method']
stree_summary.reset_index(inplace=True, drop=True)
display(stree_summary)
Set | Concordance.Index | Method | |
---|---|---|---|
0 | Train | 0.799234 | STREE |
1 | Cross-Validation | 0.654169 | STREE |
2 | Validation | 0.644611 | STREE |
##################################
# Binning the predicted risks
# into dichotomous groups and
# exploring the relationships with
# survival event and duration
##################################
heart_failure_validation.reset_index(drop=True, inplace=True)
kmf = KaplanMeierFitter()
heart_failure_validation['Predicted_Risks_STree'] = optimal_stree_heart_failure_y_validation_pred
heart_failure_validation['Predicted_RiskGroups_STree'] = risk_groups = pd.qcut(heart_failure_validation['Predicted_Risks_STree'], 2, labels=['Low-Risk', 'High-Risk'])
plt.figure(figsize=(17, 8))
for group in risk_groups.unique():
group_data = heart_failure_validation[risk_groups == group]
kmf.fit(group_data['TIME'], event_observed=group_data['DEATH_EVENT'], label=group)
kmf.plot_survival_function()
plt.title('STREE Survival Probabilities by Predicted Risk Groups on Validation Set')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.show()
##################################
# Gathering the predictor information
# for 5 test case samples
##################################
validation_case_details = X_validation.iloc[[5, 10, 15, 20, 25]]
display(validation_case_details)
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
291 | 60.0 | 0 | 35.0 | 0 | 1.4 | 139.0 |
66 | 42.0 | 1 | 15.0 | 0 | 1.3 | 136.0 |
112 | 50.0 | 0 | 25.0 | 0 | 1.6 | 136.0 |
89 | 57.0 | 1 | 25.0 | 1 | 1.1 | 144.0 |
17 | 45.0 | 0 | 14.0 | 0 | 0.8 | 127.0 |
##################################
# Gathering the event and duration information
# for 5 test case samples
##################################
print(y_validation_array[[5, 10, 15, 20, 25]])
[(False, 258) ( True, 65) (False, 90) (False, 79) ( True, 14)]
##################################
# Gathering the risk-groups
# for 5 test case samples
##################################
print(heart_failure_validation.loc[[5, 10, 15, 20, 25]][['Predicted_RiskGroups_STree']])
Predicted_RiskGroups_STree 5 Low-Risk 10 High-Risk 15 High-Risk 20 High-Risk 25 High-Risk
##################################
# Estimating the cumulative hazard
# and survival functions
# for 5 validation cases
##################################
validation_case = X_validation.iloc[[5, 10, 15, 20, 25]]
validation_case_labels = ['Patient_5','Patient_10','Patient_15','Patient_20','Patient_25',]
validation_case_cumulative_hazard_function = optimal_stree_model.predict_cumulative_hazard_function(validation_case)
validation_case_survival_function = optimal_stree_model.predict_survival_function(validation_case)
fig, ax = plt.subplots(1,2,figsize=(17, 8))
for hazard_prediction, survival_prediction in zip(validation_case_cumulative_hazard_function, validation_case_survival_function):
ax[0].step(hazard_prediction.x,hazard_prediction(hazard_prediction.x),where='post')
ax[1].step(survival_prediction.x,survival_prediction(survival_prediction.x),where='post')
ax[0].set_title('STREE Cumulative Hazard for 5 Validation Cases')
ax[0].set_xlabel('TIME')
ax[0].set_ylim(0,2)
ax[0].set_ylabel('Cumulative Hazard')
ax[0].legend(validation_case_labels, loc="upper left")
ax[1].set_title('STREE Survival Function for 5 Validation Cases')
ax[1].set_xlabel('TIME')
ax[1].set_ylabel('DEATH_EVENT Survival Probability')
ax[1].legend(validation_case_labels, loc="lower left")
plt.show()
##################################
# Saving the best Survival Tree Model
# developed from the original training data
##################################
joblib.dump(stree_best_model_train_cv,
os.path.join("..", MODELS_PATH, "stree_best_model.pkl"))
['..\\models\\stree_best_model.pkl']
1.6.7 Random Survival Forest Model Fitting | Hyperparameter Tuning | Validation ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Random Survival Forest is an ensemble method that builds multiple survival trees and averages their predictions. The model combines the predictions of multiple survival trees, each built on a bootstrap sample of the data and a random subset of predictors. It uses the concept of ensemble learning to improve predictive accuracy and robustness. As a method, it handles high-dimensional data and complex interactions between variables well; can be more accurate and robust than a single survival tree; and provides measures of variable importance. However, the process can be bomputationally intensive due to the need to build multiple trees, and may be less interpretable than single trees or parametric models like the Cox model. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves generating multiple bootstrap samples from the original dataset; building a survival tree by recursively splitting the data at nodes using a random subset of predictor variables for each bootstrap sample; combining the predictions of all survival trees to form the random survival forest and averaging the survival functions predicted by all trees in the forest to obtain the final survival function for new data.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- The random survival forest model from the sksurv.ensemble Python library API was implemented.
- The model implementation used 2 hyperparameters:
- n_estimators = number of trees in the forest made to vary between 100, 200 and 300
- min_samples_split = minimum number of samples required to split an internal node made to vary between 10, 15 and 20
- Hyperparameter tuning was conducted using the 5-fold cross-validation method repeated 5 times with optimal model performance using the concordance index determined for:
- n_estimators = 300
- min_samples_split = 10
- The cross-validated model performance of the optimal model is summarized as follows:
- Concordance Index = 0.7091
- The apparent model performance of the optimal model is summarized as follows:
- Concordance Index = 0.8714
- The independent test model performance of the final model is summarized as follows:
- Concordance Index = 0.6930
- Significant difference in the apparent and cross-validated model performance observed, indicative of the presence of excessive model overfitting.
- Survival probability curves obtained from the groups generated by dichotomizing the risk scores demonstrated sufficient differentiation across the entire duration.
- Hazard and survival probability estimations for 5 sampled cases demonstrated reasonably smooth profiles.
##################################
# Performing hyperparameter tuning
# through K-fold cross-validation
# using the Random Survival Forest Model
##################################
rsf_grid_search.fit(X_train, y_train_array)
Fitting 25 folds for each of 9 candidates, totalling 225 fits
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('rsf', RandomSurvivalForest())]), n_jobs=-1, param_grid={'rsf__min_samples_split': [10, 15, 20], 'rsf__n_estimators': [100, 200, 300], 'rsf__random_state': [88888888]}, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('rsf', RandomSurvivalForest())]), n_jobs=-1, param_grid={'rsf__min_samples_split': [10, 15, 20], 'rsf__n_estimators': [100, 200, 300], 'rsf__random_state': [88888888]}, verbose=1)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('rsf', RandomSurvivalForest(min_samples_split=10, n_estimators=300, random_state=88888888))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
RandomSurvivalForest(min_samples_split=10, n_estimators=300, random_state=88888888)
##################################
# Summarizing the hyperparameter tuning
# results from K-fold cross-validation
##################################
rsf_grid_search_results = pd.DataFrame(rsf_grid_search.cv_results_).sort_values(by='mean_test_score', ascending=False)
rsf_grid_search_results.loc[:, ~rsf_grid_search_results.columns.str.endswith('_time')]
param_rsf__min_samples_split | param_rsf__n_estimators | param_rsf__random_state | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | split5_test_score | ... | split18_test_score | split19_test_score | split20_test_score | split21_test_score | split22_test_score | split23_test_score | split24_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2 | 10 | 300 | 88888888 | {'rsf__min_samples_split': 10, 'rsf__n_estimat... | 0.737003 | 0.687307 | 0.744898 | 0.840000 | 0.585714 | 0.650307 | ... | 0.660 | 0.701887 | 0.772321 | 0.759434 | 0.727520 | 0.598291 | 0.671875 | 0.709097 | 0.068130 | 1 |
8 | 20 | 300 | 88888888 | {'rsf__min_samples_split': 20, 'rsf__n_estimat... | 0.730887 | 0.705882 | 0.734694 | 0.831111 | 0.592857 | 0.647239 | ... | 0.675 | 0.698113 | 0.758929 | 0.754717 | 0.741144 | 0.598291 | 0.671875 | 0.707597 | 0.066923 | 2 |
0 | 10 | 100 | 88888888 | {'rsf__min_samples_split': 10, 'rsf__n_estimat... | 0.755352 | 0.690402 | 0.755102 | 0.844444 | 0.578571 | 0.647239 | ... | 0.665 | 0.713208 | 0.772321 | 0.754717 | 0.727520 | 0.581197 | 0.640625 | 0.707268 | 0.074312 | 3 |
1 | 10 | 200 | 88888888 | {'rsf__min_samples_split': 10, 'rsf__n_estimat... | 0.743119 | 0.684211 | 0.739796 | 0.835556 | 0.592857 | 0.644172 | ... | 0.660 | 0.705660 | 0.758929 | 0.754717 | 0.732970 | 0.594017 | 0.651042 | 0.707263 | 0.068872 | 4 |
5 | 15 | 300 | 88888888 | {'rsf__min_samples_split': 15, 'rsf__n_estimat... | 0.733945 | 0.708978 | 0.739796 | 0.831111 | 0.578571 | 0.647239 | ... | 0.675 | 0.701887 | 0.754464 | 0.754717 | 0.741144 | 0.594017 | 0.666667 | 0.706527 | 0.068454 | 5 |
7 | 20 | 200 | 88888888 | {'rsf__min_samples_split': 20, 'rsf__n_estimat... | 0.730887 | 0.702786 | 0.734694 | 0.835556 | 0.578571 | 0.653374 | ... | 0.675 | 0.683019 | 0.758929 | 0.759434 | 0.732970 | 0.598291 | 0.666667 | 0.706351 | 0.067533 | 6 |
6 | 20 | 100 | 88888888 | {'rsf__min_samples_split': 20, 'rsf__n_estimat... | 0.740061 | 0.699690 | 0.739796 | 0.840000 | 0.557143 | 0.650307 | ... | 0.670 | 0.698113 | 0.772321 | 0.745283 | 0.722071 | 0.594017 | 0.671875 | 0.706127 | 0.069620 | 7 |
4 | 15 | 200 | 88888888 | {'rsf__min_samples_split': 15, 'rsf__n_estimat... | 0.737003 | 0.690402 | 0.739796 | 0.835556 | 0.571429 | 0.653374 | ... | 0.675 | 0.690566 | 0.758929 | 0.768868 | 0.735695 | 0.598291 | 0.666667 | 0.704593 | 0.071233 | 8 |
3 | 15 | 100 | 88888888 | {'rsf__min_samples_split': 15, 'rsf__n_estimat... | 0.740061 | 0.684211 | 0.755102 | 0.840000 | 0.571429 | 0.644172 | ... | 0.665 | 0.698113 | 0.745536 | 0.745283 | 0.732970 | 0.585470 | 0.661458 | 0.703792 | 0.074270 | 9 |
9 rows × 32 columns
##################################
# Identifying the best model
##################################
rsf_best_model_train_cv = rsf_grid_search.best_estimator_
print('Best Random Survival Forest Model using the Cross-Validated Train Data: ')
print(f"Best Model Parameters: {rsf_grid_search.best_params_}")
Best Random Survival Forest Model using the Cross-Validated Train Data: Best Model Parameters: {'rsf__min_samples_split': 10, 'rsf__n_estimators': 300, 'rsf__random_state': 88888888}
##################################
# Obtaining the cross-validation model performance of the
# optimal Random Survival Forest Model
# on the train set
##################################
optimal_rsf_heart_failure_y_crossvalidation_ci = rsf_grid_search.best_score_
print(f"Cross-Validation Concordance Index: {optimal_rsf_heart_failure_y_crossvalidation_ci}")
Cross-Validation Concordance Index: 0.7090965292195327
##################################
# Formulating a Random Survival Forest Model
# with optimal hyperparameters
##################################
optimal_rsf_model = rsf_grid_search.best_estimator_
optimal_rsf_model.fit(X_train, y_train_array)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('rsf', RandomSurvivalForest(min_samples_split=10, n_estimators=300, random_state=88888888))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('rsf', RandomSurvivalForest(min_samples_split=10, n_estimators=300, random_state=88888888))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
RandomSurvivalForest(min_samples_split=10, n_estimators=300, random_state=88888888)
##################################
# Measuring model performance of the
# optimal Random Survival Forest Model
# on the train set
##################################
optimal_rsf_heart_failure_y_train_pred = optimal_rsf_model.predict(X_train)
optimal_rsf_heart_failure_y_train_ci = concordance_index_censored(y_train_array['DEATH_EVENT'],
y_train_array['TIME'],
optimal_rsf_heart_failure_y_train_pred)[0]
print(f"Apparent Concordance Index: {optimal_rsf_heart_failure_y_train_ci}")
Apparent Concordance Index: 0.8713692946058091
##################################
# Measuring model performance of the
# optimal Random Survival Forest Model
# on the validation set
##################################
optimal_rsf_heart_failure_y_validation_pred = optimal_rsf_model.predict(X_validation)
optimal_rsf_heart_failure_y_validation_ci = concordance_index_censored(y_validation_array['DEATH_EVENT'],
y_validation_array['TIME'],
optimal_rsf_heart_failure_y_validation_pred)[0]
print(f"Validation Concordance Index: {optimal_rsf_heart_failure_y_validation_ci}")
Validation Concordance Index: 0.6930422919508867
##################################
# Gathering the concordance indices
# from the train and tests sets for
# Random Survival Forest Model
##################################
rsf_set = pd.DataFrame(["Train","Cross-Validation","Validation"])
rsf_ci_values = pd.DataFrame([optimal_rsf_heart_failure_y_train_ci,
optimal_rsf_heart_failure_y_crossvalidation_ci,
optimal_rsf_heart_failure_y_validation_ci])
rsf_method = pd.DataFrame(["RSF"]*3)
rsf_summary = pd.concat([rsf_set,
rsf_ci_values,
rsf_method], axis=1)
rsf_summary.columns = ['Set', 'Concordance.Index', 'Method']
rsf_summary.reset_index(inplace=True, drop=True)
display(rsf_summary)
Set | Concordance.Index | Method | |
---|---|---|---|
0 | Train | 0.871369 | RSF |
1 | Cross-Validation | 0.709097 | RSF |
2 | Validation | 0.693042 | RSF |
##################################
# Binning the predicted risks
# into dichotomous groups and
# exploring the relationships with
# survival event and duration
##################################
heart_failure_validation.reset_index(drop=True, inplace=True)
kmf = KaplanMeierFitter()
heart_failure_validation['Predicted_Risks_RSF'] = optimal_rsf_heart_failure_y_validation_pred
heart_failure_validation['Predicted_RiskGroups_RSF'] = risk_groups = pd.qcut(heart_failure_validation['Predicted_Risks_RSF'], 2, labels=['Low-Risk', 'High-Risk'])
plt.figure(figsize=(17, 8))
for group in risk_groups.unique():
group_data = heart_failure_validation[risk_groups == group]
kmf.fit(group_data['TIME'], event_observed=group_data['DEATH_EVENT'], label=group)
kmf.plot_survival_function()
plt.title('RSF Survival Probabilities by Predicted Risk Groups on Validation Set')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.show()
##################################
# Gathering the predictor information
# for 5 test case samples
##################################
validation_case_details = X_validation.iloc[[5, 10, 15, 20, 25]]
display(validation_case_details)
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
291 | 60.0 | 0 | 35.0 | 0 | 1.4 | 139.0 |
66 | 42.0 | 1 | 15.0 | 0 | 1.3 | 136.0 |
112 | 50.0 | 0 | 25.0 | 0 | 1.6 | 136.0 |
89 | 57.0 | 1 | 25.0 | 1 | 1.1 | 144.0 |
17 | 45.0 | 0 | 14.0 | 0 | 0.8 | 127.0 |
##################################
# Gathering the event and duration information
# for 5 test case samples
##################################
print(y_validation_array[[5, 10, 15, 20, 25]])
[(False, 258) ( True, 65) (False, 90) (False, 79) ( True, 14)]
##################################
# Gathering the risk-groups
# for 5 test case samples
##################################
print(heart_failure_validation.loc[[5, 10, 15, 20, 25]][['Predicted_RiskGroups_RSF']])
Predicted_RiskGroups_RSF 5 Low-Risk 10 High-Risk 15 High-Risk 20 High-Risk 25 High-Risk
##################################
# Estimating the cumulative hazard
# and survival functions
# for 5 validation cases
##################################
validation_case = X_validation.iloc[[5, 10, 15, 20, 25]]
validation_case_labels = ['Patient_5','Patient_10','Patient_15','Patient_20','Patient_25',]
validation_case_cumulative_hazard_function = optimal_rsf_model.predict_cumulative_hazard_function(validation_case)
validation_case_survival_function = optimal_rsf_model.predict_survival_function(validation_case)
fig, ax = plt.subplots(1,2,figsize=(17, 8))
for hazard_prediction, survival_prediction in zip(validation_case_cumulative_hazard_function, validation_case_survival_function):
ax[0].step(hazard_prediction.x,hazard_prediction(hazard_prediction.x),where='post')
ax[1].step(survival_prediction.x,survival_prediction(survival_prediction.x),where='post')
ax[0].set_title('RSF Cumulative Hazard for 5 Validation Cases')
ax[0].set_xlabel('TIME')
ax[0].set_ylim(0,2)
ax[0].set_ylabel('Cumulative Hazard')
ax[0].legend(validation_case_labels, loc="upper left")
ax[1].set_title('RSF Survival Function for 5 Validation Cases')
ax[1].set_xlabel('TIME')
ax[1].set_ylabel('DEATH_EVENT Survival Probability')
ax[1].legend(validation_case_labels, loc="lower left")
plt.show()
##################################
# Saving the best Random Survival Forest Model
# developed from the original training data
##################################
joblib.dump(rsf_best_model_train_cv,
os.path.join("..", MODELS_PATH, "rsf_best_model.pkl"))
['..\\models\\rsf_best_model.pkl']
1.6.8 Gradient Boosted Survival Model Fitting | Hyperparameter Tuning | Validation ¶
Survival Analysis deals with the analysis of time-to-event data. It focuses on the expected duration of time until one or more events of interest occur, such as death, failure, or relapse. This type of analysis is used to study and model the time until the occurrence of an event, taking into account that the event might not have occurred for all subjects during the study period. Several key aspects of survival analysis include the survival function which refers to the probability that an individual survives longer than a certain time, hazard function which describes the instantaneous rate at which events occur, given no prior event, and censoring pertaining to a condition where the event of interest has not occurred for some subjects during the observation period.
Right-Censored Survival Data occurs when the event of interest has not happened for some subjects by the end of the study period or the last follow-up time. This type of censoring is common in survival analysis because not all individuals may experience the event before the study ends, or they might drop out or be lost to follow-up. Right-censored data is crucial in survival analysis as it allows the inclusion of all subjects in the analysis, providing more accurate and reliable estimates.
Survival Models refer to statistical methods used to analyze survival data, accounting for censored observations. These models aim to describe the relationship between survival time and one or more predictor variables, and to estimate the survival function and hazard function. Survival models are essential for understanding the factors that influence time-to-event data, allowing for predictions and comparisons between different groups or treatment effects. They are widely used in clinical trials, reliability engineering, and other areas where time-to-event data is prevalent.
Gradient Boosted Survival is an ensemble technique that builds a series of survival trees, where each tree tries to correct the errors of the previous one. The model uses boosting, a sequential technique where each new tree is fit to the residuals of the combined previous trees, and combines the predictions of all the trees to produce a final prediction. As a method, it has high predictive accuracy, the ability to model complex relationships, and reduces bias and variance compared to single-tree models. However, the process can even be more computationally intensive than Random Survival Forest, requires careful tuning of multiple hyperparameters, and makes interpretation challenging due to the complex nature of the model. Given a dataset with survival times, event indicators, and predictor variables, the algorithm involves starting with an initial prediction (often the median survival time or a simple model); calculating the residuals (errors) of the current model's predictions; fitting a survival tree to the residuals to learn the errors made by the current model; updating the current model by adding the new tree weighted by a learning rate parameter; repeating previous steps for a fixed number of iterations or until convergence; and summing the predictions of all trees in the sequence to obtain the final survival function for new data.
Concordance Index measures the model's ability to correctly order pairs of observations based on their predicted survival times. Values range from 0.5 to 1.0 indicating no predictive power (random guessing) and perfect predictions, respectively. As a metric, it provides a measure of discriminative ability and useful for ranking predictions. However, it does not provide information on the magnitude of errors and may be insensitive to the calibration of predicted survival probabilities.
- The gradient boosted survival model from the sksurv.ensemble Python library API was implemented.
- The model implementation used 2 hyperparameters:
- n_estimators = number of regression trees to create made to vary between 100, 200 and 300
- learning_rate = shrinkage parameter for the contribution of each tree made to vary between 0.05, 0.10 and 0.15
- Hyperparameter tuning was conducted using the 5-fold cross-validation method repeated 5 times with optimal model performance using the concordance index determined for:
- n_estimators = 200
- learning_rate = 0.10
- The cross-validated model performance of the optimal model is summarized as follows:
- Concordance Index = 0.6765
- The apparent model performance of the optimal model is summarized as follows:
- Concordance Index = 0.9275
- The independent test model performance of the final model is summarized as follows:
- Concordance Index = 0.6575
- Significant difference in the apparent and cross-validated model performance observed, indicative of the presence of excessive model overfitting.
- Survival probability curves obtained from the groups generated by dichotomizing the risk scores demonstrated sufficient differentiation across the entire duration.
- Hazard and survival probability estimations for 5 sampled cases demonstrated reasonably smooth profiles.
##################################
# Performing hyperparameter tuning
# through K-fold cross-validation
# using the Gradient Boosted Survival Model
##################################
gbs_grid_search.fit(X_train, y_train_array)
Fitting 25 folds for each of 9 candidates, totalling 225 fits
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('gbs', GradientBoostingSurvivalAnalysis())]), n_jobs=-1, param_grid={'gbs__learning_rate': [0.05, 0.1, 0.15], 'gbs__n_estimators': [100, 200, 300], 'gbs__random_state': [88888888]}, verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
GridSearchCV(cv=RepeatedKFold(n_repeats=5, n_splits=5, random_state=88888888), estimator=Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('gbs', GradientBoostingSurvivalAnalysis())]), n_jobs=-1, param_grid={'gbs__learning_rate': [0.05, 0.1, 0.15], 'gbs__n_estimators': [100, 200, 300], 'gbs__random_state': [88888888]}, verbose=1)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('gbs', GradientBoostingSurvivalAnalysis(n_estimators=200, random_state=88888888))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
GradientBoostingSurvivalAnalysis(n_estimators=200, random_state=88888888)
##################################
# Summarizing the hyperparameter tuning
# results from K-fold cross-validation
##################################
gbs_grid_search_results = pd.DataFrame(gbs_grid_search.cv_results_).sort_values(by='mean_test_score', ascending=False)
gbs_grid_search_results.loc[:, ~gbs_grid_search_results.columns.str.endswith('_time')]
param_gbs__learning_rate | param_gbs__n_estimators | param_gbs__random_state | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | split5_test_score | ... | split18_test_score | split19_test_score | split20_test_score | split21_test_score | split22_test_score | split23_test_score | split24_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
4 | 0.10 | 200 | 88888888 | {'gbs__learning_rate': 0.1, 'gbs__n_estimators... | 0.694190 | 0.712074 | 0.704082 | 0.782222 | 0.564286 | 0.604294 | ... | 0.620 | 0.656604 | 0.736607 | 0.712264 | 0.697548 | 0.594017 | 0.692708 | 0.676537 | 0.059798 | 1 |
7 | 0.15 | 200 | 88888888 | {'gbs__learning_rate': 0.15, 'gbs__n_estimator... | 0.700306 | 0.718266 | 0.704082 | 0.760000 | 0.571429 | 0.592025 | ... | 0.630 | 0.679245 | 0.709821 | 0.721698 | 0.711172 | 0.594017 | 0.677083 | 0.675917 | 0.055144 | 2 |
5 | 0.10 | 300 | 88888888 | {'gbs__learning_rate': 0.1, 'gbs__n_estimators... | 0.678899 | 0.705882 | 0.698980 | 0.773333 | 0.564286 | 0.604294 | ... | 0.625 | 0.656604 | 0.723214 | 0.716981 | 0.697548 | 0.594017 | 0.708333 | 0.675460 | 0.054978 | 3 |
2 | 0.05 | 300 | 88888888 | {'gbs__learning_rate': 0.05, 'gbs__n_estimator... | 0.712538 | 0.708978 | 0.688776 | 0.788889 | 0.550000 | 0.598160 | ... | 0.615 | 0.664151 | 0.709821 | 0.707547 | 0.700272 | 0.606838 | 0.700521 | 0.675307 | 0.059465 | 4 |
8 | 0.15 | 300 | 88888888 | {'gbs__learning_rate': 0.15, 'gbs__n_estimator... | 0.703364 | 0.705882 | 0.714286 | 0.742222 | 0.557143 | 0.579755 | ... | 0.630 | 0.652830 | 0.709821 | 0.726415 | 0.727520 | 0.581197 | 0.651042 | 0.673835 | 0.058635 | 5 |
3 | 0.10 | 100 | 88888888 | {'gbs__learning_rate': 0.1, 'gbs__n_estimators... | 0.732416 | 0.705882 | 0.660714 | 0.777778 | 0.521429 | 0.576687 | ... | 0.605 | 0.673585 | 0.727679 | 0.683962 | 0.673025 | 0.602564 | 0.671875 | 0.669783 | 0.060433 | 6 |
6 | 0.15 | 100 | 88888888 | {'gbs__learning_rate': 0.15, 'gbs__n_estimator... | 0.700306 | 0.696594 | 0.683673 | 0.768889 | 0.564286 | 0.604294 | ... | 0.600 | 0.660377 | 0.714286 | 0.693396 | 0.683924 | 0.619658 | 0.666667 | 0.669566 | 0.056325 | 7 |
1 | 0.05 | 200 | 88888888 | {'gbs__learning_rate': 0.05, 'gbs__n_estimator... | 0.737003 | 0.696594 | 0.668367 | 0.777778 | 0.492857 | 0.576687 | ... | 0.600 | 0.660377 | 0.714286 | 0.688679 | 0.678474 | 0.606838 | 0.684896 | 0.669083 | 0.063933 | 8 |
0 | 0.05 | 100 | 88888888 | {'gbs__learning_rate': 0.05, 'gbs__n_estimator... | 0.723242 | 0.662539 | 0.665816 | 0.760000 | 0.500000 | 0.536810 | ... | 0.620 | 0.635849 | 0.705357 | 0.681604 | 0.675749 | 0.574786 | 0.653646 | 0.658122 | 0.067976 | 9 |
9 rows × 32 columns
##################################
# Identifying the best model
##################################
gbs_best_model_train_cv = gbs_grid_search.best_estimator_
print('Best Gradient Boosted Survival Model using the Cross-Validated Train Data: ')
print(f"Best Model Parameters: {gbs_grid_search.best_params_}")
Best Gradient Boosted Survival Model using the Cross-Validated Train Data: Best Model Parameters: {'gbs__learning_rate': 0.1, 'gbs__n_estimators': 200, 'gbs__random_state': 88888888}
##################################
# Obtaining the cross-validation model performance of the
# optimal Gradient Boosted Survival Model
# on the train set
##################################
optimal_gbs_heart_failure_y_crossvalidation_ci = gbs_grid_search.best_score_
print(f"Cross-Validation Concordance Index: {optimal_gbs_heart_failure_y_crossvalidation_ci}")
Cross-Validation Concordance Index: 0.6765369976540313
##################################
# Formulating a Gradient Boosted Survival Model
# with optimal hyperparameters
##################################
optimal_gbs_model = gbs_grid_search.best_estimator_
optimal_gbs_model.fit(X_train, y_train_array)
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('gbs', GradientBoostingSurvivalAnalysis(n_estimators=200, random_state=88888888))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('yeo_johnson', ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])), ('gbs', GradientBoostingSurvivalAnalysis(n_estimators=200, random_state=88888888))])
ColumnTransformer(remainder='passthrough', transformers=[('numeric_predictors', PowerTransformer(), ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM'])])
['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']
PowerTransformer()
['ANAEMIA', 'HIGH_BLOOD_PRESSURE']
passthrough
GradientBoostingSurvivalAnalysis(n_estimators=200, random_state=88888888)
##################################
# Measuring model performance of the
# optimal Gradient Boosted Survival Model
# on the train set
##################################
optimal_gbs_heart_failure_y_train_pred = optimal_gbs_model.predict(X_train)
optimal_gbs_heart_failure_y_train_ci = concordance_index_censored(y_train_array['DEATH_EVENT'],
y_train_array['TIME'],
optimal_gbs_heart_failure_y_train_pred)[0]
print(f"Apparent Concordance Index: {optimal_gbs_heart_failure_y_train_ci}")
Apparent Concordance Index: 0.9274656878391319
##################################
# Measuring model performance of the
# optimal Gradient Boosted Survival Model
# on the validation set
##################################
optimal_gbs_heart_failure_y_validation_pred = optimal_gbs_model.predict(X_validation)
optimal_gbs_heart_failure_y_validation_ci = concordance_index_censored(y_validation_array['DEATH_EVENT'],
y_validation_array['TIME'],
optimal_gbs_heart_failure_y_validation_pred)[0]
print(f"Validation Concordance Index: {optimal_gbs_heart_failure_y_validation_ci}")
Validation Concordance Index: 0.6575716234652115
##################################
# Gathering the concordance indices
# from the train and tests sets for
# Gradient Boosted Survival Model
##################################
gbs_set = pd.DataFrame(["Train","Cross-Validation","Validation"])
gbs_ci_values = pd.DataFrame([optimal_gbs_heart_failure_y_train_ci,
optimal_gbs_heart_failure_y_crossvalidation_ci,
optimal_gbs_heart_failure_y_validation_ci])
gbs_method = pd.DataFrame(["GBS"]*3)
gbs_summary = pd.concat([gbs_set,
gbs_ci_values,
gbs_method], axis=1)
gbs_summary.columns = ['Set', 'Concordance.Index', 'Method']
gbs_summary.reset_index(inplace=True, drop=True)
display(gbs_summary)
Set | Concordance.Index | Method | |
---|---|---|---|
0 | Train | 0.927466 | GBS |
1 | Cross-Validation | 0.676537 | GBS |
2 | Validation | 0.657572 | GBS |
##################################
# Binning the predicted risks
# into dichotomous groups and
# exploring the relationships with
# survival event and duration
##################################
heart_failure_validation.reset_index(drop=True, inplace=True)
kmf = KaplanMeierFitter()
heart_failure_validation['Predicted_Risks_GBS'] = optimal_gbs_heart_failure_y_validation_pred
heart_failure_validation['Predicted_RiskGroups_GBS'] = risk_groups = pd.qcut(heart_failure_validation['Predicted_Risks_GBS'], 2, labels=['Low-Risk', 'High-Risk'])
plt.figure(figsize=(17, 8))
for group in risk_groups.unique():
group_data = heart_failure_validation[risk_groups == group]
kmf.fit(group_data['TIME'], event_observed=group_data['DEATH_EVENT'], label=group)
kmf.plot_survival_function()
plt.title('GBS Survival Probabilities by Predicted Risk Groups on Validation Set')
plt.xlabel('TIME')
plt.ylabel('DEATH_EVENT Survival Probability')
plt.show()
##################################
# Gathering the predictor information
# for 5 test case samples
##################################
validation_case_details = X_validation.iloc[[5, 10, 15, 20, 25]]
display(validation_case_details)
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
291 | 60.0 | 0 | 35.0 | 0 | 1.4 | 139.0 |
66 | 42.0 | 1 | 15.0 | 0 | 1.3 | 136.0 |
112 | 50.0 | 0 | 25.0 | 0 | 1.6 | 136.0 |
89 | 57.0 | 1 | 25.0 | 1 | 1.1 | 144.0 |
17 | 45.0 | 0 | 14.0 | 0 | 0.8 | 127.0 |
##################################
# Gathering the event and duration information
# for 5 test case samples
##################################
print(y_validation_array[[5, 10, 15, 20, 25]])
[(False, 258) ( True, 65) (False, 90) (False, 79) ( True, 14)]
##################################
# Gathering the risk-groups
# for 5 test case samples
##################################
print(heart_failure_validation.loc[[5, 10, 15, 20, 25]][['Predicted_RiskGroups_GBS']])
Predicted_RiskGroups_GBS 5 Low-Risk 10 High-Risk 15 High-Risk 20 High-Risk 25 High-Risk
##################################
# Estimating the cumulative hazard
# and survival functions
# for 5 validation cases
##################################
validation_case = X_validation.iloc[[5, 10, 15, 20, 25]]
validation_case_labels = ['Patient_5','Patient_10','Patient_15','Patient_20','Patient_25',]
validation_case_cumulative_hazard_function = optimal_gbs_model.predict_cumulative_hazard_function(validation_case)
validation_case_survival_function = optimal_gbs_model.predict_survival_function(validation_case)
fig, ax = plt.subplots(1,2,figsize=(17, 8))
for hazard_prediction, survival_prediction in zip(validation_case_cumulative_hazard_function, validation_case_survival_function):
ax[0].step(hazard_prediction.x,hazard_prediction(hazard_prediction.x),where='post')
ax[1].step(survival_prediction.x,survival_prediction(survival_prediction.x),where='post')
ax[0].set_title('GBS Cumulative Hazard for 5 Validation Cases')
ax[0].set_xlabel('TIME')
ax[0].set_ylim(0,2)
ax[0].set_ylabel('Cumulative Hazard')
ax[0].legend(validation_case_labels, loc="upper left")
ax[1].set_title('GBS Survival Function for 5 Validation Cases')
ax[1].set_xlabel('TIME')
ax[1].set_ylabel('DEATH_EVENT Survival Probability')
ax[1].legend(validation_case_labels, loc="lower left")
plt.show()
##################################
# Saving the best Gradient Boosted Survival Model
# developed from the original training data
##################################
joblib.dump(gbs_best_model_train_cv,
os.path.join("..", MODELS_PATH, "gbs_best_model.pkl"))
['..\\models\\gbs_best_model.pkl']
1.6.9 Model Selection ¶
- The cox proportional hazards regression model was selected as the final model by demonstrating the best concordance index in the validation data with minimal overfitting between the apparent and cross-validated train data:
- train data (apparent) = 0.7394
- train data (cross-validated) = 0.7073
- validation data = 0.7419
- The optimal hyperparameters for the final model configuration was determined as follows:
- alpha = 10.00
- The cox net survival model also demonstrated comparably good survival prediction, but was not selected over the cox proportional hazards regression model due to model complexity.
- The survival tree model, random survival forest model, and gradient boosted survival model all showed conditions of overfitting as demonstrated by a considerable difference between the apparent and cross-validated concordance index values.
##################################
# Gathering the concordance indices from
# training, cross-validation and validation
##################################
set_labels = ['Train','Cross-Validation','Validation']
ci_plot = pd.DataFrame({'COXPH': list([optimal_coxph_heart_failure_y_train_ci,
optimal_coxph_heart_failure_y_crossvalidation_ci,
optimal_coxph_heart_failure_y_validation_ci]),
'COXNS': list([optimal_coxns_heart_failure_y_train_ci,
optimal_coxns_heart_failure_y_crossvalidation_ci,
optimal_coxns_heart_failure_y_validation_ci]),
'STREE': list([optimal_stree_heart_failure_y_train_ci,
optimal_stree_heart_failure_y_crossvalidation_ci,
optimal_stree_heart_failure_y_validation_ci]),
'RSF': list([optimal_rsf_heart_failure_y_train_ci,
optimal_rsf_heart_failure_y_crossvalidation_ci,
optimal_rsf_heart_failure_y_validation_ci]),
'GBS': list([optimal_gbs_heart_failure_y_train_ci,
optimal_gbs_heart_failure_y_crossvalidation_ci,
optimal_gbs_heart_failure_y_validation_ci])}, index = set_labels)
display(ci_plot)
COXPH | COXNS | STREE | RSF | GBS | |
---|---|---|---|---|---|
Train | 0.741941 | 0.741941 | 0.799234 | 0.871369 | 0.927466 |
Cross-Validation | 0.707318 | 0.701369 | 0.654169 | 0.709097 | 0.676537 |
Validation | 0.739427 | 0.729877 | 0.644611 | 0.693042 | 0.657572 |
##################################
# Plotting all the concordance indices
# for all models
##################################
ci_plot = ci_plot.plot.barh(figsize=(10, 6), width=0.90)
ci_plot.set_xlim(0.00,1.00)
ci_plot.set_title("Survival Prediction Model Comparison by Concordance Index")
ci_plot.set_xlabel("Concordance Index")
ci_plot.set_ylabel("Data Set")
ci_plot.grid(False)
ci_plot.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
for container in ci_plot.containers:
ci_plot.bar_label(container, fmt='%.5f', padding=-50, color='white', fontweight='bold')
1.6.10 Model Testing ¶
- The selected cox proportional hazards regression model demonstrated sufficient concordance index in the independent test data :
- train data (apparent) = 0.7394
- train data (cross-validated) = 0.7073
- validation data = 0.7419
- test data = 0.7064
- For benchmarking purposes, all candidate models were evaluated on the test data. Interestingly, the survival tree model, random survival forest model, and gradient boosted survival model performed better than the selected model. In this case, the inconsistent performance (poor on validation, good on test) might be an indicator of instability. The cox proportional hazards regression model model (and to some extent, the cox net survival model model), which shows more consistent performance across validation and test sets, is more reliable. Although, the selected model may not perform as well on the test set alone, its generalization across both validation and test sets makes it a more robust and stable choice in practice.
##################################
# Evaluating the concordance indices
# on the test data
##################################
optimal_coxph_heart_failure_y_test_ci = concordance_index_censored(y_test_array['DEATH_EVENT'],
y_test_array['TIME'],
optimal_coxph_model.predict(X_test))[0]
optimal_coxns_heart_failure_y_test_ci = concordance_index_censored(y_test_array['DEATH_EVENT'],
y_test_array['TIME'],
optimal_coxns_model.predict(X_test))[0]
optimal_stree_heart_failure_y_test_ci = concordance_index_censored(y_test_array['DEATH_EVENT'],
y_test_array['TIME'],
optimal_stree_model.predict(X_test))[0]
optimal_rsf_heart_failure_y_test_ci = concordance_index_censored(y_test_array['DEATH_EVENT'],
y_test_array['TIME'],
optimal_rsf_model.predict(X_test))[0]
optimal_gbs_heart_failure_y_test_ci = concordance_index_censored(y_test_array['DEATH_EVENT'],
y_test_array['TIME'],
optimal_gbs_model.predict(X_test))[0]
##################################
# Adding the the concordance index estimated
# from the test data
##################################
set_labels = ['Train','Cross-Validation','Validation','Test']
updated_ci_plot = pd.DataFrame({'COXPH': list([optimal_coxph_heart_failure_y_train_ci,
optimal_coxph_heart_failure_y_crossvalidation_ci,
optimal_coxph_heart_failure_y_validation_ci,
optimal_coxph_heart_failure_y_test_ci]),
'COXNS': list([optimal_coxns_heart_failure_y_train_ci,
optimal_coxns_heart_failure_y_crossvalidation_ci,
optimal_coxns_heart_failure_y_validation_ci,
optimal_coxns_heart_failure_y_test_ci]),
'STREE': list([optimal_stree_heart_failure_y_train_ci,
optimal_stree_heart_failure_y_crossvalidation_ci,
optimal_stree_heart_failure_y_validation_ci,
optimal_stree_heart_failure_y_test_ci]),
'RSF': list([optimal_rsf_heart_failure_y_train_ci,
optimal_rsf_heart_failure_y_crossvalidation_ci,
optimal_rsf_heart_failure_y_validation_ci,
optimal_rsf_heart_failure_y_test_ci]),
'GBS': list([optimal_gbs_heart_failure_y_train_ci,
optimal_gbs_heart_failure_y_crossvalidation_ci,
optimal_gbs_heart_failure_y_validation_ci,
optimal_gbs_heart_failure_y_test_ci])}, index = set_labels)
display(updated_ci_plot)
COXPH | COXNS | STREE | RSF | GBS | |
---|---|---|---|---|---|
Train | 0.741941 | 0.741941 | 0.799234 | 0.871369 | 0.927466 |
Cross-Validation | 0.707318 | 0.701369 | 0.654169 | 0.709097 | 0.676537 |
Validation | 0.739427 | 0.729877 | 0.644611 | 0.693042 | 0.657572 |
Test | 0.706422 | 0.719831 | 0.762526 | 0.760056 | 0.778758 |
##################################
# Plotting all the concordance indices
# for all models
##################################
updated_ci_plot = updated_ci_plot.plot.barh(figsize=(10, 8), width=0.90)
updated_ci_plot.set_xlim(0.00,1.00)
updated_ci_plot.set_title("Survival Prediction Model Comparison by Concordance Index")
updated_ci_plot.set_xlabel("Concordance Index")
updated_ci_plot.set_ylabel("Data Set")
updated_ci_plot.grid(False)
updated_ci_plot.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
for container in updated_ci_plot.containers:
updated_ci_plot.bar_label(container, fmt='%.5f', padding=-50, color='white', fontweight='bold')
1.6.11 Model Inference ¶
- For the final selected survival prediction model developed from the train data, the contributions of the predictors, ranked by importance, are given as follows:
- Cox proportional hazards regression model
- SERUM_CREATININE
- EJECTION_FRACTION
- SERUM_SODIUM
- ANAEMIA
- AGE
- HIGH_BLOOD_PRESSURE
- Cox proportional hazards regression model
- Model inference involved indicating the characteristics and predicting the survival probability of the new case against the model training observations.
- Characteristics based on all predictors used for generating the final selected survival prediction model
- Predicted heart failure survival probability profile based on the final selected survival prediction model
##################################
# Determining the Cox Proportional Hazards Regression model
# absolute coefficient-based feature importance
# on train data
##################################
coxph_train_feature_importance = pd.DataFrame(
{'Signed.Coefficient': optimal_coxph_model.named_steps['coxph'].coef_,
'Absolute.Coefficient': np.abs(optimal_coxph_model.named_steps['coxph'].coef_)}, index=X_train.columns)
display(coxph_train_feature_importance.sort_values('Absolute.Coefficient', ascending=False))
Signed.Coefficient | Absolute.Coefficient | |
---|---|---|
EJECTION_FRACTION | 0.407833 | 0.407833 |
SERUM_CREATININE | 0.352092 | 0.352092 |
ANAEMIA | -0.306170 | 0.306170 |
HIGH_BLOOD_PRESSURE | -0.280524 | 0.280524 |
AGE | 0.245804 | 0.245804 |
SERUM_SODIUM | 0.234638 | 0.234638 |
##################################
# Plotting the Cox Proportional Hazards Regression model
# absolute coefficient-based feature importance
# on train data
##################################
coxph_train_coefficient_importance_summary = coxph_train_feature_importance.sort_values('Absolute.Coefficient', ascending=True)
plt.figure(figsize=(17, 8))
plt.barh(coxph_train_coefficient_importance_summary.index, coxph_train_coefficient_importance_summary['Absolute.Coefficient'])
plt.xlabel('Predictor Contribution: Absolute Coefficient')
plt.ylabel('Predictor')
plt.title('Feature Importance - Final Survival Prediction Model: Cox Proportional Hazards Regression')
plt.tight_layout()
plt.show()
##################################
# Determining the Cox Proportional Hazards Regression model
# permutation-based feature importance
# on train data
##################################
coxph_train_feature_importance = permutation_importance(optimal_coxph_model,
X_train,
y_train_array,
n_repeats=15,
random_state=88888888)
coxph_train_feature_importance_summary = pd.DataFrame(
{k: coxph_train_feature_importance[k]
for k in ("importances_mean", "importances_std")},
index=X_train.columns).sort_values(by="importances_mean", ascending=False)
coxph_train_feature_importance_summary.columns = ['Importances.Mean', 'Importances.Std']
display(coxph_train_feature_importance_summary)
Importances.Mean | Importances.Std | |
---|---|---|
SERUM_CREATININE | 0.055362 | 0.017600 |
EJECTION_FRACTION | 0.032078 | 0.014803 |
SERUM_SODIUM | 0.023034 | 0.012928 |
ANAEMIA | 0.018449 | 0.009402 |
AGE | 0.017763 | 0.009435 |
HIGH_BLOOD_PRESSURE | 0.002915 | 0.004219 |
##################################
# Plotting the Cox Proportional Hazards Regression model
# absolute coefficient-based feature importance
# on train data
##################################
coxph_train_feature_importance_summary = coxph_train_feature_importance_summary.sort_values('Importances.Mean', ascending=True)
plt.figure(figsize=(17, 8))
plt.barh(coxph_train_feature_importance_summary.index, coxph_train_feature_importance_summary['Importances.Mean'])
plt.xlabel('Predictor Contribution: Permutation Importance')
plt.ylabel('Predictor')
plt.title('Feature Importance - Final Survival Prediction Model: Cox Proportional Hazards Regression')
plt.tight_layout()
plt.show()
##################################
# Rebuilding the training data
# for plotting kaplan-meier charts
##################################
X_train_indices = X_train.index.tolist()
heart_failure_MI = heart_failure_EDA.copy()
heart_failure_MI = heart_failure_MI.drop(['DIABETES','SEX', 'SMOKING', 'CREATININE_PHOSPHOKINASE','PLATELETS'], axis=1)
heart_failure_MI = heart_failure_MI.loc[X_train_indices]
heart_failure_MI.head()
AGE | EJECTION_FRACTION | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | HIGH_BLOOD_PRESSURE | DEATH_EVENT | TIME | |
---|---|---|---|---|---|---|---|---|
266 | -0.423454 | -1.773346 | 1.144260 | -0.689301 | Absent | Absent | True | 241.0 |
180 | -2.043070 | -0.633046 | -0.732811 | -0.244181 | Absent | Absent | False | 148.0 |
288 | 0.434332 | -0.160461 | -0.087641 | 1.348555 | Absent | Absent | False | 256.0 |
258 | -1.446547 | -1.163741 | -1.149080 | -0.471658 | Present | Absent | False | 230.0 |
236 | 1.173233 | 1.021735 | -0.087641 | 3.397822 | Absent | Present | False | 209.0 |
##################################
# Determining the medians for the numeric predictors
##################################
heart_failure_MI_numeric = heart_failure_MI[["AGE","EJECTION_FRACTION","SERUM_CREATININE","SERUM_SODIUM"]]
numeric_predictor_median_list = heart_failure_MI_numeric.median()
print("Numeric Predictor Median: ","\n", numeric_predictor_median_list)
Numeric Predictor Median: AGE 0.065124 EJECTION_FRACTION 0.100914 SERUM_CREATININE -0.087641 SERUM_SODIUM -0.006503 dtype: float64
##################################
# Saving the risk category threshold
# from the best Cox Proportional Hazards Regression Model
# developed from the original training data
##################################
joblib.dump(numeric_predictor_median_list,
os.path.join("..", PARAMETERS_PATH, "numeric_feature_median_list.pkl"))
['..\\parameters\\numeric_feature_median_list.pkl']
##################################
# Creating a function to bin
# numeric predictors into two groups
##################################
def bin_numeric_model_predictor(df, predictor):
median = numeric_predictor_median_list.loc[predictor]
df[predictor] = np.where(df[predictor] <= median, "Low", "High")
return df
##################################
# Binning the numeric predictors
# into two groups
##################################
for numeric_column in ["AGE","EJECTION_FRACTION","SERUM_CREATININE","SERUM_SODIUM"]:
heart_failure_MI_EDA = bin_numeric_model_predictor(heart_failure_MI, numeric_column)
##################################
# Exploring the transformed
# dataset for plotting
##################################
heart_failure_MI_EDA.head()
AGE | EJECTION_FRACTION | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | HIGH_BLOOD_PRESSURE | DEATH_EVENT | TIME | |
---|---|---|---|---|---|---|---|---|
266 | Low | Low | High | Low | Absent | Absent | True | 241.0 |
180 | Low | Low | Low | Low | Absent | Absent | False | 148.0 |
288 | High | Low | Low | High | Absent | Absent | False | 256.0 |
258 | Low | Low | Low | Low | Present | Absent | False | 230.0 |
236 | High | High | Low | High | Absent | Present | False | 209.0 |
##################################
# Defining a function to plot the
# estimated survival profiles
# using Kaplan-Meier Plots
##################################
def plot_kaplan_meier(df, cat_var, ax, new_case_value=None):
kmf = KaplanMeierFitter()
# Defining the color scheme for each category
if cat_var in ['AGE', 'EJECTION_FRACTION', 'SERUM_CREATININE', 'SERUM_SODIUM']:
categories = ['Low', 'High']
colors = {'Low': 'blue', 'High': 'red'}
else:
categories = ['Absent', 'Present']
colors = {'Absent': 'blue', 'Present': 'red'}
# Plotting each category with a partly red or blue transparent line
for value in categories:
mask = df[cat_var] == value
kmf.fit(df['TIME'][mask], event_observed=df['DEATH_EVENT'][mask], label=f'{cat_var}={value} (Baseline Distribution)')
kmf.plot_survival_function(ax=ax, ci_show=False, color=colors[str(value)], linestyle='-', linewidth=6.0, alpha=0.30)
# Overlaying a black broken line for the new case if provided
if new_case_value is not None:
mask_new_case = df[cat_var] == new_case_value
kmf.fit(df['TIME'][mask_new_case], event_observed=df['DEATH_EVENT'][mask_new_case], label=f'{cat_var}={new_case_value} (Test Case)')
kmf.plot_survival_function(ax=ax, ci_show=False, color='black', linestyle=':', linewidth=3.0)
##################################
# Plotting the estimated survival profiles
# of the model training data
# using Kaplan-Meier Plots
##################################
fig, axes = plt.subplots(3, 2, figsize=(17, 13))
heart_failure_predictors = ['AGE','EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM','ANAEMIA','HIGH_BLOOD_PRESSURE']
for i, predictor in enumerate(heart_failure_predictors):
ax = axes[i // 2, i % 2]
plot_kaplan_meier(heart_failure_MI_EDA, predictor, ax, new_case_value=None)
ax.set_title(f'DEATH_EVENT Survival Probabilities by {predictor} Categories')
ax.set_ylim(-0.05, 1.05)
ax.set_xlabel('Time (Days)')
ax.set_ylabel('Estimated Survival Probability')
ax.legend(loc='lower left')
plt.tight_layout()
plt.show()
##################################
# Estimating the survival functions
# for the training data
##################################
heart_failure_train_survival_function = optimal_coxph_model.predict_survival_function(X_train)
##################################
# Resetting the index for
# plotting survival functions
# for the training data
##################################
y_train_reset_index = y_train.reset_index()
##################################
# Plotting the baseline survival functions
# for the training data
##################################
plt.figure(figsize=(17, 8))
for i, surv_func in enumerate(heart_failure_train_survival_function):
plt.step(surv_func.x,
surv_func.y,
where="post",
color='red' if y_train_reset_index['DEATH_EVENT'][i] == 1 else 'blue',
linewidth=6.0,
alpha=0.05)
red_patch = plt.Line2D([0], [0], color='red', lw=6, alpha=0.30, label='Death Event Status = True')
blue_patch = plt.Line2D([0], [0], color='blue', lw=6, alpha=0.30, label='Death Event Status = False')
plt.legend(handles=[red_patch, blue_patch], facecolor='white', framealpha=1, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=3)
plt.title('Final Survival Prediction Model: Cox Proportional Hazards Regression')
plt.xlabel('Time (Days)')
plt.ylabel('Estimated Survival Probability')
plt.tight_layout(rect=[0, 0, 1.00, 0.95])
plt.show()
##################################
# Determining the risk category
# for the test case
##################################
optimal_coxph_heart_failure_y_train_pred = optimal_coxph_model.predict(X_train)
heart_failure_train['Predicted_Risks_CoxPH'] = optimal_coxph_heart_failure_y_train_pred
risk_groups, risk_group_bin_range = pd.qcut(heart_failure_train['Predicted_Risks_CoxPH'], 2, labels=['Low-Risk', 'High-Risk'], retbins=True)
risk_group_threshold = risk_group_bin_range[1]
print("Risk Category Threshold: ", risk_group_threshold)
Risk Category Threshold: 0.1856637832961452
##################################
# Saving the risk category threshold
# from the best Cox Proportional Hazards Regression Model
# developed from the original training data
##################################
joblib.dump(risk_group_threshold,
os.path.join("..", PARAMETERS_PATH, "coxph_best_model_risk_group_threshold.pkl"))
['..\\parameters\\coxph_best_model_risk_group_threshold.pkl']
##################################
# Describing the details of a
# low-risk test case for evaluation
##################################
X_sample = {'AGE': 43,
'ANAEMIA': 0,
'EJECTION_FRACTION': 75,
'HIGH_BLOOD_PRESSURE': 1,
'SERUM_CREATININE': 0.75,
'SERUM_SODIUM': 100}
X_test_sample = pd.DataFrame([X_sample])
X_test_sample.head()
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
0 | 43 | 0 | 75 | 1 | 0.75 | 100 |
##################################
# Applying preprocessing to
# the test case
##################################
X_test_sample_transformed = coxph_pipeline.named_steps['yeo_johnson'].transform(X_test_sample)
X_test_sample_converted = pd.DataFrame([X_test_sample_transformed[0]], columns=["AGE", "EJECTION_FRACTION", "SERUM_CREATININE", "SERUM_SODIUM", "ANAEMIA", "HIGH_BLOOD_PRESSURE"])
X_test_sample_converted.head()
AGE | EJECTION_FRACTION | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | HIGH_BLOOD_PRESSURE | |
---|---|---|---|---|---|---|
0 | -1.669035 | 2.423321 | -1.404833 | -4.476082 | 0.0 | 1.0 |
##################################
# Binning numeric predictors into two groups
##################################
for i, col in enumerate(["AGE", "EJECTION_FRACTION", "SERUM_CREATININE", "SERUM_SODIUM"]):
X_test_sample_converted[col] = X_test_sample_converted[col].apply(lambda x: 'High' if x > numeric_predictor_median_list[i] else 'Low')
##################################
# Converting integer predictors into labels
##################################
for col in ["ANAEMIA", "HIGH_BLOOD_PRESSURE"]:
X_test_sample_converted[col] = X_test_sample_converted[col].apply(lambda x: 'Absent' if x < 1.0 else 'Present')
##################################
# Describing the details of the
# test case for evaluation
##################################
X_test_sample_converted.head()
AGE | EJECTION_FRACTION | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | HIGH_BLOOD_PRESSURE | |
---|---|---|---|---|---|---|
0 | Low | High | Low | Low | Absent | Present |
##################################
# Plotting the estimated survival profiles
# of the test case
# using Kaplan-Meier Plots
##################################
fig, axes = plt.subplots(3, 2, figsize=(17, 13))
heart_failure_predictors = ['AGE','EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM','ANAEMIA','HIGH_BLOOD_PRESSURE']
for i, predictor in enumerate(heart_failure_predictors):
ax = axes[i // 2, i % 2]
plot_kaplan_meier(heart_failure_MI_EDA, predictor, ax, new_case_value=X_test_sample_converted[predictor][0])
ax.set_title(f'DEATH_EVENT Survival Probabilities by {predictor} Categories')
ax.set_ylim(-0.05, 1.05)
ax.set_xlabel('TIME')
ax.set_ylabel('DEATH_EVENT Survival Probability')
ax.legend(loc='lower left')
plt.tight_layout()
plt.show()
##################################
# Computing the estimated survival probability
# for the test case
##################################
X_test_sample_survival_function = optimal_coxph_model.predict_survival_function(X_test_sample)
##################################
# Determining the risk category
# for the test case
##################################
X_test_sample_risk_category = "High-Risk" if (optimal_coxph_model.predict(X_test_sample) > risk_group_threshold) else "Low-Risk"
##################################
# Computing the estimated survival probabilities
# for the test case at five defined time points
##################################
X_test_sample_survival_time = np.array([50, 100, 150, 200, 250])
X_test_sample_survival_probability = np.interp(X_test_sample_survival_time,
X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y)
X_test_sample_survival_probability_percentage = X_test_sample_survival_probability*100
for survival_time, survival_probability in zip(X_test_sample_survival_time, X_test_sample_survival_probability_percentage):
print(f"Test Case Survival Probability ({survival_time} Days): {survival_probability:.2f}%")
print(f"Test Case Risk Category: {X_test_sample_risk_category}")
Test Case Survival Probability (50 Days): 90.98% Test Case Survival Probability (100 Days): 87.65% Test Case Survival Probability (150 Days): 84.60% Test Case Survival Probability (200 Days): 78.48% Test Case Survival Probability (250 Days): 70.70% Test Case Risk Category: Low-Risk
##################################
# Plotting the estimated survival probability
# for the test case
# in the baseline survival function
# of the final survival prediction model
##################################
plt.figure(figsize=(17, 8))
for i, surv_func in enumerate(heart_failure_train_survival_function):
plt.step(surv_func.x,
surv_func.y,
where="post",
color='red' if y_train_reset_index['DEATH_EVENT'][i] == 1 else 'blue',
linewidth=6.0,
alpha=0.05)
if X_test_sample_risk_category == "Low-Risk":
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='blue',
linewidth=6.0,
linestyle='-',
alpha=0.30,
label='Test Case (Low-Risk)')
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='black',
linewidth=3.0,
linestyle=':',
label='Test Case (Low-Risk)')
for survival_time, survival_probability in zip(X_test_sample_survival_time, X_test_sample_survival_probability):
plt.vlines(x=survival_time, ymin=0, ymax=survival_probability, color='blue', linestyle='-', linewidth=2.0, alpha=0.30)
red_patch = plt.Line2D([0], [0], color='red', lw=6, alpha=0.30, label='Death Event Status = True')
blue_patch = plt.Line2D([0], [0], color='blue', lw=6, alpha=0.30, label='Death Event Status = False')
black_patch = plt.Line2D([0], [0], color='black', lw=3, linestyle=":", label='Test Case (Low-Risk)')
if X_test_sample_risk_category == "High-Risk":
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='red',
linewidth=6.0,
linestyle='-',
alpha=0.30,
label='Test Case (High-Risk)')
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='black',
linewidth=3.0,
linestyle=':',
label='Test Case (High-Risk)')
for survival_time, survival_probability in zip(X_test_sample_survival_time, X_test_sample_survival_probability):
plt.vlines(x=survival_time, ymin=0, ymax=survival_probability, color='red', linestyle='-', linewidth=2.0, alpha=0.30)
red_patch = plt.Line2D([0], [0], color='red', lw=6, alpha=0.30, label='Death Event Status = True')
blue_patch = plt.Line2D([0], [0], color='blue', lw=6, alpha=0.30, label='Death Event Status = False')
black_patch = plt.Line2D([0], [0], color='black', lw=3, linestyle=":", label='Test Case (High-Risk)')
plt.legend(handles=[red_patch, blue_patch, black_patch], facecolor='white', framealpha=1, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=3)
plt.title('Final Survival Prediction Model: Cox Proportional Hazards Regression')
plt.xlabel('Time (Days)')
plt.ylabel('Estimated Survival Probability')
plt.tight_layout(rect=[0, 0, 1.00, 0.95])
plt.show()
##################################
# Describing the details of a
# high-risk test case for evaluation
##################################
X_sample = {'AGE': 70,
'ANAEMIA': 1,
'EJECTION_FRACTION': 20,
'HIGH_BLOOD_PRESSURE': 1,
'SERUM_CREATININE': 0.75,
'SERUM_SODIUM': 100}
X_test_sample = pd.DataFrame([X_sample])
X_test_sample.head()
AGE | ANAEMIA | EJECTION_FRACTION | HIGH_BLOOD_PRESSURE | SERUM_CREATININE | SERUM_SODIUM | |
---|---|---|---|---|---|---|
0 | 70 | 1 | 20 | 1 | 0.75 | 100 |
##################################
# Applying preprocessing to
# the test case
##################################
coxph_pipeline.fit(X_train, y_train_array)
X_test_sample_transformed = coxph_pipeline.named_steps['yeo_johnson'].transform(X_test_sample)
X_test_sample_converted = pd.DataFrame([X_test_sample_transformed[0]], columns=["AGE", "EJECTION_FRACTION", "SERUM_CREATININE", "SERUM_SODIUM", "ANAEMIA", "HIGH_BLOOD_PRESSURE"])
X_test_sample_converted.head()
AGE | EJECTION_FRACTION | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | HIGH_BLOOD_PRESSURE | |
---|---|---|---|---|---|---|
0 | 0.758286 | -1.733233 | -1.404833 | -4.476082 | 1.0 | 1.0 |
##################################
# Binning numeric predictors into two groups
##################################
for i, col in enumerate(["AGE", "EJECTION_FRACTION", "SERUM_CREATININE", "SERUM_SODIUM"]):
X_test_sample_converted[col] = X_test_sample_converted[col].apply(lambda x: 'High' if x > numeric_predictor_median_list[i] else 'Low')
##################################
# Converting integer predictors into labels
##################################
for col in ["ANAEMIA", "HIGH_BLOOD_PRESSURE"]:
X_test_sample_converted[col] = X_test_sample_converted[col].apply(lambda x: 'Absent' if x < 1.0 else 'Present')
##################################
# Describing the details of the
# test case for evaluation
##################################
X_test_sample_converted.head()
AGE | EJECTION_FRACTION | SERUM_CREATININE | SERUM_SODIUM | ANAEMIA | HIGH_BLOOD_PRESSURE | |
---|---|---|---|---|---|---|
0 | High | Low | Low | Low | Present | Present |
##################################
# Plotting the estimated survival profiles
# of the test case
# using Kaplan-Meier Plots
##################################
fig, axes = plt.subplots(3, 2, figsize=(17, 13))
heart_failure_predictors = ['AGE','EJECTION_FRACTION','SERUM_CREATININE','SERUM_SODIUM','ANAEMIA','HIGH_BLOOD_PRESSURE']
for i, predictor in enumerate(heart_failure_predictors):
ax = axes[i // 2, i % 2]
plot_kaplan_meier(heart_failure_MI_EDA, predictor, ax, new_case_value=X_test_sample_converted[predictor][0])
ax.set_title(f'DEATH_EVENT Survival Probabilities by {predictor} Categories')
ax.set_ylim(-0.05, 1.05)
ax.set_xlabel('TIME')
ax.set_ylabel('DEATH_EVENT Survival Probability')
ax.legend(loc='lower left')
plt.tight_layout()
plt.show()
##################################
# Computing the estimated survival probability
# for the test case
##################################
X_test_sample_survival_function = optimal_coxph_model.predict_survival_function(X_test_sample)
##################################
# Determining the risk category
# for the test case
##################################
X_test_sample_risk_category = "High-Risk" if (optimal_coxph_model.predict(X_test_sample) > risk_group_threshold) else "Low-Risk"
##################################
# Computing the estimated survival probabilities
# for the test case at five defined time points
##################################
X_test_sample_survival_time = np.array([50, 100, 150, 200, 250])
X_test_sample_survival_probability = np.interp(X_test_sample_survival_time,
X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y)
X_test_sample_survival_probability_percentage = X_test_sample_survival_probability*100
for survival_time, survival_probability in zip(X_test_sample_survival_time, X_test_sample_survival_probability_percentage):
print(f"Test Case Survival Probability ({survival_time} Days): {survival_probability:.2f}%")
print(f"Test Case Risk Category: {X_test_sample_risk_category}")
Test Case Survival Probability (50 Days): 41.81% Test Case Survival Probability (100 Days): 29.65% Test Case Survival Probability (150 Days): 21.39% Test Case Survival Probability (200 Days): 10.71% Test Case Survival Probability (250 Days): 4.09% Test Case Risk Category: High-Risk
##################################
# Plotting the estimated survival probability
# for the test case
# in the baseline survival function
# of the final survival prediction model
##################################
plt.figure(figsize=(17, 8))
for i, surv_func in enumerate(heart_failure_train_survival_function):
plt.step(surv_func.x,
surv_func.y,
where="post",
color='red' if y_train_reset_index['DEATH_EVENT'][i] == 1 else 'blue',
linewidth=6.0,
alpha=0.05)
if X_test_sample_risk_category == "Low-Risk":
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='blue',
linewidth=6.0,
linestyle='-',
alpha=0.30,
label='Test Case (Low-Risk)')
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='black',
linewidth=3.0,
linestyle=':',
label='Test Case (Low-Risk)')
for survival_time, survival_probability in zip(X_test_sample_survival_time, X_test_sample_survival_probability):
plt.vlines(x=survival_time, ymin=0, ymax=survival_probability, color='blue', linestyle='-', linewidth=2.0, alpha=0.30)
red_patch = plt.Line2D([0], [0], color='red', lw=6, alpha=0.30, label='Death Event Status = True')
blue_patch = plt.Line2D([0], [0], color='blue', lw=6, alpha=0.30, label='Death Event Status = False')
black_patch = plt.Line2D([0], [0], color='black', lw=3, linestyle=":", label='Test Case (Low-Risk)')
if X_test_sample_risk_category == "High-Risk":
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='red',
linewidth=6.0,
linestyle='-',
alpha=0.30,
label='Test Case (High-Risk)')
plt.step(X_test_sample_survival_function[0].x,
X_test_sample_survival_function[0].y,
where="post",
color='black',
linewidth=3.0,
linestyle=':',
label='Test Case (High-Risk)')
for survival_time, survival_probability in zip(X_test_sample_survival_time, X_test_sample_survival_probability):
plt.vlines(x=survival_time, ymin=0, ymax=survival_probability, color='red', linestyle='-', linewidth=2.0, alpha=0.30)
red_patch = plt.Line2D([0], [0], color='red', lw=6, alpha=0.30, label='Death Event Status = True')
blue_patch = plt.Line2D([0], [0], color='blue', lw=6, alpha=0.30, label='Death Event Status = False')
black_patch = plt.Line2D([0], [0], color='black', lw=3, linestyle=":", label='Test Case (High-Risk)')
plt.legend(handles=[red_patch, blue_patch, black_patch], facecolor='white', framealpha=1, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=3)
plt.title('Final Survival Prediction Model: Cox Proportional Hazards Regression')
plt.xlabel('Time (Days)')
plt.ylabel('Estimated Survival Probability')
plt.tight_layout(rect=[0, 0, 1.00, 0.95])
plt.show()
1.7. Predictive Model Deployment Using Streamlit and Streamlit Community Cloud ¶
1.7.1 Model Prediction Application Code Development ¶
Streamlit is an open-source Python library that simplifies the creation and deployment of web applications for machine learning and data science projects. It allows developers and data scientists to turn Python scripts into interactive web apps quickly without requiring extensive web development knowledge. Streamlit seamlessly integrates with popular Python libraries such as Pandas, Matplotlib, Plotly, and TensorFlow, allowing one to leverage existing data processing and visualization tools within the application. Streamlit apps can be easily deployed on various platforms, including Streamlit Community Cloud, Heroku, or any cloud service that supports Python web applications.
Streamlit Community Cloud, formerly known as Streamlit Sharing, is a free cloud-based platform provided by Streamlit that allows users to easily deploy and share Streamlit apps online. It is particularly popular among data scientists, machine learning engineers, and developers for quickly showcasing projects, creating interactive demos, and sharing data-driven applications with a wider audience without needing to manage server infrastructure. Significant features include free hosting (Streamlit Community Cloud provides free hosting for Streamlit apps, making it accessible for users who want to share their work without incurring hosting costs), easy deployment (users can connect their GitHub repository to Streamlit Community Cloud, and the app is automatically deployed from the repository), continuous deployment (if the code in the connected GitHub repository is updated, the app is automatically redeployed with the latest changes), sharing capabilities (once deployed, apps can be shared with others via a simple URL, making it easy for collaborators, stakeholders, or the general public to access and interact with the app), built-in authentication (users can restrict access to their apps using GitHub-based authentication, allowing control over who can view and interact with the app), and community support (the platform is supported by a community of users and developers who share knowledge, templates, and best practices for building and deploying Streamlit apps).
1.7.2 Model Application Programming Interface Code Development ¶
- A model prediction application code in Python was developed to:
- generate the Kaplan-Meier plots for the test case and the study population data as baseline
- estimate the heart failure survival profile and probabilities for the test case and the study population data as baseline
- predict risk categories for the test case
- The model prediction application code was saved in a repository that was eventually cloned for uploading to Streamlit Community Cloud.
1.7.3 User Interface Application Code Development ¶
- A user interface application code in Python was developed to:
- generate the Kaplan-Meier plots for the test case and the study population data as baseline
- estimate the heart failure survival profile and probabilities for the test case and the study population data as baseline
- predict risk categories for the test case
- The user interface application code was saved in a repository that was eventually cloned for uploading to Streamlit Community Cloud.
1.7.4 Web Application ¶
- The prediction model was deployed using a web application hosted at Streamlit.
- The user interface input consists of the following:
- range sliders to:
- enable numerical input to measure the characteristics of the test case for certain cardiovascular, hematologic and metabolic markers
- radio buttons to:
- enable binary category selection (Present | Absent) to identify the status of the test case for certain hematologic and cardiovascular markers
- action button to:
- process study population data as baseline
- process user input as test case
- render all entries into visualization charts
- execute all computations, estimations and predictions
- render test case prediction into the survival probability plot
- range sliders to:
- The user interface ouput consists of the following:
- Kaplan-Meier plots to:
- provide a baseline visualization of the survival profiles of the various feature categories (Yes | No or High | Low) estimated from the study population given the survival time and event status
- Indicate the entries made from the user input to visually assess the survival probabilities of the test case characteristics against the study population across all time points
- survival probability plot to:
- provide a visualization of the baseline survival probability profile using each observation of the study population given the survival time and event status
- indicate the heart failure survival probabilities of the test case at different time points
- summary table to:
- present the estimated heart failure survival probabilities and predicted risk category for the test case
- Kaplan-Meier plots to:
2. Summary ¶
3. References ¶
- [Book] Clinical Prediction Models by Ewout Steyerberg
- [Book] Survival Analysis: A Self-Learning Text by David Kleinbaum and Mitchel Klein
- [Book] Applied Survival Analysis Using R by Dirk Moore
- [Book] Survival Analysis with Python by Avishek Nag
- [Python Library API] SciKit-Survival by SciKit-Survival Team
- [Python Library API] SciKit-Learn by SciKit-Learn Team
- [Python Library API] StatsModels by StatsModels Team
- [Python Library API] SciPy by SciPy Team
- [Python Library API] Lifelines by Lifelines Team
- [Python Library API] Streamlit by Streamlit Team
- [Python Library API] Streamlit Community Cloud by Streamlit Team
- [Kaggle Project] Applied Reliability, Solutions To Problems by Keenan Zhuo (Kaggle)
- [Kaggle Project] Survival Models VS ML Models Benchmark - Churn Tel by Carlos Alonso Salcedo (Kaggle)
- [Kaggle Project] Survival Analysis with Cox Model Implementation by Bryan Boulé (Kaggle)
- [Kaggle Project] Survival Analysis by Gunes Evitan (Kaggle)
- [Kaggle Project] Survival Analysis of Lung Cancer Patients by Sayan Chakraborty (Kaggle)
- [Kaggle Project] COVID-19 Cox Survival Regression by Ilias Katsabalos (Kaggle)
- [Kaggle Project] Liver Cirrhosis Prediction with XGboost & EDA by Arjun Bhaybang (Kaggle)
- [Article] Exploring Time-to-Event with Survival Analysis by Olivia Tanuwidjaja (Towards Data Science)
- [Article] The Complete Introduction to Survival Analysis in Python by Marco Peixeiro (Towards Data Science)
- [Article] Survival Analysis Simplified: Explaining and Applying with Python by Zeynep Atli (Towards Data Science)
- [Article] Survival Analysis in Python (KM Estimate, Cox-PH and AFT Model) by Rahul Raoniar (Medium)
- [Article] How to Evaluate Survival Analysis Models) by Nicolo Cosimo Albanese (Towards Data Science)
- [Article] Survival Analysis with Python Tutorial — How, What, When, and Why) by Towards AI Team (Medium)
- [Article] Survival Analysis: Predict Time-To-Event With Machine Learning) by Lina Faik (Medium)
- [Article] A Complete Guide To Survival Analysis In Python, Part 1 by Pratik Shukla (KDNuggets)
- [Article] A Complete Guide To Survival Analysis In Python, Part 2 by Pratik Shukla (KDNuggets)
- [Article] A Complete Guide To Survival Analysis In Python, Part 3 by Pratik Shukla (KDNuggets)
- [Article] Model Explainability using SHAP (SHapley Additive exPlanations) and LIME (Local Interpretable Model-agnostic Explanations) by Anshul Goel (Medium)
- [Article] A Comprehensive Guide into SHAP (SHapley Additive exPlanations) Values by Brain John Aboze (DeepChecks.Com)
- [Article] Survival Analysis by Jessica Lougheed and Lizbeth Benson (QuantDev.SSRI.PSU.Edu)
- [Article] Part 1: How to Format Data for Several Types of Survival Analysis Models by Jessica Lougheed and Lizbeth Benson (QuantDev.SSRI.PSU.Edu)
- [Article] Part 2: Single-Episode Cox Regression Model with Time-Invariant Predictors by Jessica Lougheed and Lizbeth Benson (QuantDev.SSRI.PSU.Edu)
- [Article] Part 3: Single-Episode Cox Regression Model with Time-Varying Predictors by Jessica Lougheed and Lizbeth Benson (QuantDev.SSRI.PSU.Edu)
- [Article] Part 4: Recurring-Episode Cox Regression Model with Time-Invariant Predictors by Jessica Lougheed and Lizbeth Benson (QuantDev.SSRI.PSU.Edu)
- [Article] Part 5: Recurring-Episode Cox Regression Model with Time-Varying Predictors by Jessica Lougheed and Lizbeth Benson (QuantDev.SSRI.PSU.Edu)
- [Article] Parametric Survival Modeling by Devin Incerti (DevinIncerti.Com)
- [Article] Survival Analysis Simplified: Explaining and Applying with Python by Zeynep Atli (Medium)
- [Article] Understanding Survival Analysis Models: Bridging the Gap between Parametric and Semiparametric Approaches by Zeynep Atli (Medium)
- [Article] Survival Modeling — Accelerated Failure Time — XGBoost by Avinash Barnwal (Medium)
- [Publication] Regression Models and Life Tables by David Cox (Royal Statistical Society)
- [Publication] Covariance Analysis of Censored Survival Data by Norman Breslow (Biometrics)
- [Publication] The Efficiency of Cox’s Likelihood Function for Censored Data by Bradley Efron (Journal of the American Statistical Association)
- [Publication] Regularization Paths for Cox’s Proportional Hazards Model via Coordinate Descent by Noah Simon, Jerome Friedman, Trevor Hastie and Rob Tibshirani (Journal of Statistical Software)
- [Publication] Shapley Additive Explanations by Noah Simon, Jerome Friedman, Trevor Hastie and Rob Tibshirani (Journal of Statistical Software) by Erik Strumbelj and Igor Kononenko (The Journal of Machine Learning Research)
- [Publication] A Unified Approach to Interpreting Model Predictions by Scott Lundberg and Sun-In Lee (Conference on Neural Information Processing Systems)
- [Publication] Survival Analysis Part I: Basic Concepts and First Analyses by Taane Clark (British Journal of Cancer)
- [Publication] Survival Analysis Part II: Multivariate Data Analysis – An Introduction to Concepts and Methods by Mike Bradburn (British Journal of Cancer)
- [Publication] Survival Analysis Part III: Multivariate Data Analysis – Choosing a Model and Assessing its Adequacy and Fit by Mike Bradburn (British Journal of Cancer)
- [Publication] Survival Analysis Part IV: Further Concepts and Methods in Survival Analysis by Taane Clark (British Journal of Cancer)
- [Publication] Marginal Likelihoods Based on Cox's Regression and Life Model by Jack Kalbfleisch and Ross Prentice (Biometrika)
- [Publication] Hazard Rate Models with Covariates by Jack Kalbfleisch and Ross Prentice (Biometrics)
- [Publication] Linear Regression with Censored Data by Jonathan Buckley and Ian James (Biometrika)
- [Publication] A Statistical Distribution Function of Wide Applicability by Waloddi Weibull (Journal of Applied Mechanics)
- [Publication] Exponential Survivals with Censoring and Explanatory Variables by Ross Prentice (Biometrika)
- [Publication] The Lognormal Distribution, with Special Reference to its Uses in Economics by John Aitchison and James Brown (Economics Applied Statistics)
- [Course] Survival Analysis in Python by Shae Wang (DataCamp)
from IPython.display import display, HTML
display(HTML("<style>.rendered_html { font-size: 15px; font-family: 'Trebuchet MS'; }</style>"))