Source code for whynot.algorithms.causal_suite

"""Algorithms for causal inference."""
import whynot

try:
    import whynot_estimators

    ESTIMATORS_INSTALLED = True
except ImportError:
    ESTIMATORS_INSTALLED = False


[docs]def causal_suite(covariates, treatment, outcome, verbose=False): """Run a collection of causal inference algorithms on the observational dataset. By default, the suite only runs estimators implemented in Python - Ordinary least squares (ols) - Propensity score matching - Propensity weighted least squares Depending on the estimators installed in whynot_estimators, the suite additionally runs: - An IP weighting estimator (ip_weighting) - Matching estimators with mahalanobis distance metrics - Causal Forest (causal_forest) - TMLE (tmle) Parameters ---------- covariates : `np.ndarray` Array of shape `num_samples` x `num_features`. treatment : np.ndarray Boolean array of shape [num_samples] indicating treatment status for each sample. outcome : np.ndarray Array of shape [num_sample] containing the observed outcome for each sample. verbose : bool If True, print incremental messages as estimators are executed. Returns ------- results: dict Dictionary with keys denoting the name of each method and values the corresponding :class:`InferenceResult`: results['causal_forest'] -> inference_results """ # Map from algorithm name to estimation function. methods = { "ols": whynot.ols, "propensity_score_matching": whynot.propensity_score_matching, "propensity_weighted_ols": whynot.propensity_weighted_ols, } if ESTIMATORS_INSTALLED: additional_methods = { "ip_weighting": whynot_estimators.ip_weighting, "mahalanobis_matching": whynot_estimators.matching, "causal_forest": whynot_estimators.causal_forest, "tmle": whynot_estimators.tmle, } installed = {} for name, estimator in additional_methods.items(): if estimator.is_installed(): installed[name] = estimator methods.update(installed) elif verbose: print("Only the base estimators are installed.") print( "To install additional additional estimators, `pip install whynot_estimators.`" ) results = {} for name, estimator in methods.items(): if verbose: print(f"Running estimator: {name}") try: inference_result = estimator.estimate_treatment_effect( covariates, treatment, outcome ) # pylint:disable-msg=broad-except except Exception: print(f"Estimator {name} failed to run... skipping!") continue results[name] = inference_result return results