How to generate covariate-adjusted cox survival/hazard functions?

You want to obtain survival probabilities from a Cox model for certain values of some covariate of interest, while adjusting for other covariates. However, because we do not make any assumption on the distribution of the survival times in a Cox model, we cannot directly obtain survival probabilities from it. We first have to estimate the baseline hazard function, which is typically done with the non-parametric Breslow estimator. When the Cox model is fitted with coxph from the survival package, we can obtain such probabilites with a call to the survfit() function. You may consult ?survfit.coxph for more information.

Let's see how we can do this by using the lung data set.

library(survival)

# select covariates of interest
df <- subset(lung, select = c(time, status, age, sex, ph.karno))

# assess whether there are any missing observations
apply(df, 2, \(x) sum(is.na(x))) # 1 in ph.karno

# listwise delete missing observations
df <- df[complete.cases(df), ]

# Cox model
fit <- coxph(Surv(time, status == 2) ~ age + sex + ph.karno, data = df)

## Note that I ignore the fact that ph.karno does not satisfy the PH assumption.

# specify for which combinations of values of age, sex, and 
# ph.karno we want to derive survival probabilies
ND1 <- with(df, expand.grid(
  age = median(age),
  sex = c(1,2),
  ph.karno = median(ph.karno)
))
ND2 <- with(df, expand.grid(
  age = median(age),
  sex = 1, # males
  ph.karno = round(create_intervals(n_groups = 3L))
))

# Obtain the expected survival times
sfit1 <- survfit(fit, newdata = ND1)
sfit2 <- survfit(fit, newdata = ND2)

The code behind the function create_intervals() can be found in this post. I just simply replaced speed with ph.karno in the function.

The output sfit1 contains the expected median survival times and the corresponding 95% confidence intervals for the combinations of covariates as specified in ND1.

> sfit1
Call: survfit(formula = fit, newdata = ND)

    n events median 0.95LCL 0.95UCL
1 227    164    283     223     329
2 227    164    371     320     524

Survival probabilities at specific follow-up times be obtained with the times argument of the summary() method.

# survival probabilities at 200 days of follow-up
summary(sfit1, times = 200)

The output contains again the expected survival probability, but now after 200 days of follow-up, wherein survival1 corresponds to the expected survival probability of the first row of ND1, i.e. a male and female patient of median age with median ph.karno.

> summary(sfit1, times = 200)
Call: survfit(formula = fit, newdata = ND1)

 time n.risk n.event survival1 survival2
  200    144      71     0.625     0.751

The 95% confidence limits associated with these two probabilities can be manually extracted from summary().

sum_sfit <- summary(sfit1, times = 200)
sum_sfit <- t(rbind(sum_sfit$surv, sum_sfit$lower, sum_sfit$upper))
colnames(sum_sfit) <- c("S_hat", "2.5 %", "97.5 %")
# ------------------------------------------------------
> sum_sfit
      S_hat     2.5 %    97.5 %
1 0.6250586 0.5541646 0.7050220
2 0.7513961 0.6842830 0.8250914

If you would like to use ggplot to depict the expected survival probabilities (and the corresponding 95% confidence intervals) for the combinations of values as specified in ND1 and ND2, we first need to make data.frames that contain all the information in an appropriate format.

# function which returns the output from a survfit.object
# in an appropriate format, which can be used in a call
# to ggplot()
df_fun <- \(surv_obj, newdata, factor) {
  len <- length(unique(newdata[[factor]]))
  out <- data.frame(
    time = rep(surv_obj[['time']], times = len),
    n.risk = rep(surv_obj[['n.risk']], times = len),
    n.event = rep(surv_obj[['n.event']], times = len),
    surv = stack(data.frame(surv_obj[['surv']]))[, 'values'],
    upper = stack(data.frame(surv_obj[['upper']]))[, 'values'],
    lower = stack(data.frame(surv_obj[['lower']]))[, 'values']
  )
  out[, 7] <- gl(len, length(surv_obj[['time']]))
  names(out)[7] <- 'factor'
  return(out)
}

# data for the first panel (A)
df_leftPanel <- df_fun(surv_obj = sfit1, newdata = ND1, factor = 'sex')

# data for the second panel (B)
df_rightPanel <- df_fun(surv_obj = sfit2, newdata = ND2, factor = 'ph.karno')

Now that we have defined our data.frames, we need to define a new function which allows us to plot the 95% CIs. We assign it the generic name geom_stepribbon.

library(ggplot2)

# Function for geom_stepribbon
geom_stepribbon <- function(
  mapping     = NULL,
  data        = NULL,
  stat        = "identity",
  position    = "identity",
  na.rm       = FALSE,
  show.legend = NA,
  inherit.aes = TRUE, ...) {
  layer(
    data        = data,
    mapping     = mapping,
    stat        = stat,
    geom        = GeomStepribbon,
    position    = position,
    show.legend = show.legend,
    inherit.aes = inherit.aes,
    params      = list(na.rm = na.rm, ... )
  )
}

GeomStepribbon <- ggproto(
  "GeomStepribbon", GeomRibbon,
  extra_params = c("na.rm"),
  draw_group = function(data, panel_scales, coord, na.rm = FALSE) {
    if (na.rm) data <- data[complete.cases(data[c("x", "ymin", "ymax")]), ]
    data   <- rbind(data, data)
    data   <- data[order(data$x), ]
    data$x <- c(data$x[2:nrow(data)], NA)
    data   <- data[complete.cases(data["x"]), ]
    GeomRibbon$draw_group(data, panel_scales, coord, na.rm = FALSE)
  }
)

Finally, we can plot the expected survival probabilities for ND1 and ND2.

yl <- 'Expected Survival probability\n'
xl <- '\nTime (days)'

# left panel
my_colours <- c('blue4', 'darkorange')
adj_colour <- \(x) adjustcolor(x, alpha.f = 0.2)
my_colours <- c(
  my_colours, adj_colour(my_colours[1]), adj_colour(my_colours[2])
)
left_panel <- ggplot(df_leftPanel,
                     aes(x = time, colour = factor, fill = factor)) + 
  geom_step(aes(y = surv), size = 0.8) + 
  geom_stepribbon(aes(ymin = lower, ymax = upper), colour = NA) +
  scale_colour_manual(name = 'Sex',
                      values = c('1' = my_colours[1],
                                 '2' = my_colours[2]),
                      labels = c('1' = 'Males',
                                 '2' = 'Females')) +
  scale_fill_manual(name = 'Sex',
                    values = c('1' = my_colours[3],
                               '2' = my_colours[4]),
                    labels = c('1' = 'Males',
                               '2' = 'Females')) +
  ylab(yl) + xlab(xl) +
  theme(axis.text = element_text(size = 12),
        axis.title = element_text(size = 12),
        legend.text = element_text(size = 12),
        legend.title = element_text(size = 12),
        legend.position = 'top')

# right panel
my_colours <- c('blue4', 'darkorange', '#00b0a4')
my_colours <- c(
  my_colours, adj_colour(my_colours[1]),
  adj_colour(my_colours[2]), adj_colour(my_colours[3])
)
right_panel <- ggplot(df_rightPanel,
                      aes(x = time, colour = factor, fill = factor)) + 
  geom_step(aes(y = surv), size = 0.8) +  
  geom_stepribbon(aes(ymin = lower, ymax = upper), colour = NA) +
  scale_colour_manual(name = 'Ph.karno',
                      values = c('1' = my_colours[1],
                                 '2' = my_colours[2],
                                 '3' = my_colours[3]),
                      labels = c('1' = 'Low',
                                 '2' = 'Middle',
                                 '3' = 'High')) +
  scale_fill_manual(name = 'Ph.karno',
                    values = c('1' = my_colours[4],
                               '2' = my_colours[5],
                               '3' = my_colours[6]),
                    labels = c('1' = 'Low',
                               '2' = 'Middle',
                               '3' = 'High')) +
  ylab(yl) + xlab(xl) +
  theme(axis.text = element_text(size = 12),
        axis.title = element_text(size = 12),
        legend.text = element_text(size = 12),
        legend.title = element_text(size = 12),
        legend.position = 'top')

# composite plot
library(ggpubr)
ggarrange(left_panel, right_panel,
          ncol = 2, nrow = 1,
          labels = c('A', 'B'))

Output

enter image description here

Interpretation

  • Panel A shows the expected survival probabilities for a male and female patient of median age with a median ph.karno.
  • Panel B shows the expected survival probabilities for three male patients of median age with ph.karnos of 67 (low), 83 (middle), and 100 (high).

These survival curves will always satisfy the PH assumption, as they were derived from the Cox model.

Note: use function(x) instead of \(x) if you use a version of R <4.1.0