Causal Graph Discovery

WhyNot comes equipped with tools to automatically generate the causal graph associated with experiments on the dynamical system simulators. This allows users both to experiment with graphical methods for causal inference and enables experimentation with methods for causal structure discovery.

The user writes simulators and experiments as normal in raw Python and NumPy. Then, building off of work in automatic differentiation, WhyNot traces the evolution of the state variables during simulation and automatically builds up the corresponding causal graph. This approach to graph generation is very fast, extremely flexible, and significantly less error-prone than tracking the simulator dynamics and experimental setup by hand.

Automatic graph generation is still experimental, and it is likely there are remaining rough edges. Graph generation is currently supported on most, but not all, simulators. Support for the remaining simulators (world3 and DICE) is forthcoming.

Generating Causal Graphs for Dynamical System Simulators

First, generate a run of the simulator

import whynot as wn

initial_state = wn.hiv.State()
config = wn.hiv.Config()

run = wn.hiv.simulate(initial_state, config)

Then, generate the causal graph associated with the run(s) of the simulator.

graph = wn.causal_graphs.build_dynamics_graph(wn.hiv, [run], config)

The nodes in the graph correspond to the state variables at each time step.

>>> print(graph.nodes)

'uninfected_T1_0.0', 'infected_T1_0.0', 'uninfected_T2_0.0',
'infected_T2_0.0', 'free_virus_0.0', 'immune_response_0.0',
'uninfected_T1_1.0', 'infected_T1_1.0', 'uninfected_T2_1.0',
'infected_T2_1.0', 'free_virus_1.0', 'immune_response_1.0' ...

The edges in the graph correspond to edges between subsequent states, which are defined by the dynamics.

>>> print(graph.edges)
('uninfected_T1_0.0', 'uninfected_T1_1.0')
('uninfected_T1_0.0', 'infected_T1_1.0')
('uninfected_T1_0.0', 'free_virus_1.0')
('infected_T1_0.0', 'infected_T1_1.0')
('infected_T1_0.0', 'free_virus_1.0')
...

This graph, along with the data generated by the run, can then be used to causal graph discovery methods.

Generating Causal Graphs for WhyNot Experiments

To enable tracing of user-defined functions with NumPy, we import a thinly-wrapped NumPy version,

import whynot as wn
import whynot.traceable_numpy as np

Then define a DynamicsExperiment as normal.

def sample_initial_states(rng):
    """Initial state distribution"""
    rabbits = rng.randint(10, 100)
    foxes = rng.uniform(0.1, 0.8) * rabbits
    return wn.lotka_volterra.State(rabbits=rabbits, foxes=foxes)

def soft_threshold(x, threshold, r=20):
    """A continuous relaxation of the threshold function. If x > tau, return ~1, if x < tau, returns ~0."""
    return 1. / (np.exp(r * (threshold  - x)) + 1)


def confounded_propensity_scores(untreated_run, intervention):
    """Confounded treatment assignment probability."""
    run = untreated_run
    return 0.3 + 0.4 * (1. - soft_threshold(run[intervention.time].foxes, threshold=7))


exp = wn.DynamicsExperiment(
    name="lotka_volterra_confounding",
    description=("Determine effect of reducing rabbits needed to sustain a fox."),
    simulator=wn.lotka_volterra,
    simulator_config=wn.lotka_volterra.Config(fox_growth=0.75, delta_t=1, end_time=6),
    intervention=wn.lotka_volterra.Intervention(time=3, fox_growth=0.4),
    state_sampler=sample_initial_states,
    propensity_scorer=confounded_propensity_scores,
    outcome_extractor=lambda run: run.states[-1].foxes,
    covariate_builder=lambda run, intervention: run[intervention.time].foxes)

To generate the dataset and the associated causal graph, pass causal_graph=True into the run method.

dataset = exp.run(num_samples=100, causal_graph=True)

# The causal graph is a networkx DiGraph
causal_graph = dataset.causal_graph

Important: While execution tracing is a flexible way to build up the causal graph, it does not interact well with control flows based on conditional statements, e.g. if. To understand the problem, suppose we write:

def func(x):
    if x > 2:
        return 3
    return 5

y = func(x)

Then, y depends on x. However, the graph tracing approach cannot uncover this dependency since the output of func is a constant. To avoid this corner case, all code in the user defined functions should be straight line code.

For example, in the propensity scoring function above, we used a soft threshold rather than a hard threshold to make graph construction possible.

# BAD: Not traceable
# Hard IF statement: Graph tracing cannot discover that treatment
# assignment depends on the fox population at the time of intervention.
def confounded_propensity_scores(untreated_run, intervention):
    if untreated_run[intervention.time].foxes > 7:
        return 0.7
    return 0.4

# GOOD: Traceable
# Soft/continuous variant: Graph tracing discovers treatment depends on
# the fox population at the time of intervention.
def confounded_propensity_scores(untreated_run, intervention):
    run = untreated_run
    return 0.3 + 0.4 * (1. - soft_threshold(run[intervention.time].foxes, threshold=7))