Source code for whynot.framework

"""Base experimental class."""
import dataclasses
import inspect
import itertools
from typing import Any

import numpy as np


PARAM_COLLECTION = "__parameter_collection__"


[docs]@dataclasses.dataclass class Dataset: # pylint:disable-msg=too-few-public-methods """Observational dataset and grouth truth unit-level effects. Attributes ---------- covariates: `np.ndarray` Float array of shape [num_samples, num_features] of covariates for each unit. treatments: `np.ndarray` Integer 0/1 array of shape [num_samples] indicating treatment status for each unit. 1 indicates treated, 0 indicates unit not treated. outcomes: `np.ndarray` Float array of shape [num_samples] containing the observed outcome for each unit. true_effects: `np.ndarray` Float array of shape [num_samples] containing the unit-level treatment effects, :math:`Y_i(1) - Y_i(0)` for each :math:`i`. sate: float Sample average treatment effect based on ground truth unit-level effects. causal_graph: networkx.DiGraph If supported by the simulator and experiment, the causal graph associated with the data. """ covariates: np.ndarray treatments: np.ndarray outcomes: np.ndarray true_effects: np.ndarray causal_graph: Any = None @property def sate(self): """Return the sample average treatment effect.""" return np.mean(self.true_effects) def bootstrap_sample_ate_ci(self, num_bootstrap_samples=2000, alpha=0.05): """Bootstrap a (1-alpha)% confidence interval for the sample ate.""" means = [] for _ in range(num_bootstrap_samples): sample = np.random.choice( self.true_effects, size=len(self.true_effects), replace=True ) means.append(np.mean(sample)) lower_tail, upper_tail = alpha / 2.0, 1.0 - alpha / 2.0 return (np.quantile(means, lower_tail), np.quantile(means, upper_tail))
[docs]@dataclasses.dataclass class InferenceResult: # pylint:disable-msg=too-few-public-methods """Object to store results of causal inference method. Attributes ---------- ate: Estimated average treatment effect stderr: Reported standard error of the ATE estimate. Only available if supported by the method, otherwise None. ci: Reported 95% confidence interval for the ATE (lower_bound, upper_bound). Only available if supported by the method, otherwise None. TODO: Ideally, we'd support various significance levels. individual_effects: Heterogeneous treatment effect for each unit. Only available if supported by the method, otherwise None. elapsed_time: How long in second (wall-clock time) it took to produce the estimate. """ ate: float = None stderr: float = None ci: tuple = (None, None) # pylint: disable-msg=invalid-name individual_effects: np.ndarray = None elapsed_time: float = None
[docs]@dataclasses.dataclass class ExperimentParameter: # pylint:disable-msg=too-few-public-methods """Container for a parameter to vary in an experiment. Attributes ---------- name: Name of the parameter description: Parameter description default: Default (uninitialized) value of the parameter values: Iterator of parameter values that supports sampling (for random search). """ name: str default: Any values: Any = None description: str = ""
def parameter(name, default, values=None, description=""): """Decorate functions in an experiment with parameters. Usage: @parameter(name="treatment-bias", default=0.2, values=[0.1, 0.2, 0.5, 0.9], description="How much is treatment biased for control group.") def propensity_score(untreated_run, treatment_bias) .... """ exp_param = ExperimentParameter(name, default, values, description) def parameter_decorator(func): """Attach parameter collection object to the function.""" # Ensure the parameter is an argument to the function method_params = inspect.signature(func).parameters if name not in method_params: raise ValueError(f"{name} is not an argument to {func.__name__}") if hasattr(func, PARAM_COLLECTION): getattr(func, PARAM_COLLECTION).add_parameter(exp_param) else: setattr(func, PARAM_COLLECTION, ParameterCollection([exp_param])) return func return parameter_decorator def extract_params(func, standard_args): """Return WhyNot parameters for user-defined function. Performs error-checking to ensure parameters are disjoint from standard arguments and all arguments to the function are either standard arguments or parameters. Parameters ---------- func: function Possibly parameterized function. standard_args: list A list of possible arguments provided by the calling class itself and shouldn't be treated as parameters for the function. Returns ------- params: `whynot.framework.ParameterCollection` A collection of parameters for the func. """ if not callable(func): msg = f"Trying to extract parameters from {func.__name__}, but not callable." raise ValueError(msg) # Extract parameters specified by the user via the @parameter decorator. specified_params = ParameterCollection([]) if hasattr(func, PARAM_COLLECTION): specified_params = getattr(func, PARAM_COLLECTION) # Ensure standard_args is disjoint from the specified params. for arg in standard_args: if arg in specified_params: msg = ( f"{arg} is both a parameter and a standard argument to {func.__name__}." ) raise ValueError(msg) # By construction, every element in specified_params # must appear in the function signature, i.e. method_params. method_args = inspect.signature(func).parameters for arg in method_args: if arg not in standard_args and arg not in specified_params: msg = ( f"'{arg}' is in the signature of function {func.__name__}, " f"but '{arg}' is not a standard argument or a parameter. " f"Standard arguments: {', '.join(standard_args)}." ) raise ValueError(msg) return specified_params
[docs]class ParameterCollection: """Lightweight wrapper class around a set of parameters. Provides utility functions to support sampling and assigning subsets of the parameters. Enforces name uniqueness. Every parameter should have a unique name. """ def __init__(self, params=None): """Params is a list of Parameter objects.""" self.params = {} if params is not None: for param in params: if param.name in self.params: raise ValueError(f"Duplicate name {param.name}") self.params[param.name] = param def add_parameter(self, param): """Add a new parameter to the existing collection.""" if param.name in self.params: raise ValueError(f"Adding parameter with duplicate name {param.name}") self.params[param.name] = param def default(self): """Return the default parameter setting for each parameter.""" return {name: p.default for name, p in self.params.items()} def sample(self, seed=None): """Return a random parameter setting for each parameter.""" rng = np.random.RandomState(seed) sampled_params = {} for name, param in self.params.items(): if param.values is None: sampled_params[name] = param.default else: sampled_params[name] = rng.choice(param.values) return sampled_params def project(self, specified_params): """Return fully instantiated parameters using defaults for unspecified params. specified_params is a dict of the arguments passed to the run method. Any unspecified params are set to their default values. Returns a dictonary params[param_name] -> param_value """ for param_name in specified_params: if param_name not in self.params: raise ValueError(f"Parameter {param_name} specified, but not used!.") params = {} for name, param in self.params.items(): if name in specified_params: params[name] = specified_params[name] else: params[name] = param.default return params def __contains__(self, name): """Check if param name is specified.""" return name in self.params def __getitem__(self, name): """Return the parameter object corresponding to name.""" if name in self.params: return self.params[name] raise ValueError(f"{name} not found") def __add__(self, collection): """Add two parameter collections together.""" original_params = list(self.params.values()) new_params = list(collection.params.values()) return ParameterCollection(original_params + new_params) def __iter__(self): """Iterate over the collection.""" for param in self.params.values(): yield param def __repr__(self): """See ___str___.""" return self.__str__() def __str__(self): """Display the collection in a human readable format. For instance: Params: Name: hidden_dim Description: hidden dimension of 2-layer ReLu network response. Default: 32 Values: [8, 16, 32, 64, 128, 256, 512] """ class_display = "Params:" param_display = ( "\tName:\t\t{}\n\tDescription:\t{}\n\tDefault:\t{}\n\tValues:\t\t{}\n" ) for param in self.params.values(): param_values = [] if param.values is None else param.values class_display += "\n" + param_display.format( param.name, param.description, param.default, param_values ) return class_display def grid(self): """Return a parameter grid. Examples -------- >>> p1 = ExperimentParams(name="a", values=[1, 2]) >>> p2 = ExperimentParams(name="b", values=[3, 4]) >>> collection = ParameterCollection(params=[a, b]) >>> for a, b in collection.grid(): ... print(a, b) 1, 3 2, 3 1, 4 2, 4 """ grid = [] for values in itertools.product(*[p.values for p in self.params.values()]): settings = {} for param_, value in zip(self.params, values): settings[param_.name] = value grid.append(settings) return grid
#################### # Generic Simulator #################### class GenericExperiment: # pylint:disable-msg=too-few-public-methods """Encapsulate a causal simulation experiment.""" def __init__(self, name, description, run_method): """Initialize a generic experiment class.""" self.name = name self.description = description self.run_method = run_method run_args = ["num_samples", "seed", "parallelize", "show_progress"] self.params = extract_params(run_method, run_args) def get_parameters(self): """Return parameters of the experiment.""" return self.params def run( self, num_samples, seed, parallelize=True, show_progress=False, **parameter_args ): """Run the experiment and return a causal dataset.""" run_parameters = self.params.project(parameter_args) results = self.run_method( num_samples=num_samples, seed=seed, show_progress=show_progress, parallelize=parallelize, **run_parameters, ) (covariates, treatment, outcome), ground_truth = results return Dataset( covariates=covariates, treatments=treatment, outcomes=outcome, true_effects=ground_truth, )