.. _causal-graph-discovery: Causal Graph Discovery ====================== WhyNot comes equipped with tools to automatically generate the causal graph associated with experiments on the dynamical system simulators. This allows users both to experiment with graphical methods for causal inference and enables experimentation with methods for causal structure discovery. The user writes simulators and experiments as normal in raw Python and NumPy. Then, building off of work in automatic differentiation, WhyNot traces the evolution of the state variables during simulation and automatically builds up the corresponding causal graph. This approach to graph generation is very fast, extremely flexible, and significantly less error-prone than tracking the simulator dynamics and experimental setup by hand. *Automatic graph generation is still experimental, and it is likely there are remaining rough edges. Graph generation is currently supported on most, but not all, simulators. Support for the remaining simulators (world3 and DICE) is forthcoming.* Generating Causal Graphs for Dynamical System Simulators -------------------------------------------------------- First, generate a run of the simulator .. code:: python import whynot as wn initial_state = wn.hiv.State() config = wn.hiv.Config() run = wn.hiv.simulate(initial_state, config) Then, generate the causal graph associated with the run(s) of the simulator. .. code:: python graph = wn.causal_graphs.build_dynamics_graph(wn.hiv, [run], config) The nodes in the graph correspond to the state variables at each time step. .. code:: python >>> print(graph.nodes) 'uninfected_T1_0.0', 'infected_T1_0.0', 'uninfected_T2_0.0', 'infected_T2_0.0', 'free_virus_0.0', 'immune_response_0.0', 'uninfected_T1_1.0', 'infected_T1_1.0', 'uninfected_T2_1.0', 'infected_T2_1.0', 'free_virus_1.0', 'immune_response_1.0' ... The edges in the graph correspond to edges between subsequent states, which are defined by the dynamics. .. code:: python >>> print(graph.edges) ('uninfected_T1_0.0', 'uninfected_T1_1.0') ('uninfected_T1_0.0', 'infected_T1_1.0') ('uninfected_T1_0.0', 'free_virus_1.0') ('infected_T1_0.0', 'infected_T1_1.0') ('infected_T1_0.0', 'free_virus_1.0') ... This graph, along with the data generated by the ``run``, can then be used to causal graph discovery methods. Generating Causal Graphs for WhyNot Experiments ----------------------------------------------- To enable tracing of user-defined functions with NumPy, we import a thinly-wrapped NumPy version, .. code:: python import whynot as wn import whynot.traceable_numpy as np Then define a :class:`~whynot.dynamics.DynamicsExperiment` as normal. .. code:: python def sample_initial_states(rng): """Initial state distribution""" rabbits = rng.randint(10, 100) foxes = rng.uniform(0.1, 0.8) * rabbits return wn.lotka_volterra.State(rabbits=rabbits, foxes=foxes) def soft_threshold(x, threshold, r=20): """A continuous relaxation of the threshold function. If x > tau, return ~1, if x < tau, returns ~0.""" return 1. / (np.exp(r * (threshold - x)) + 1) def confounded_propensity_scores(untreated_run, intervention): """Confounded treatment assignment probability.""" run = untreated_run return 0.3 + 0.4 * (1. - soft_threshold(run[intervention.time].foxes, threshold=7)) exp = wn.DynamicsExperiment( name="lotka_volterra_confounding", description=("Determine effect of reducing rabbits needed to sustain a fox."), simulator=wn.lotka_volterra, simulator_config=wn.lotka_volterra.Config(fox_growth=0.75, delta_t=1, end_time=6), intervention=wn.lotka_volterra.Intervention(time=3, fox_growth=0.4), state_sampler=sample_initial_states, propensity_scorer=confounded_propensity_scores, outcome_extractor=lambda run: run.states[-1].foxes, covariate_builder=lambda run, intervention: run[intervention.time].foxes) To generate the dataset and the associated causal graph, pass ``causal_graph=True`` into the ``run`` method. .. code:: python dataset = exp.run(num_samples=100, causal_graph=True) # The causal graph is a networkx DiGraph causal_graph = dataset.causal_graph **Important**: While execution tracing is a flexible way to build up the causal graph, it does not interact well with control flows based on conditional statements, e.g. ``if``. To understand the problem, suppose we write: .. code:: python def func(x): if x > 2: return 3 return 5 y = func(x) Then, ``y`` depends on ``x``. However, the graph tracing approach cannot uncover this dependency since the output of ``func`` is a constant. To avoid this corner case, all code in the user defined functions should be `straight line` code. For example, in the propensity scoring function above, we used a `soft threshold` rather than a hard threshold to make graph construction possible. .. code:: python # BAD: Not traceable # Hard IF statement: Graph tracing cannot discover that treatment # assignment depends on the fox population at the time of intervention. def confounded_propensity_scores(untreated_run, intervention): if untreated_run[intervention.time].foxes > 7: return 0.7 return 0.4 # GOOD: Traceable # Soft/continuous variant: Graph tracing discovers treatment depends on # the fox population at the time of intervention. def confounded_propensity_scores(untreated_run, intervention): run = untreated_run return 0.3 + 0.4 * (1. - soft_threshold(run[intervention.time].foxes, threshold=7))