Causal Graph Discovery¶
WhyNot comes equipped with tools to automatically generate the causal graph associated with experiments on the dynamical system simulators. This allows users both to experiment with graphical methods for causal inference and enables experimentation with methods for causal structure discovery.
The user writes simulators and experiments as normal in raw Python and NumPy. Then, building off of work in automatic differentiation, WhyNot traces the evolution of the state variables during simulation and automatically builds up the corresponding causal graph. This approach to graph generation is very fast, extremely flexible, and significantly less error-prone than tracking the simulator dynamics and experimental setup by hand.
Automatic graph generation is still experimental, and it is likely there are remaining rough edges. Graph generation is currently supported on most, but not all, simulators. Support for the remaining simulators (world3 and DICE) is forthcoming.
Generating Causal Graphs for Dynamical System Simulators¶
First, generate a run of the simulator
import whynot as wn
initial_state = wn.hiv.State()
config = wn.hiv.Config()
run = wn.hiv.simulate(initial_state, config)
Then, generate the causal graph associated with the run(s) of the simulator.
graph = wn.causal_graphs.build_dynamics_graph(wn.hiv, [run], config)
The nodes in the graph correspond to the state variables at each time step.
>>> print(graph.nodes)
'uninfected_T1_0.0', 'infected_T1_0.0', 'uninfected_T2_0.0',
'infected_T2_0.0', 'free_virus_0.0', 'immune_response_0.0',
'uninfected_T1_1.0', 'infected_T1_1.0', 'uninfected_T2_1.0',
'infected_T2_1.0', 'free_virus_1.0', 'immune_response_1.0' ...
The edges in the graph correspond to edges between subsequent states, which are defined by the dynamics.
>>> print(graph.edges)
('uninfected_T1_0.0', 'uninfected_T1_1.0')
('uninfected_T1_0.0', 'infected_T1_1.0')
('uninfected_T1_0.0', 'free_virus_1.0')
('infected_T1_0.0', 'infected_T1_1.0')
('infected_T1_0.0', 'free_virus_1.0')
...
This graph, along with the data generated by the run
, can then be used to
causal graph discovery methods.
Generating Causal Graphs for WhyNot Experiments¶
To enable tracing of user-defined functions with NumPy, we import a thinly-wrapped NumPy version,
import whynot as wn
import whynot.traceable_numpy as np
Then define a DynamicsExperiment
as normal.
def sample_initial_states(rng):
"""Initial state distribution"""
rabbits = rng.randint(10, 100)
foxes = rng.uniform(0.1, 0.8) * rabbits
return wn.lotka_volterra.State(rabbits=rabbits, foxes=foxes)
def soft_threshold(x, threshold, r=20):
"""A continuous relaxation of the threshold function. If x > tau, return ~1, if x < tau, returns ~0."""
return 1. / (np.exp(r * (threshold - x)) + 1)
def confounded_propensity_scores(untreated_run, intervention):
"""Confounded treatment assignment probability."""
run = untreated_run
return 0.3 + 0.4 * (1. - soft_threshold(run[intervention.time].foxes, threshold=7))
exp = wn.DynamicsExperiment(
name="lotka_volterra_confounding",
description=("Determine effect of reducing rabbits needed to sustain a fox."),
simulator=wn.lotka_volterra,
simulator_config=wn.lotka_volterra.Config(fox_growth=0.75, delta_t=1, end_time=6),
intervention=wn.lotka_volterra.Intervention(time=3, fox_growth=0.4),
state_sampler=sample_initial_states,
propensity_scorer=confounded_propensity_scores,
outcome_extractor=lambda run: run.states[-1].foxes,
covariate_builder=lambda run, intervention: run[intervention.time].foxes)
To generate the dataset and the associated causal graph, pass
causal_graph=True
into the run
method.
dataset = exp.run(num_samples=100, causal_graph=True)
# The causal graph is a networkx DiGraph
causal_graph = dataset.causal_graph
Important: While execution tracing is a flexible way to build up the causal
graph, it does not interact well with control flows based on conditional
statements, e.g. if
. To understand the problem, suppose we write:
def func(x):
if x > 2:
return 3
return 5
y = func(x)
Then, y
depends on x
. However, the graph tracing approach cannot uncover
this dependency since the output of func
is a constant. To avoid this corner
case, all code in the user defined functions should be straight line code.
For example, in the propensity scoring function above, we used a soft threshold rather than a hard threshold to make graph construction possible.
# BAD: Not traceable
# Hard IF statement: Graph tracing cannot discover that treatment
# assignment depends on the fox population at the time of intervention.
def confounded_propensity_scores(untreated_run, intervention):
if untreated_run[intervention.time].foxes > 7:
return 0.7
return 0.4
# GOOD: Traceable
# Soft/continuous variant: Graph tracing discovers treatment depends on
# the fox population at the time of intervention.
def confounded_propensity_scores(untreated_run, intervention):
run = untreated_run
return 0.3 + 0.4 * (1. - soft_threshold(run[intervention.time].foxes, threshold=7))