"""Tigramite causal inference for time series."""
# Author: Jakob Runge <jakob@jakob-runge.com>
#
# License: GNU General Public License v3.0
import numpy as np
import math
import itertools
from copy import deepcopy
from collections import defaultdict
from tigramite.models import Models
import struct
[docs]class CausalEffects():
r"""Causal effect estimation.
Methods for the estimation of linear or non-parametric causal effects
between (potentially multivariate) X and Y (potentially conditional
on S) by (generalized) backdoor adjustment. Various graph types are
supported, also including hidden variables.
Linear and non-parametric estimators are based on sklearn. For the
linear case without hidden variables also an efficient estimation
based on Wright's path coefficients is available. This estimator
also allows to estimate mediation effects.
See the corresponding paper [6]_ and tigramite tutorial for an
in-depth introduction.
References
----------
.. [6] J. Runge, Necessary and sufficient graphical conditions for
optimal adjustment sets in causal graphical models with
hidden variables, Advances in Neural Information Processing
Systems, 2021, 34
https://proceedings.neurips.cc/paper/2021/hash/8485ae387a981d783f8764e508151cd9-Abstract.html
Parameters
----------
graph : array of either shape [N, N], [N, N, tau_max+1], or [N, N, tau_max+1, tau_max+1]
Different graph types are supported, see tutorial.
X : list of tuples
List of tuples [(i, -tau), ...] containing cause variables.
Y : list of tuples
List of tuples [(j, 0), ...] containing effect variables.
S : list of tuples
List of tuples [(i, -tau), ...] containing conditioned variables.
graph_type : str
Type of graph.
hidden_variables : list of tuples
Hidden variables in format [(i, -tau), ...]. The internal graph is
constructed by a latent projection.
check_SM_overlap : bool
Whether to check whether S overlaps with M.
verbosity : int, optional (default: 0)
Level of verbosity.
"""
def __init__(self,
graph,
graph_type,
X,
Y,
S=None,
hidden_variables=None,
check_SM_overlap=True,
verbosity=0):
self.verbosity = verbosity
self.N = graph.shape[0]
if S is None:
S = []
self.listX = list(X)
self.listY = list(Y)
self.listS = list(S)
self.X = set(X)
self.Y = set(Y)
self.S = set(S)
#
# Checks regarding graph type
#
supported_graphs = ['dag',
'admg',
'tsg_dag',
'tsg_admg',
'stationary_dag',
'stationary_admg',
# 'mag',
# 'tsg_mag',
# 'stationary_mag',
# 'pag',
# 'tsg_pag',
# 'stationary_pag',
]
if graph_type not in supported_graphs:
raise ValueError("Only graph types %s supported!" %supported_graphs)
# TODO?: check that masking aligns with hidden samples in variables
if hidden_variables is None:
hidden_variables = []
self.hidden_variables = set(hidden_variables)
if len(self.hidden_variables.intersection(self.X.union(self.Y).union(self.S))) > 0:
raise ValueError("XYS overlaps with hidden_variables!")
# Only needed for later extension to MAG/PAGs
if 'pag' in graph_type:
self.possible = True
self.definite_status = True
else:
self.possible = False
self.definite_status = False
# Not needed for now...
# self.ignore_time_bounds = False
# Construct internal graph from input graph depending on graph type
# and hidden variables
self._construct_graph(graph=graph, graph_type=graph_type,
hidden_variables=hidden_variables)
self._check_graph(self.graph)
self._check_XYS()
self.ancX = self._get_ancestors(X)
self.ancY = self._get_ancestors(Y)
self.ancS = self._get_ancestors(S)
# If X is not in anc(Y), then no causal link exists
if self.ancY.intersection(set(X)) == set():
self.no_causal_path = True
if self.verbosity > 0:
print("No causal path from X to Y exists.")
else:
self.no_causal_path = False
# Get mediators
mediators = self.get_mediators(start=self.X, end=self.Y)
M = set(mediators)
self.M = M
self.listM = list(self.M)
for varlag in self.X.union(self.Y).union(self.S):
if abs(varlag[1]) > self.tau_max:
raise ValueError("X, Y, S must have time lags inside graph.")
if len(self.X.intersection(self.Y)) > 0:
raise ValueError("Overlap between X and Y.")
if len(self.S.intersection(self.Y.union(self.X))) > 0:
raise ValueError("Conditions S overlap with X or Y.")
# # TODO: need to prove that this is sufficient for non-identifiability!
# if len(self.X.intersection(self._get_descendants(self.M))) > 0:
# raise ValueError("Not identifiable: Overlap between X and des(M)")
if check_SM_overlap and len(self.S.intersection(self.M)) > 0:
raise ValueError("Conditions S overlap with mediators M.")
self.desX = self._get_descendants(self.X)
self.desY = self._get_descendants(self.Y)
self.desM = self._get_descendants(self.M)
self.descendants = self.desY.union(self.desM)
# Define forb as X and descendants of YM
self.forbidden_nodes = self.descendants.union(self.X) #.union(S)
# Define valid ancestors
self.vancs = self.ancX.union(self.ancY).union(self.ancS) - self.forbidden_nodes
if self.verbosity > 0:
if len(self.S.intersection(self.desX)) > 0:
print("Warning: Potentially outside assumptions: Conditions S overlap with des(X)")
# Here only check if S overlaps with des(Y), leave the option that S
# contains variables in des(M) to the user
if len(self.S.intersection(self.desY)) > 0:
raise ValueError("Not identifiable: Conditions S overlap with des(Y).")
if self.verbosity > 0:
print("\n##\n## Initializing CausalEffects class\n##"
"\n\nInput:")
print("\ngraph_type = %s" % graph_type
+ "\nX = %s" % self.listX
+ "\nY = %s" % self.listY
+ "\nS = %s" % self.listS
+ "\nM = %s" % self.listM
)
if len(self.hidden_variables) > 0:
print("\nhidden_variables = %s" % self.hidden_variables
)
print("\n\n")
if self.no_causal_path:
print("No causal path from X to Y exists!")
def _construct_graph(self, graph, graph_type, hidden_variables):
"""Construct internal graph object based on input graph and hidden variables.
Uses the latent projection operation.
"""
if graph_type in ['dag', 'admg']:
if graph.ndim != 2:
raise ValueError("graph_type in ['dag', 'admg'] assumes graph.shape=(N, N).")
# Convert to shape [N, N, 1, 1] with dummy dimension
# to process as tsg_dag or tsg_admg with potential hidden variables
self.graph = np.expand_dims(graph, axis=(2, 3))
# tau_max needed in _get_latent_projection_graph
self.tau_max = 0
if len(hidden_variables) > 0:
self.graph = self._get_latent_projection_graph() # stationary=False)
self.graph_type = "tsg_admg"
else:
# graph = self.graph
self.graph_type = 'tsg_' + graph_type
elif graph_type in ['tsg_dag', 'tsg_admg']:
if graph.ndim != 4:
raise ValueError("tsg-graph_type assumes graph.shape=(N, N, tau_max+1, tau_max+1).")
# Then tau_max is implicitely derived from
# the dimensions
self.graph = graph
self.tau_max = graph.shape[2] - 1
if len(hidden_variables) > 0:
self.graph = self._get_latent_projection_graph() #, stationary=False)
self.graph_type = "tsg_admg"
else:
self.graph_type = graph_type
elif graph_type in ['stationary_dag', 'stationary_admg']:
# Currently only stationary_dag without hidden variables is supported
if graph.ndim != 3:
raise ValueError("stationary graph_type assumes graph.shape=(N, N, tau_max+1).")
# # TODO: remove if theory for stationary ADMGs is clear
# if graph_type == 'stationary_dag' and len(hidden_variables) > 0:
# raise ValueError("Hidden variables currently not supported for "
# "stationary_dag.")
# For a stationary DAG without hidden variables it's sufficient to consider
# a tau_max that includes the parents of X, Y, M, and S. A conservative
# estimate thereof is simply the lag-dimension of the stationary DAG plus
# the maximum lag of XYS.
statgraph_tau_max = graph.shape[2] - 1
maxlag_XYS = 0
for varlag in self.X.union(self.Y).union(self.S):
maxlag_XYS = max(maxlag_XYS, abs(varlag[1]))
self.tau_max = maxlag_XYS + statgraph_tau_max
stat_graph = deepcopy(graph)
#########################################
# Use this tau_max and construct ADMG by assuming paths of
# maximal lag 10*tau_max... TO BE REVISED!
self.graph = graph
self.graph = self._get_latent_projection_graph(stationary=True)
self.graph_type = "tsg_admg"
#########################################
# Also create stationary graph extended to tau_max
self.stationary_graph = np.zeros((self.N, self.N, self.tau_max + 1), dtype='<U3')
self.stationary_graph[:, :, :stat_graph.shape[2]] = stat_graph
# allowed_edges = ["-->", "<--"]
# # Construct tsg_graph
# graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype='<U3')
# graph[:] = ""
# for (i, j) in itertools.product(range(self.N), range(self.N)):
# for jt, tauj in enumerate(range(0, self.tau_max + 1)):
# for it, taui in enumerate(range(tauj, self.tau_max + 1)):
# tau = abs(taui - tauj)
# if tau == 0 and j == i:
# continue
# if tau > statgraph_tau_max:
# continue
# # if tau == 0:
# # if stat_graph[i, j, tau] == '-->':
# # graph[i, j, taui, tauj] = "-->"
# # graph[j, i, tauj, taui] = "<--"
# # # elif stat_graph[i, j, tau] == '<--':
# # # graph[i, j, taui, tauj] = "<--"
# # # graph[j, i, tauj, taui] = "-->"
# # else:
# if stat_graph[i, j, tau] == '-->':
# graph[i, j, taui, tauj] = "-->"
# graph[j, i, tauj, taui] = "<--"
# elif stat_graph[i, j, tau] == '<--':
# pass
# elif stat_graph[i, j, tau] == '':
# pass
# else:
# edge = stat_graph[i, j, tau]
# raise ValueError("Invalid graph edge %s. " %(edge) +
# "For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
# # elif stat_graph[i, j, tau] == '<--':
# # graph[i, j, taui, tauj] = "<--"
# # graph[j, i, tauj, taui] = "-->"
# self.graph_type = 'tsg_dag'
# self.graph = graph
# return (graph, graph_type, self.tau_max, hidden_variables)
# max_lag = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S)), graph=graph)
# stat_mediators = self._get_mediators_stationary_graph(start=X, end=Y, max_lag=max_lag)
# self.tau_max = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S).union(stat_mediators)), graph=graph)
# self.tau_max = graph_taumax
# for varlag in X.union(Y).union(S):
# self.tau_max = max(self.tau_max, abs(varlag[1]))
# if verbosity > 0:
# print("Setting tau_max = ", self.tau_max)
# if tau_max is None:
# self.tau_max = graph_taumax
# for varlag in X.union(Y).union(S):
# self.tau_max = max(self.tau_max, abs(varlag[1]))
# if verbosity > 0:
# print("Setting tau_max = ", self.tau_max)
# else:
# self.tau_max = graph_taumax
# # Repeat hidden variable pattern
# # if larger tau_max is given
# if self.tau_max > graph_taumax:
# for lag in range(graph_taumax + 1, self.tau_max + 1):
# for j in range(self.N):
# if (j, -(lag % (graph_taumax+1))) in self.hidden_variables:
# self.hidden_variables.add((j, -lag))
# print(self.hidden_variables)
# self.graph = self._get_latent_projection_graph(self.graph, stationary=True)
# self.graph_type = "tsg_admg"
# else:
def _check_XYS(self):
"""Check whether XYS are sober.
"""
XYS = self.X.union(self.Y).union(self.S)
for xys in XYS:
var, lag = xys
if var < 0 or var >= self.N:
raise ValueError("XYS vars must be in [0...N]")
if lag < -self.tau_max or lag > 0:
raise ValueError("XYS lags must be in [-taumax...0]")
[docs] def check_XYS_paths(self):
"""Check whether one can remove nodes from X and Y with no proper causal paths.
Returns
-------
X, Y : cleaned lists of X and Y with irrelevant nodes removed.
"""
# TODO: Also check S...
oldX = self.X.copy()
oldY = self.Y.copy()
# anc_Y = self._get_ancestors(self.Y)
# anc_S = self._get_ancestors(self.S)
# Remove first from X those nodes with no causal path to Y or S
X = set([x for x in self.X if x in self.ancY.union(self.ancS)])
# Remove from Y those nodes with no causal path from X
# des_X = self._get_descendants(X)
Y = set([y for y in self.Y if y in self.desX])
# Also require that all x in X have proper path to Y or S,
# that is, the first link goes out of x
# and into path nodes
mediators_S = self.get_mediators(start=self.X, end=self.S)
path_nodes = list(self.M.union(Y).union(mediators_S))
X = X.intersection(self._get_all_parents(path_nodes))
if set(oldX) != set(X) and self.verbosity > 0:
print("Consider pruning X = %s to X = %s " %(oldX, X) +
"since only these have causal path to Y")
if set(oldY) != set(Y) and self.verbosity > 0:
print("Consider pruning Y = %s to Y = %s " %(oldY, Y) +
"since only these have causal path from X")
return (list(X), list(Y))
def _check_graph(self, graph):
"""Checks that graph contains no invalid entries/structure.
Assumes graph.shape = (N, N, tau_max+1, tau_max+1)
"""
allowed_edges = ["-->", "<--"]
if 'admg' in self.graph_type:
allowed_edges += ["<->", "<-+", "+->"]
elif 'mag' in self.graph_type:
allowed_edges += ["<->"]
elif 'pag' in self.graph_type:
allowed_edges += ["<->", "o-o", "o->", "<-o"] # "o--",
# "--o",
# "x-o",
# "o-x",
# "x--",
# "--x",
# "x->",
# "<-x",
# "x-x",
# ]
graph_dict = defaultdict(list)
for i, j, taui, tauj in zip(*np.where(graph)):
edge = graph[i, j, taui, tauj]
# print((i, -taui), edge, (j, -tauj), graph[j, i, tauj, taui])
if edge != self._reverse_link(graph[j, i, tauj, taui]):
raise ValueError(
"graph needs to have consistent edges (eg"
" graph[i,j,taui,tauj]='-->' requires graph[j,i,tauj,taui]='<--')"
)
if edge not in allowed_edges:
raise ValueError("Invalid graph edge %s. " %(edge) +
"For graph_type = %s only %s are allowed." %(self.graph_type, str(allowed_edges)))
if edge == "-->" or edge == "+->":
# Map to (i,-taui, j, tauj) graph
indexi = i * (self.tau_max + 1) + taui
indexj = j * (self.tau_max + 1) + tauj
graph_dict[indexj].append(indexi)
# Check for cycles
if self._check_cyclic(graph_dict):
raise ValueError("graph is cyclic.")
# if MAG: check for almost cycles
# if PAG???
def _check_cyclic(self, graph_dict):
"""Return True if the graph_dict has a cycle.
graph_dict must be represented as a dictionary mapping vertices to
iterables of neighbouring vertices. For example:
>>> cyclic({1: (2,), 2: (3,), 3: (1,)})
True
>>> cyclic({1: (2,), 2: (3,), 3: (4,)})
False
"""
path = set()
visited = set()
def visit(vertex):
if vertex in visited:
return False
visited.add(vertex)
path.add(vertex)
for neighbour in graph_dict.get(vertex, ()):
if neighbour in path or visit(neighbour):
return True
path.remove(vertex)
return False
return any(visit(v) for v in graph_dict)
def _get_mediators_stationary_graph(self, start, end, max_lag):
"""Returns mediator variables on proper causal paths
from X to Y in a stationary graph."""
des_X = self._get_descendants_stationary_graph(start, max_lag)
mediators = set()
# Walk along proper causal paths backwards from Y to X
potential_mediators = set()
for y in end:
j, tau = y
this_level = [y]
while len(this_level) > 0:
next_level = []
for varlag in this_level:
for _, parent in self._get_adjacents_stationary_graph(graph=self.graph,
node=varlag, patterns=["<*-", "<*+"], max_lag=max_lag, exclude=None):
i, tau = parent
if (parent in des_X
and parent not in mediators
# and parent not in potential_mediators
and parent not in start
and parent not in end
# and (-self.tau_max <= tau <= 0 or self.ignore_time_bounds)
):
mediators = mediators.union(set([parent]))
next_level.append(parent)
this_level = next_level
return mediators
def _reverse_link(self, link):
"""Reverse a given link, taking care to replace > with < and vice versa."""
if link == "":
return ""
if link[2] == ">":
left_mark = "<"
else:
left_mark = link[2]
if link[0] == "<":
right_mark = ">"
else:
right_mark = link[0]
return left_mark + link[1] + right_mark
def _match_link(self, pattern, link):
"""Matches pattern including wildcards with link.
In an ADMG we have edge types ["-->", "<--", "<->", "+->", "<-+"].
Here +-> corresponds to having both "-->" and "<->".
In a MAG we have edge types ["-->", "<--", "<->", "---"].
"""
if pattern == '' or link == '':
return True if pattern == link else False
else:
left_mark, middle_mark, right_mark = pattern
if left_mark != '*':
# if link[0] != '+':
if link[0] != left_mark: return False
if right_mark != '*':
# if link[2] != '+':
if link[2] != right_mark: return False
if middle_mark != '*' and link[1] != middle_mark: return False
return True
def _find_adj(self, node, patterns, exclude=None, return_link=False):
"""Find adjacencies of node that match given patterns."""
graph = self.graph
if exclude is None:
exclude = []
# exclude = self.hidden_variables
# else:
# exclude = set(exclude).union(self.hidden_variables)
# Setup
i, lag_i = node
lag_i = abs(lag_i)
if exclude is None: exclude = []
if type(patterns) == str:
patterns = [patterns]
# Init
adj = []
# Find adjacencies going forward/contemp
for k, lag_ik in zip(*np.where(graph[i,:,lag_i,:])):
# print((k, lag_ik), graph[i,k,lag_i,lag_ik])
# matches = [self._match_link(patt, graph[i,k,lag_i,lag_ik]) for patt in patterns]
# if np.any(matches):
for patt in patterns:
if self._match_link(patt, graph[i,k,lag_i,lag_ik]):
match = (k, -lag_ik)
if match not in exclude:
if return_link:
adj.append((graph[i,k,lag_i,lag_ik], match))
else:
adj.append(match)
break
# Find adjacencies going backward/contemp
for k, lag_ki in zip(*np.where(graph[:,i,:,lag_i])):
# print((k, lag_ki), graph[k,i,lag_ki,lag_i])
# matches = [self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]) for patt in patterns]
# if np.any(matches):
for patt in patterns:
if self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]):
match = (k, -lag_ki)
if match not in exclude:
if return_link:
adj.append((self._reverse_link(graph[k,i,lag_ki,lag_i]), match))
else:
adj.append(match)
break
adj = list(set(adj))
return adj
def _is_match(self, nodei, nodej, pattern_ij):
"""Check whether the link between X and Y agrees with pattern."""
graph = self.graph
(i, lag_i) = nodei
(j, lag_j) = nodej
tauij = lag_j - lag_i
if abs(tauij) >= graph.shape[2]:
return False
return ((tauij >= 0 and self._match_link(pattern_ij, graph[i, j, tauij])) or
(tauij < 0 and self._match_link(self._reverse_link(pattern_ij), graph[j, i, abs(tauij)])))
def _get_children(self, varlag):
"""Returns set of children (varlag --> ...) for (lagged) varlag."""
if self.possible:
patterns=['-*>', 'o*o', 'o*>']
else:
patterns=['-*>', '+*>']
return self._find_adj(node=varlag, patterns=patterns)
def _get_parents(self, varlag):
"""Returns set of parents (varlag <-- ...) for (lagged) varlag."""
if self.possible:
patterns=['<*-', 'o*o', '<*o']
else:
patterns=['<*-', '<*+']
return self._find_adj(node=varlag, patterns=patterns)
def _get_spouses(self, varlag):
"""Returns set of spouses (varlag <-> ...) for (lagged) varlag."""
return self._find_adj(node=varlag, patterns=['<*>', '+*>', '<*+'])
def _get_neighbors(self, varlag):
"""Returns set of neighbors (varlag --- ...) for (lagged) varlag."""
return self._find_adj(node=varlag, patterns=['-*-'])
def _get_ancestors(self, W):
"""Get ancestors of nodes in W up to time tau_max.
Includes the nodes themselves.
"""
ancestors = set(W)
for w in W:
j, tau = w
this_level = [w]
while len(this_level) > 0:
next_level = []
for varlag in this_level:
for par in self._get_parents(varlag):
i, tau = par
if par not in ancestors and -self.tau_max <= tau <= 0:
ancestors = ancestors.union(set([par]))
next_level.append(par)
this_level = next_level
return ancestors
def _get_all_parents(self, W):
"""Get parents of nodes in W up to time tau_max.
Includes the nodes themselves.
"""
parents = set(W)
for w in W:
j, tau = w
for par in self._get_parents(w):
i, tau = par
if par not in parents and -self.tau_max <= tau <= 0:
parents = parents.union(set([par]))
return parents
def _get_all_spouses(self, W):
"""Get spouses of nodes in W up to time tau_max.
Includes the nodes themselves.
"""
spouses = set(W)
for w in W:
j, tau = w
for spouse in self._get_spouses(w):
i, tau = spouse
if spouse not in spouses and -self.tau_max <= tau <= 0:
spouses = spouses.union(set([spouse]))
return spouses
def _get_descendants_stationary_graph(self, W, max_lag):
"""Get descendants of nodes in W up to time t in stationary graph.
Includes the nodes themselves.
"""
descendants = set(W)
for w in W:
j, tau = w
this_level = [w]
while len(this_level) > 0:
next_level = []
for varlag in this_level:
for _, child in self._get_adjacents_stationary_graph(graph=self.graph,
node=varlag, patterns=["-*>", "-*+"], max_lag=max_lag, exclude=None):
i, tau = child
if (child not in descendants
# and (-self.tau_max <= tau <= 0 or self.ignore_time_bounds)
):
descendants = descendants.union(set([child]))
next_level.append(child)
this_level = next_level
return descendants
def _get_descendants(self, W):
"""Get descendants of nodes in W up to time t.
Includes the nodes themselves.
"""
descendants = set(W)
for w in W:
j, tau = w
this_level = [w]
while len(this_level) > 0:
next_level = []
for varlag in this_level:
for child in self._get_children(varlag):
i, tau = child
if (child not in descendants
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
descendants = descendants.union(set([child]))
next_level.append(child)
this_level = next_level
return descendants
def _get_collider_path_nodes(self, W, descendants):
"""Get non-descendant collider path nodes and their parents of nodes in W up to time t.
"""
collider_path_nodes = set([])
# print("descendants ", descendants)
for w in W:
# print(w)
j, tau = w
this_level = [w]
while len(this_level) > 0:
next_level = []
for varlag in this_level:
# print("\t", varlag, self._get_spouses(varlag))
for spouse in self._get_spouses(varlag):
# print("\t\t", spouse)
i, tau = spouse
if (spouse not in collider_path_nodes
and spouse not in descendants
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
collider_path_nodes = collider_path_nodes.union(set([spouse]))
next_level.append(spouse)
this_level = next_level
# Add parents
for w in collider_path_nodes:
for par in self._get_parents(w):
if (par not in collider_path_nodes
and par not in descendants
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
collider_path_nodes = collider_path_nodes.union(set([par]))
return collider_path_nodes
def _get_adjacents_stationary_graph(self, graph, node, patterns,
max_lag=0, exclude=None):
"""Find adjacencies of node matching patterns in a stationary graph."""
# graph = self.graph
# Setup
i, lag_i = node
if exclude is None: exclude = []
if type(patterns) == str:
patterns = [patterns]
# Init
adj = []
# Find adjacencies going forward/contemp
for k, lag_ik in zip(*np.where(graph[i,:,:])):
matches = [self._match_link(patt, graph[i, k, lag_ik]) for patt in patterns]
if np.any(matches):
match = (k, lag_i + lag_ik)
if (k, lag_i + lag_ik) not in exclude and (-max_lag <= lag_i + lag_ik <= 0): # or self.ignore_time_bounds):
adj.append((graph[i, k, lag_ik], match))
# Find adjacencies going backward/contemp
for k, lag_ki in zip(*np.where(graph[:,i,:])):
matches = [self._match_link(self._reverse_link(patt), graph[k, i, lag_ki]) for patt in patterns]
if np.any(matches):
match = (k, lag_i - lag_ki)
if (k, lag_i - lag_ki) not in exclude and (-max_lag <= lag_i - lag_ki <= 0): # or self.ignore_time_bounds):
adj.append((self._reverse_link(graph[k, i, lag_ki]), match))
adj = list(set(adj))
return adj
def _get_canonical_dag_from_graph(self, graph):
"""Constructs canonical DAG as links_coeffs dictionary from graph.
For every <-> link further latent variables are added.
This corresponds to a canonical DAG (Richardson Spirtes 2002).
Can be used to evaluate d-separation.
"""
N, N, tau_maxplusone = graph.shape
tau_max = tau_maxplusone - 1
links = {j: [] for j in range(N)}
# Add further latent variables to accommodate <-> links
latent_index = N
for i, j, tau in zip(*np.where(graph)):
edge_type = graph[i, j, tau]
# Consider contemporaneous links only once
if tau == 0 and j > i:
continue
if edge_type == "-->":
links[j].append((i, -tau))
elif edge_type == "<--":
links[i].append((j, -tau))
elif edge_type == "<->":
links[latent_index] = []
links[i].append((latent_index, 0))
links[j].append((latent_index, -tau))
latent_index += 1
# elif edge_type == "---":
# links[latent_index] = []
# selection_vars.append(latent_index)
# links[latent_index].append((i, -tau))
# links[latent_index].append((j, 0))
# latent_index += 1
elif edge_type == "+->":
links[j].append((i, -tau))
links[latent_index] = []
links[i].append((latent_index, 0))
links[j].append((latent_index, -tau))
latent_index += 1
elif edge_type == "<-+":
links[i].append((j, -tau))
links[latent_index] = []
links[i].append((latent_index, 0))
links[j].append((latent_index, -tau))
latent_index += 1
return links
def _get_maximum_possible_lag(self, XYZ, graph):
"""Construct maximum relevant time lag for d-separation in stationary graph.
TO BE REVISED!
"""
def _repeating(link, seen_path):
"""Returns True if a link or its time-shifted version is already
included in seen_links."""
i, taui = link[0]
j, tauj = link[1]
for index, seen_link in enumerate(seen_path[:-1]):
seen_i, seen_taui = seen_link
seen_j, seen_tauj = seen_path[index + 1]
if (i == seen_i and j == seen_j
and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
return True
return False
# TODO: does this work with PAGs?
# if self.possible:
# patterns=['<*-', '<*o', 'o*o']
# else:
# patterns=['<*-']
canonical_dag_links = self._get_canonical_dag_from_graph(graph)
max_lag = 0
for node in XYZ:
j, tau = node # tau <= 0
max_lag = max(max_lag, abs(tau))
causal_path = []
queue = [(node, causal_path)]
while queue:
varlag, causal_path = queue.pop()
causal_path = [varlag] + causal_path
var, lag = varlag
for partmp in canonical_dag_links[var]:
i, tautmp = partmp
# Get shifted lag since canonical_dag_links is at t=0
tau = tautmp + lag
par = (i, tau)
if (par not in causal_path):
if len(causal_path) == 1:
queue.append((par, causal_path))
continue
if (len(causal_path) > 1) and not _repeating((par, varlag), causal_path):
max_lag = max(max_lag, abs(tau))
queue.append((par, causal_path))
return max_lag
def _get_latent_projection_graph(self, stationary=False):
"""For DAGs/ADMGs uses the Latent projection operation (Pearl 2009).
Assumes a normal or stationary graph with potentially unobserved nodes.
Also allows particular time steps to be unobserved. By stationarity
that pattern of unobserved nodes is repeated into -infinity.
Latent projection operation for latents = nodes before t-tau_max or due to <->:
(i) auxADMG contains (i, -taui) --> (j, -tauj) iff there is a directed path
(i, -taui) --> ... --> (j, -tauj) on which
every non-endpoint vertex is in hidden variables (= not in observed_vars)
here iff (i, -|taui-tauj|) --> j in graph
(ii) auxADMG contains (i, -taui) <-> (j, -tauj) iff there exists a path of the
form (i, -taui) <-- ... --> (j, -tauj) on
which every non-endpoint vertex is non-collider AND in L (=not in observed_vars)
here iff (i, -|taui-tauj|) <-> j OR there is path
(i, -taui) <-- nodes before t-tau_max --> (j, -tauj)
"""
# graph = self.graph
# if self.hidden_variables is None:
# hidden_variables_here = []
# else:
hidden_variables_here = self.hidden_variables
aux_graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype='<U3')
aux_graph[:] = ""
for (i, j) in itertools.product(range(self.N), range(self.N)):
for jt, tauj in enumerate(range(0, self.tau_max + 1)):
for it, taui in enumerate(range(0, self.tau_max + 1)):
tau = abs(taui - tauj)
if tau == 0 and j == i:
continue
if (i, -taui) in hidden_variables_here or (j, -tauj) in hidden_variables_here:
continue
# print("\n")
# print((i, -taui), (j, -tauj))
cond_i_xy = (
# tau <= graph_taumax
# and (graph[i, j, tau] == '-->' or graph[i, j, tau] == '+->')
# )
# and
self._check_path( #graph=graph,
start=[(i, -taui)],
end=[(j, -tauj)],
conditions=None,
starts_with=['-*>', '+*>'],
ends_with=['-*>', '+*>'],
path_type='causal',
hidden_by_taumax=False,
hidden_variables=hidden_variables_here,
stationary_graph=stationary,
))
cond_i_yx = (
# tau <= graph_taumax
# and (graph[i, j, tau] == '<--' or graph[i, j, tau] == '<-+')
# )
# and
self._check_path( #graph=graph,
start=[(j, -tauj)],
end=[(i, -taui)],
conditions=None,
starts_with=['-*>', '+*>'],
ends_with=['-*>', '+*>'],
path_type='causal',
hidden_by_taumax=False,
hidden_variables=hidden_variables_here,
stationary_graph=stationary,
))
if stationary:
hidden_by_taumax_here = True
else:
hidden_by_taumax_here = False
cond_ii = (
# tau <= graph_taumax
# and
(
# graph[i, j, tau] == '<->'
# or graph[i, j, tau] == '+->' or graph[i, j, tau] == '<-+'))
self._check_path( #graph=graph,
start=[(i, -taui)],
end=[(j, -tauj)],
conditions=None,
starts_with=['<**', '+**'],
ends_with=['**>', '**+'],
path_type='any',
hidden_by_taumax=hidden_by_taumax_here,
hidden_variables=hidden_variables_here,
stationary_graph=stationary,
)))
if cond_i_xy and not cond_i_yx and not cond_ii:
aux_graph[i, j, taui, tauj] = "-->" #graph[i, j, tau]
# if tau == 0:
aux_graph[j, i, tauj, taui] = "<--" # graph[j, i, tau]
elif not cond_i_xy and cond_i_yx and not cond_ii:
aux_graph[i, j, taui, tauj] = "<--" #graph[i, j, tau]
# if tau == 0:
aux_graph[j, i, tauj, taui] = "-->" # graph[j, i, tau]
elif not cond_i_xy and not cond_i_yx and cond_ii:
aux_graph[i, j, taui, tauj] = '<->'
# if tau == 0:
aux_graph[j, i, tauj, taui] = '<->'
elif cond_i_xy and not cond_i_yx and cond_ii:
aux_graph[i, j, taui, tauj] = '+->'
# if tau == 0:
aux_graph[j, i, tauj, taui] = '<-+'
elif not cond_i_xy and cond_i_yx and cond_ii:
aux_graph[i, j, taui, tauj] = '<-+'
# if tau == 0:
aux_graph[j, i, tauj, taui] = '+->'
elif cond_i_xy and cond_i_yx:
raise ValueError("Cycle between %s and %s!" %(str(i, -taui), str(j, -tauj)))
# print(aux_graph[i, j, taui, tauj])
# print((i, -taui), (j, -tauj), cond_i_xy, cond_i_yx, cond_ii, aux_graph[i, j, taui, tauj], aux_graph[j, i, tauj, taui])
return aux_graph
def _check_path(self,
# graph,
start, end,
conditions=None,
starts_with=None,
ends_with=None,
path_type='any',
# causal_children=None,
stationary_graph=False,
hidden_by_taumax=False,
hidden_variables=None,
):
"""Check whether an open/active path between start and end given conditions exists.
Also allows to restrict start and end patterns and to consider causal/non-causal paths
hidden_by_taumax and hidden_variables are relevant for the latent projection operation.
"""
if conditions is None:
conditions = set([])
# if conditioned_variables is None:
# S = []
start = set(start)
end = set(end)
conditions = set(conditions)
# Get maximal possible time lag of a connecting path
# See Thm. XXXX - TO BE REVISED!
XYZ = start.union(end).union(conditions)
if stationary_graph:
max_lag = 10*self.tau_max # TO BE REVISED! self._get_maximum_possible_lag(XYZ, self.graph)
causal_children = list(self._get_mediators_stationary_graph(start, end, max_lag).union(end))
else:
max_lag = None
causal_children = list(self.get_mediators(start, end).union(end))
# if hidden_variables is None:
# hidden_variables = set([])
if hidden_by_taumax:
if hidden_variables is None:
hidden_variables = set([])
hidden_variables = hidden_variables.union([(k, -tauk) for k in range(self.N)
for tauk in range(self.tau_max+1, max_lag + 1)])
# print("causal_children ", causal_children)
if starts_with is None:
starts_with = ['***']
elif type(starts_with) == str:
starts_with = [starts_with]
if ends_with is None:
ends_with = ['***']
elif type(ends_with) == str:
ends_with = [ends_with]
#
# Breadth-first search to find connection
#
# print("\nstart, starts_with, ends_with, end ", start, starts_with, ends_with, end)
# print("hidden_variables ", hidden_variables)
start_from = set()
for x in start:
if stationary_graph:
link_neighbors = self._get_adjacents_stationary_graph(graph=self.graph, node=x, patterns=starts_with,
max_lag=max_lag, exclude=list(start))
else:
link_neighbors = self._find_adj(node=x, patterns=starts_with, exclude=list(start), return_link=True)
for link_neighbor in link_neighbors:
link, neighbor = link_neighbor
# if before_taumax and neighbor[1] >= -self.tau_max:
# continue
if (hidden_variables is not None and neighbor not in end
and neighbor not in hidden_variables):
continue
if path_type == 'non_causal':
if (neighbor in causal_children and self._match_link('-*>', link)
and not self._match_link('+*>', link)):
continue
elif path_type == 'causal':
if (neighbor not in causal_children): # or self._match_link('<**', link)):
continue
start_from.add((x, link, neighbor))
# print("start, end, start_from ", start, end, start_from)
visited = set()
for (varlag_i, link_ik, varlag_k) in start_from:
visited.add((link_ik, varlag_k))
# Traversing through motifs i *-* k *-* j
while start_from:
# print("Continue ", start_from)
# for (link_ik, varlag_k) in start_from:
removables = []
for (varlag_i, link_ik, varlag_k) in start_from:
# print("varlag_k in end ", varlag_k in end, link_ik)
if varlag_k in end:
if np.any([self._match_link(patt, link_ik) for patt in ends_with]):
# print("Connected ", varlag_i, link_ik, varlag_k)
return True
else:
removables.append((varlag_i, link_ik, varlag_k))
for removable in removables:
start_from.remove(removable)
if len(start_from)==0:
return False
# Get any neighbor from starting nodes
# link_ik, varlag_k = start_from.pop()
varlag_i, link_ik, varlag_k = start_from.pop()
# print("Get k = ", link_ik, varlag_k)
# print("start_from ", start_from)
# print("visited ", visited)
if stationary_graph:
link_neighbors = self._get_adjacents_stationary_graph(graph=self.graph, node=varlag_k, patterns='***',
max_lag=max_lag, exclude=list(start))
else:
link_neighbors = self._find_adj(node=varlag_k, patterns='***', exclude=list(start), return_link=True)
# print("link_neighbors ", link_neighbors)
for link_neighbor in link_neighbors:
link_kj, varlag_j = link_neighbor
# print("Walk ", link_ik, varlag_k, link_kj, varlag_j)
# print ("visited ", (link_kj, varlag_j), visited)
if (link_kj, varlag_j) in visited:
# if (varlag_i, link_kj, varlag_j) in visited:
# print("in visited")
continue
# print("Not in visited")
if path_type == 'causal':
if not (self._match_link('-*>', link_kj) or self._match_link('+*>', link_kj)):
continue
# If motif i *-* k *-* j is open,
# then add link_kj, varlag_j to visited and start_from
left_mark = link_ik[2]
right_mark = link_kj[0]
# print(left_mark, right_mark)
if self.definite_status:
# Exclude paths that are not definite_status implying that any of the following
# motifs occurs:
# i *-> k o-* j
if (left_mark == '>' and right_mark == 'o'):
continue
# i *-o k <-* j
if (left_mark == 'o' and right_mark == '<'):
continue
# i *-o k o-* j and i and j are adjacent
if (left_mark == 'o' and right_mark == 'o'
and self._is_match(varlag_i, varlag_j, "***")):
continue
# If k is in conditions and motif is *-o k o-*, then motif is blocked since
# i and j are non-adjacent due to the check above
if varlag_k in conditions and (left_mark == 'o' and right_mark == 'o'):
# print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
# If k is in conditions and left or right mark is tail '-', then motif is blocked
if varlag_k in conditions and (left_mark == '-' or right_mark == '-'):
# print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
# If k is not in conditions and left and right mark are heads '><', then motif is blocked
if varlag_k not in conditions and (left_mark == '>' and right_mark == '<'):
# print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
# if (before_taumax and varlag_j not in end
# and varlag_j[1] >= -self.tau_max):
# # print("before_taumax ", varlag_j)
# continue
if (hidden_variables is not None and varlag_j not in end
and varlag_j not in hidden_variables):
continue
# Motif is open
# print("Motif open ", link_ik, varlag_k, link_kj, varlag_j )
# start_from.add((link_kj, varlag_j))
visited.add((link_kj, varlag_j))
start_from.add((varlag_k, link_kj, varlag_j))
# visited.add((varlag_k, link_kj, varlag_j))
# print("Separated")
return False
[docs] def get_optimal_set(self,
alternative_conditions=None,
minimize=False,
return_separate_sets=False,
):
"""Returns optimal adjustment set.
See Runge NeurIPS 2021.
Parameters
----------
alternative_conditions : set of tuples
Used only internally in optimality theorem. If None, self.S is used.
minimize : {False, True, 'colliders_only'}
Minimize optimal set. If True, minimize such that no subset
can be removed without making it invalid. If 'colliders_only',
only colliders are minimized.
return_separate_sets : bool
Whether to return tuple of parents, colliders, collider_parents, and S.
Returns
-------
Oset_S : False or list or tuple of lists
Returns optimal adjustment set if a valid set exists, otherwise False.
"""
# Needed for optimality theorem where Osets for alternative S are tested
if alternative_conditions is None:
S = self.S.copy()
vancs = self.vancs.copy()
else:
S = alternative_conditions
newancS = self._get_ancestors(S)
vancs = self.ancX.union(self.ancY).union(newancS) - self.forbidden_nodes
# vancs = self._get_ancestors(list(self.X.union(self.Y).union(S))) - self.forbidden_nodes
# descendants = self._get_descendants(self.Y.union(self.M))
# Sufficient condition for non-identifiability
if len(self.X.intersection(self.descendants)) > 0:
return False # raise ValueError("Not identifiable: Overlap between X and des(M)")
##
## Construct O-set
##
# Start with parents
parents = self._get_all_parents(self.Y.union(self.M)) # set([])
# Remove forbidden nodes
parents = parents - self.forbidden_nodes
# Construct valid collider path nodes
colliders = set([])
for w in self.Y.union(self.M):
j, tau = w
this_level = [w]
non_suitable_nodes = []
while len(this_level) > 0:
next_level = []
for varlag in this_level:
suitable_spouses = set(self._get_spouses(varlag)) - set(non_suitable_nodes)
for spouse in suitable_spouses:
i, tau = spouse
if spouse in self.X:
return False
if (# Node not already in set
spouse not in colliders #.union(parents)
# not forbidden
and spouse not in self.forbidden_nodes
# in time bounds
and (-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
and (spouse in vancs
or not self._check_path(#graph=self.graph,
start=self.X, end=[spouse],
conditions=list(parents.union(vancs)) + list(S),
))
):
colliders = colliders.union(set([spouse]))
next_level.append(spouse)
else:
if spouse not in colliders:
non_suitable_nodes.append(spouse)
this_level = set(next_level) - set(non_suitable_nodes)
# Add parents and raise Error if not identifiable
collider_parents = self._get_all_parents(colliders)
if len(self.X.intersection(collider_parents)) > 0:
return False
colliders_and_their_parents = colliders.union(collider_parents)
# Add valid collider path nodes and their parents
Oset = parents.union(colliders_and_their_parents)
if minimize:
removable = []
# First remove all those that have no path from X
sorted_Oset = Oset
if minimize == 'colliders_only':
sorted_Oset = [node for node in sorted_Oset if node not in parents]
for node in sorted_Oset:
if (not self._check_path(#graph=self.graph,
start=self.X, end=[node],
conditions=list(Oset - set([node])) + list(S))):
removable.append(node)
Oset = Oset - set(removable)
if minimize == 'colliders_only':
sorted_Oset = [node for node in Oset if node not in parents]
removable = []
# Next remove all those with no direct connection to Y
for node in sorted_Oset:
if (not self._check_path(#graph=self.graph,
start=[node], end=self.Y,
conditions=list(Oset - set([node])) + list(S) + list(self.X),
ends_with=['**>', '**+'])):
removable.append(node)
Oset = Oset - set(removable)
Oset_S = Oset.union(S)
if return_separate_sets:
return parents, colliders, collider_parents, S
else:
return list(Oset_S)
def _get_collider_paths_optimality(self, source_nodes, target_nodes,
condition,
inside_set=None,
start_with_tail_or_head=False,
):
"""Returns relevant collider paths to check optimality.
Iterates over collider paths within O-set via depth-first search
"""
for w in source_nodes:
# Only used to return *all* collider paths
# (needed in optimality theorem)
coll_path = []
queue = [(w, coll_path)]
non_valid_subsets = []
while queue:
varlag, coll_path = queue.pop()
coll_path = coll_path + [varlag]
suitable_nodes = set(self._get_spouses(varlag))
if start_with_tail_or_head and coll_path == [w]:
children = set(self._get_children(varlag))
suitable_nodes = suitable_nodes.union(children)
for node in suitable_nodes:
i, tau = node
if ((-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
and node not in coll_path):
if condition == 'II' and node not in target_nodes and node not in self.vancs:
continue
if node in inside_set:
if condition == 'I':
non_valid = False
for pathset in non_valid_subsets[::-1]:
if set(pathset).issubset(set(coll_path + [node])):
non_valid = True
break
if non_valid is False:
queue.append((node, coll_path))
else:
continue
elif condition == 'II':
queue.append((node, coll_path))
if node in target_nodes:
# yield coll_path
# collider_paths[node].append(coll_path)
if condition == 'I':
# Construct OÏ€iN
Sprime = self.S.union(coll_path)
OpiN = self.get_optimal_set(alternative_conditions=Sprime)
if OpiN is False:
queue = [(q_node, q_path) for (q_node, q_path) in queue if set(coll_path).issubset(set(q_path + [q_node])) is False]
non_valid_subsets.append(coll_path)
else:
return False
elif condition == 'II':
return True
# yield coll_path
if condition == 'I':
return True
elif condition == 'II':
return False
# return collider_paths
[docs] def check_optimality(self):
"""Check whether optimal adjustment set exists according to Thm. 3 in Runge NeurIPS 2021.
Returns
-------
optimality : bool
Returns True if an optimal adjustment set exists, otherwise False.
"""
# Cond. 0: Exactly one valid adjustment set exists
cond_0 = (self._get_all_valid_adjustment_sets(check_one_set_exists=True))
#
# Cond. I
#
parents, colliders, collider_parents, _ = self.get_optimal_set(return_separate_sets=True)
Oset = parents.union(colliders).union(collider_parents)
n_nodes = self._get_all_spouses(self.Y.union(self.M).union(colliders)) - self.forbidden_nodes - Oset - self.S - self.Y - self.M - colliders
if (len(n_nodes) == 0):
# # (1) There are no spouses N ∈ sp(YMC) \ (forbOS)
cond_I = True
else:
# (2) For all N ∈ N and all its collider paths i it holds that
# OÏ€iN does not block all non-causal paths from X to Y
# cond_I = True
cond_I = self._get_collider_paths_optimality(
source_nodes=list(n_nodes), target_nodes=list(self.Y.union(self.M)),
condition='I',
inside_set=Oset.union(self.S), start_with_tail_or_head=False,
)
#
# Cond. II
#
e_nodes = Oset.difference(parents)
cond_II = True
for E in e_nodes:
Oset_minusE = Oset.difference(set([E]))
if self._check_path(#graph=self.graph,
start=list(self.X), end=[E],
conditions=list(self.S) + list(Oset_minusE)):
cond_II = self._get_collider_paths_optimality(
target_nodes=self.Y.union(self.M),
source_nodes=list(set([E])),
condition='II',
inside_set=list(Oset.union(self.S)),
start_with_tail_or_head = True)
if cond_II is False:
if self.verbosity > 1:
print("Non-optimal due to E = ", E)
break
optimality = (cond_0 or (cond_I and cond_II))
if self.verbosity > 0:
print("Optimality = %s with cond_0 = %s, cond_I = %s, cond_II = %s"
% (optimality, cond_0, cond_I, cond_II))
return optimality
def _check_validity(self, Z):
"""Checks whether Z is a valid adjustment set."""
# causal_children = list(self.M.union(self.Y))
backdoor_path = self._check_path(#graph=self.graph,
start=list(self.X), end=list(self.Y),
conditions=list(Z),
# causal_children=causal_children,
path_type = 'non_causal')
if backdoor_path:
return False
else:
return True
def _get_adjust_set(self,
minimize=False,
):
"""Returns Adjust-set.
See van der Zander, B.; Liśkiewicz, M. & Textor, J.
Separators and adjustment sets in causal graphs: Complete
criteria and an algorithmic framework
Artificial Intelligence, Elsevier, 2019, 270, 1-40
"""
vancs = self.vancs.copy()
if minimize:
# Get removable nodes by computing minimal valid set from Z
if minimize == 'keep_parentsYM':
minimize_nodes = vancs - self._get_all_parents(list(self.Y.union(self.M)))
else:
minimize_nodes = vancs
# Zprime2 = Zprime
# First remove all nodes that have no unique path to X given Oset
for node in minimize_nodes:
# path = self.oracle.check_shortest_path(X=X, Y=[node],
# Z=list(vancs - set([node])),
# max_lag=None,
# starts_with=None, #'arrowhead',
# forbidden_nodes=None, #list(Zprime - set([node])),
# return_path=False)
path = self._check_path(#graph=self.graph,
start=self.X, end=[node],
conditions=list(vancs - set([node])),
)
if path is False:
vancs = vancs - set([node])
if minimize == 'keep_parentsYM':
minimize_nodes = vancs - self._get_all_parents(list(self.Y.union(self.M)))
else:
minimize_nodes = vancs
# print(Zprime2)
# Next remove all nodes that have no unique path to Y given Oset_min
# Z = Zprime2
for node in minimize_nodes:
path = self._check_path(#graph=self.graph,
start=[node], end=self.Y,
conditions=list(vancs - set([node])) + list(self.X),
)
if path is False:
vancs = vancs - set([node])
if self._check_validity(list(vancs)) is False:
return False
else:
return list(vancs)
def _get_all_valid_adjustment_sets(self,
check_one_set_exists=False, yield_index=None):
"""Constructs all valid adjustment sets or just checks whether one exists.
See van der Zander, B.; Liśkiewicz, M. & Textor, J.
Separators and adjustment sets in causal graphs: Complete
criteria and an algorithmic framework
Artificial Intelligence, Elsevier, 2019, 270, 1-40
"""
cond_set = set(self.S)
all_vars = [(i, -tau) for i in range(self.N)
for tau in range(0, self.tau_max + 1)]
all_vars_set = set(all_vars) - self.forbidden_nodes
def find_sep(I, R):
Rprime = R - self.X - self.Y
# TODO: anteriors and NOT ancestors where
# anteriors include --- links in causal paths
# print(I)
XYI = list(self.X.union(self.Y).union(I))
# print(XYI)
ancs = self._get_ancestors(list(XYI))
Z = ancs.intersection(Rprime)
if self._check_validity(Z) is False:
return False
else:
return Z
def list_sep(I, R):
# print(find_sep(X, Y, I, R))
if find_sep(I, R) is not False:
# print(I,R)
if I == R:
# print('--->', I)
yield I
else:
# Pick arbitrary node from R-I
RminusI = list(R - I)
# print(R, I, RminusI)
v = RminusI[0]
# print("here ", X, Y, I.union(set([v])), R)
yield from list_sep(I.union(set([v])), R)
yield from list_sep(I, R - set([v]))
# print("all ", X, Y, cond_set, all_vars_set)
all_sets = []
I = cond_set
R = all_vars_set
for index, valid_set in enumerate(list_sep(I, R)):
# print(valid_set)
all_sets.append(list(valid_set))
if check_one_set_exists and index > 0:
break
if yield_index is not None and index == yield_index:
return valid_set
if yield_index is not None:
return None
if check_one_set_exists:
if len(all_sets) == 1:
return True
else:
return False
return all_sets
def _get_causal_paths(self, source_nodes, target_nodes,
mediators=None,
mediated_through=None,
proper_paths=True,
):
"""Returns causal paths via depth-first search.
Allows to restrict paths through mediated_through.
"""
source_nodes = set(source_nodes)
target_nodes = set(target_nodes)
if mediators is None:
mediators = set()
else:
mediators = set(mediators)
if mediated_through is None:
mediated_through = []
mediated_through = set(mediated_through)
if proper_paths:
inside_set = mediators.union(target_nodes) - source_nodes
else:
inside_set = mediators.union(target_nodes).union(source_nodes)
all_causal_paths = {}
for w in source_nodes:
all_causal_paths[w] = {}
for z in target_nodes:
all_causal_paths[w][z] = []
for w in source_nodes:
causal_path = []
queue = [(w, causal_path)]
while queue:
varlag, causal_path = queue.pop()
causal_path = causal_path + [varlag]
suitable_nodes = set(self._get_children(varlag)
).intersection(inside_set)
for node in suitable_nodes:
i, tau = node
if ((-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
and node not in causal_path):
queue.append((node, causal_path))
if node in target_nodes:
if len(mediated_through) > 0 and len(set(causal_path).intersection(mediated_through)) == 0:
continue
else:
all_causal_paths[w][node].append(causal_path + [node])
return all_causal_paths
[docs] def fit_total_effect(self,
dataframe,
estimator,
adjustment_set='optimal',
conditional_estimator=None,
data_transform=None,
mask_type=None,
ignore_identifiability=False,
):
"""Returns a fitted model for the total causal effect of X on Y
conditional on S.
Parameters
----------
dataframe : data object
Tigramite dataframe object. It must have the attributes dataframe.values
yielding a numpy array of shape (observations T, variables N) and
optionally a mask of the same shape and a missing values flag.
estimator : sklearn model object
For example, sklearn.linear_model.LinearRegression() for a linear
regression model.
adjustment_set : str or list of tuples
If 'optimal' the Oset is used, if 'minimized_optimal' the minimized Oset,
and if 'colliders_minimized_optimal', the colliders-minimized Oset.
If a list of tuples is passed, this set is used.
conditional_estimator : sklearn model object, optional (default: None)
Used to fit conditional causal effects in nested regression.
If None, the same model as for estimator is used.
data_transform : sklearn preprocessing object, optional (default: None)
Used to transform data prior to fitting. For example,
sklearn.preprocessing.StandardScaler for simple standardization. The
fitted parameters are stored.
mask_type : {None, 'y','x','z','xy','xz','yz','xyz'}
Masking mode: Indicators for which variables in the dependence
measure I(X; Y | Z) the samples should be masked. If None, the mask
is not used. Explained in tutorial on masking and missing values.
ignore_identifiability : bool
Only applies to adjustment sets supplied by user. Ignores if that
set leads to a non-identifiable effect.
"""
if self.no_causal_path:
if self.verbosity > 0:
print("No causal path from X to Y exists.")
return self
self.dataframe = dataframe
self.conditional_estimator = conditional_estimator
# if self.dataframe.has_vector_data:
# raise ValueError("vector_vars in DataFrame cannot be used together with CausalEffects!"
# " You can estimate vector-valued effects by using multivariate X, Y, S."
# " Note, however, that this requires assuming a graph at the level "
# "of the components of X, Y, S, ...")
if self.N != self.dataframe.N:
raise ValueError("Dataset dimensions inconsistent with number of variables in graph.")
if adjustment_set == 'optimal':
# Check optimality and use either optimal or colliders_only set
adjustment_set = self.get_optimal_set()
elif adjustment_set == 'colliders_minimized_optimal':
adjustment_set = self.get_optimal_set(minimize='colliders_only')
elif adjustment_set == 'minimized_optimal':
adjustment_set = self.get_optimal_set(minimize=True)
else:
if ignore_identifiability is False and self._check_validity(adjustment_set) is False:
raise ValueError("Chosen adjustment_set is not valid.")
if adjustment_set is False:
raise ValueError("Causal effect not identifiable via adjustment.")
self.adjustment_set = adjustment_set
# Fit model of Y on X and Z (and conditions)
# Build the model
self.model = Models(
dataframe=dataframe,
model=estimator,
conditional_model=conditional_estimator,
data_transform=data_transform,
mask_type=mask_type,
verbosity=self.verbosity)
self.model.get_general_fitted_model(
Y=self.listY, X=self.listX, Z=list(self.adjustment_set),
conditions=self.listS,
tau_max=self.tau_max,
cut_off='tau_max',
return_data=False)
return self
[docs] def predict_total_effect(self,
intervention_data,
conditions_data=None,
pred_params=None,
return_further_pred_results=False,
aggregation_func=np.mean,
transform_interventions_and_prediction=False,
):
"""Predict effect of intervention with fitted model.
Uses the model.predict() function of the sklearn model.
Parameters
----------
intervention_data : numpy array
Numpy array of shape (time, len(X)) that contains the do(X) values.
conditions_data : data object, optional
Numpy array of shape (time, len(S)) that contains the S=s values.
pred_params : dict, optional
Optional parameters passed on to sklearn prediction function.
return_further_pred_results : bool, optional (default: False)
In case the predictor class returns more than just the expected value,
the entire results can be returned.
aggregation_func : callable
Callable applied to output of 'predict'. Default is 'np.mean'.
transform_interventions_and_prediction : bool (default: False)
Whether to perform the inverse data_transform on prediction results.
Returns
-------
Results from prediction: an array of shape (time, len(Y)).
If estimate_confidence = True, then a tuple is returned.
"""
def get_vectorized_length(W):
return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])
# lenX = len(self.listX)
# lenS = len(self.listS)
lenX = get_vectorized_length(self.listX)
lenS = get_vectorized_length(self.listS)
if intervention_data.shape[1] != lenX:
raise ValueError("intervention_data.shape[1] must be len(X).")
if conditions_data is not None and lenS > 0:
if conditions_data.shape[1] != lenS:
raise ValueError("conditions_data.shape[1] must be len(S).")
if conditions_data.shape[0] != intervention_data.shape[0]:
raise ValueError("conditions_data.shape[0] must match intervention_data.shape[0].")
elif conditions_data is not None and lenS == 0:
raise ValueError("conditions_data specified, but S=None or empty.")
elif conditions_data is None and lenS > 0:
raise ValueError("S specified, but conditions_data is None.")
if self.no_causal_path:
if self.verbosity > 0:
print("No causal path from X to Y exists.")
return np.zeros((len(intervention_data), len(self.listY)))
effect = self.model.get_general_prediction(
intervention_data=intervention_data,
conditions_data=conditions_data,
pred_params=pred_params,
return_further_pred_results=return_further_pred_results,
transform_interventions_and_prediction=transform_interventions_and_prediction,
aggregation_func=aggregation_func,)
return effect
[docs] def fit_wright_effect(self,
dataframe,
mediation=None,
method='parents',
links_coeffs=None,
data_transform=None,
mask_type=None,
):
"""Returns a fitted model for the total or mediated causal effect of X on Y
potentially through mediator variables.
Parameters
----------
dataframe : data object
Tigramite dataframe object. It must have the attributes dataframe.values
yielding a numpy array of shape (observations T, variables N) and
optionally a mask of the same shape and a missing values flag.
mediation : None, 'direct', or list of tuples
If None, total effect is estimated, if 'direct' then only the direct effect is estimated,
else only those causal paths are considerd that pass at least through one of these mediator nodes.
method : {'parents', 'links_coeffs', 'optimal'}
Method to use for estimating Wright's path coefficients. If 'optimal',
the Oset is used, if 'links_coeffs', the coefficients in links_coeffs are used,
if 'parents', the parents are used (only valid for DAGs).
links_coeffs : dict
Only used if method = 'links_coeffs'.
Dictionary of format: {0:[((i, -tau), coeff),...], 1:[...],
...} for all variables where i must be in [0..N-1] and tau >= 0 with
number of variables N. coeff must be a float.
data_transform : None
Not implemented for Wright estimator. Complicated for missing samples.
mask_type : {None, 'y','x','z','xy','xz','yz','xyz'}
Masking mode: Indicators for which variables in the dependence
measure I(X; Y | Z) the samples should be masked. If None, the mask
is not used. Explained in tutorial on masking and missing values.
"""
if self.no_causal_path:
if self.verbosity > 0:
print("No causal path from X to Y exists.")
return self
if data_transform is not None:
raise ValueError("data_transform not implemented for Wright estimator."
" You can preprocess data yourself beforehand.")
import sklearn.linear_model
self.dataframe = dataframe
if self.dataframe.has_vector_data:
raise ValueError("vector_vars in DataFrame cannot be used together with Wright method!"
" You can either 1) estimate vector-valued effects by using multivariate (X, Y, S)"
" together with assuming a graph at the level of the components of (X, Y, S), "
" or 2) use vector_vars together with fit_total_effect and an estimator"
" that supports multiple outputs.")
estimator = sklearn.linear_model.LinearRegression()
# Fit model of Y on X and Z (and conditions)
# Build the model
self.model = Models(
dataframe=dataframe,
model=estimator,
data_transform=None, #data_transform,
mask_type=mask_type,
verbosity=self.verbosity)
mediators = self.M # self.get_mediators(start=self.X, end=self.Y)
if mediation == 'direct':
causal_paths = {}
for w in self.X:
causal_paths[w] = {}
for z in self.Y:
if w in self._get_parents(z):
causal_paths[w][z] = [[w, z]]
else:
causal_paths[w][z] = []
else:
causal_paths = self._get_causal_paths(source_nodes=self.X,
target_nodes=self.Y, mediators=mediators,
mediated_through=mediation, proper_paths=True)
if method == 'links_coeffs':
coeffs = {}
max_lag = 0
for medy in [med for med in mediators] + [y for y in self.listY]:
coeffs[medy] = {}
j, tauj = medy
for ipar, par_coeff in enumerate(links_coeffs[medy[0]]):
par, coeff, _ = par_coeff
i, taui = par
taui_shifted = taui + tauj
max_lag = max(abs(par[1]), max_lag)
coeffs[medy][(i, taui_shifted)] = coeff #self.fit_results[j][(j, 0)]['model'].coef_[ipar]
self.model.tau_max = max_lag
# print(coeffs)
elif method == 'optimal':
# all_parents = {}
coeffs = {}
for medy in [med for med in mediators] + [y for y in self.listY]:
coeffs[medy] = {}
mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X).union(self.Y)) - set([medy])
all_parents = self._get_all_parents([medy]) - set([medy])
for par in mediator_parents:
Sprime = set(all_parents) - set([par, medy])
causal_effects = CausalEffects(graph=self.graph,
X=[par], Y=[medy], S=Sprime,
graph_type=self.graph_type,
check_SM_overlap=False,
)
oset = causal_effects.get_optimal_set()
# print(medy, par, list(set(all_parents)), oset)
if oset is False:
raise ValueError("Not identifiable via Wright's method.")
fit_res = self.model.get_general_fitted_model(
Y=[medy], X=[par], Z=oset,
tau_max=self.tau_max,
cut_off='tau_max',
return_data=False)
coeffs[medy][par] = fit_res['model'].coef_[0]
elif method == 'parents':
coeffs = {}
for medy in [med for med in mediators] + [y for y in self.listY]:
coeffs[medy] = {}
# mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X)) - set([medy])
all_parents = self._get_all_parents([medy]) - set([medy])
if 'dag' not in self.graph_type:
spouses = self._get_all_spouses([medy]) - set([medy])
if len(spouses) != 0:
raise ValueError("method == 'parents' only possible for "
"causal paths without adjacent bi-directed links!")
# print(j, all_parents[j])
# if len(all_parents[j]) > 0:
# print(medy, list(all_parents))
fit_res = self.model.get_general_fitted_model(
Y=[medy], X=list(all_parents), Z=[],
conditions=None,
tau_max=self.tau_max,
cut_off='tau_max',
return_data=False)
for ipar, par in enumerate(list(all_parents)):
# print(par, fit_res['model'].coef_)
coeffs[medy][par] = fit_res['model'].coef_[0][ipar]
else:
raise ValueError("method must be 'optimal', 'links_coeffs', or 'parents'.")
# Effect is sum over products over all path coefficients
# from x in X to y in Y
effect = {}
for (x, y) in itertools.product(self.listX, self.listY):
effect[(x, y)] = 0.
for causal_path in causal_paths[x][y]:
effect_here = 1.
# print(x, y, causal_path)
for index, node in enumerate(causal_path[:-1]):
i, taui = node
j, tauj = causal_path[index + 1]
# tau_ij = abs(tauj - taui)
# print((j, tauj), (i, taui))
effect_here *= coeffs[(j, tauj)][(i, taui)]
effect[(x, y)] += effect_here
# Make fitted coefficients available as attribute
self.coeffs = coeffs
# Modify and overwrite variables in self.model
self.model.Y = self.listY
self.model.X = self.listX
self.model.Z = []
self.model.conditions = []
self.model.cut_off = 'tau_max' # 'max_lag_or_tau_max'
class dummy_fit_class():
def __init__(self, y_here, listX_here, effect_here):
dim = len(listX_here)
self.coeff_array = np.array([effect_here[(x, y_here)] for x in listX_here]).reshape(dim, 1)
def predict(self, X):
return np.dot(X, self.coeff_array).squeeze()
fit_results = {}
for y in self.listY:
fit_results[y] = {}
fit_results[y]['model'] = dummy_fit_class(y, self.listX, effect)
fit_results[y]['data_transform'] = deepcopy(data_transform)
# self.effect = effect
self.model.fit_results = fit_results
return self
[docs] def predict_wright_effect(self,
intervention_data,
pred_params=None,
):
"""Predict linear effect of intervention with fitted Wright-model.
Parameters
----------
intervention_data : numpy array
Numpy array of shape (time, len(X)) that contains the do(X) values.
pred_params : dict, optional
Optional parameters passed on to sklearn prediction function.
Returns
-------
Results from prediction: an array of shape (time, len(Y)).
"""
lenX = len(self.listX)
lenY = len(self.listY)
if intervention_data.shape[1] != lenX:
raise ValueError("intervention_data.shape[1] must be len(X).")
if self.no_causal_path:
if self.verbosity > 0:
print("No causal path from X to Y exists.")
return np.zeros((len(intervention_data), len(self.Y)))
intervention_T, _ = intervention_data.shape
predicted_array = np.zeros((intervention_T, lenY))
pred_dict = {}
for iy, y in enumerate(self.listY):
# Print message
if self.verbosity > 1:
print("\n## Predicting target %s" % str(y))
if pred_params is not None:
for key in list(pred_params):
print("%s = %s" % (key, pred_params[key]))
# Default value for pred_params
if pred_params is None:
pred_params = {}
# Check this is a valid target
if y not in self.model.fit_results:
raise ValueError("y = %s not yet fitted" % str(y))
# data_transform is too complicated for Wright estimator
# Transform the data if needed
# fitted_data_transform = self.model.fit_results[y]['fitted_data_transform']
# if fitted_data_transform is not None:
# intervention_data = fitted_data_transform['X'].transform(X=intervention_data)
# Now iterate through interventions (and potentially S)
for index, dox_vals in enumerate(intervention_data):
# Construct XZS-array
intervention_array = dox_vals.reshape(1, lenX)
predictor_array = intervention_array
predicted_vals = self.model.fit_results[y]['model'].predict(
X=predictor_array, **pred_params)
predicted_array[index, iy] = predicted_vals.mean()
# data_transform is too complicated for Wright estimator
# if fitted_data_transform is not None:
# rescaled = fitted_data_transform['Y'].inverse_transform(X=predicted_array[index, iy].reshape(-1, 1))
# predicted_array[index, iy] = rescaled.squeeze()
return predicted_array
[docs] def fit_bootstrap_of(self, method, method_args,
boot_samples=100,
boot_blocklength=1,
seed=None):
"""Runs chosen method on bootstrap samples drawn from DataFrame.
Bootstraps for tau=0 are drawn from [max_lag, ..., T] and all lagged
variables constructed in DataFrame.construct_array are consistently
shifted with respect to this bootsrap sample to ensure that lagged
relations in the bootstrap sample are preserved.
This function fits the models, predict_bootstrap_of can then be used
to get confidence intervals for the effect of interventions.
Parameters
----------
method : str
Chosen method among valid functions in this class.
method_args : dict
Arguments passed to method.
boot_samples : int
Number of bootstrap samples to draw.
boot_blocklength : int, optional (default: 1)
Block length for block-bootstrap.
seed : int, optional(default = None)
Seed for RandomState (default_rng)
"""
# if dataframe.analysis_mode != 'single':
# raise ValueError("CausalEffects class currently only supports single "
# "datasets.")
valid_methods = ['fit_total_effect',
'fit_wright_effect',
]
if method not in valid_methods:
raise ValueError("method must be one of %s" % str(valid_methods))
# First call the method on the original dataframe
# to make available adjustment set etc
getattr(self, method)(**method_args)
self.original_model = deepcopy(self.model)
if self.verbosity > 0:
print("\n##\n## Running Bootstrap of %s " % method +
"\n##\n" +
"\nboot_samples = %s \n" % boot_samples +
"\nboot_blocklength = %s \n" % boot_blocklength
)
method_args_bootstrap = deepcopy(method_args)
self.bootstrap_results = {}
for b in range(boot_samples):
# # Replace dataframe in method args by bootstrapped dataframe
# method_args_bootstrap['dataframe'].bootstrap = boot_draw
if seed is None:
random_state = np.random.default_rng(None)
else:
random_state = np.random.default_rng(seed*boot_samples + b)
method_args_bootstrap['dataframe'].bootstrap = {'boot_blocklength':boot_blocklength,
'random_state':random_state}
# Call method and save fitted model
getattr(self, method)(**method_args_bootstrap)
self.bootstrap_results[b] = deepcopy(self.model)
# Reset model
self.model = self.original_model
return self
[docs] def predict_bootstrap_of(self, method, method_args,
conf_lev=0.9,
return_individual_bootstrap_results=False):
"""Predicts with fitted bootstraps.
To be used after fitting with fit_bootstrap_of. Only uses the
expected values of the predict function, not potential other output.
Parameters
----------
method : str
Chosen method among valid functions in this class.
method_args : dict
Arguments passed to method.
conf_lev : float, optional (default: 0.9)
Two-sided confidence interval.
return_individual_bootstrap_results : bool
Returns the individual bootstrap predictions.
Returns
-------
confidence_intervals : numpy array
"""
valid_methods = ['predict_total_effect',
'predict_wright_effect',
]
if method not in valid_methods:
raise ValueError("method must be one of %s" % str(valid_methods))
# def get_vectorized_length(W):
# return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])
lenX = len(self.listX)
lenS = len(self.listS)
lenY = len(self.listY)
intervention_T, _ = method_args['intervention_data'].shape
boot_samples = len(self.bootstrap_results)
# bootstrap_predicted_array = np.zeros((boot_samples, intervention_T, lenY))
for b in range(boot_samples): #self.bootstrap_results.keys():
self.model = self.bootstrap_results[b]
boot_effect = getattr(self, method)(**method_args)
if isinstance(boot_effect, tuple):
boot_effect = boot_effect[0]
if b == 0:
bootstrap_predicted_array = np.zeros((boot_samples, ) + boot_effect.shape,
dtype=boot_effect.dtype)
bootstrap_predicted_array[b] = boot_effect
# Reset model
self.model = self.original_model
# Confidence intervals for val_matrix; interval is two-sided
c_int = (1. - (1. - conf_lev)/2.)
confidence_interval = np.percentile(
bootstrap_predicted_array, axis=0,
q = [100*(1. - c_int), 100*c_int]) #[:,:,0]
if return_individual_bootstrap_results:
return bootstrap_predicted_array, confidence_interval
return confidence_interval
[docs] @staticmethod
def get_dict_from_graph(graph, parents_only=False):
"""Helper function to convert graph to dictionary of links.
Parameters
---------
graph : array of shape (N, N, tau_max+1)
Matrix format of graph in string format.
parents_only : bool
Whether to only return parents ('-->' in graph)
Returns
-------
links : dict
Dictionary of form {0:{(0, -1): o-o, ...}, 1:{...}, ...}.
"""
N = graph.shape[0]
links = dict([(j, {}) for j in range(N)])
if parents_only:
for (i, j, tau) in zip(*np.where(graph=='-->')):
links[j][(i, -tau)] = graph[i,j,tau]
else:
for (i, j, tau) in zip(*np.where(graph!='')):
links[j][(i, -tau)] = graph[i,j,tau]
return links
[docs] @staticmethod
def get_graph_from_dict(links, tau_max=None):
"""Helper function to convert dictionary of links to graph array format.
Parameters
---------
links : dict
Dictionary of form {0:[((0, -1), coeff, func), ...], 1:[...], ...}.
Also format {0:[(0, -1), ...], 1:[...], ...} is allowed.
tau_max : int or None
Maximum lag. If None, the maximum lag in links is used.
Returns
-------
graph : array of shape (N, N, tau_max+1)
Matrix format of graph with 1 for true links and 0 else.
"""
def _get_minmax_lag(links):
"""Helper function to retrieve tau_min and tau_max from links.
"""
N = len(links)
# Get maximum time lag
min_lag = np.inf
max_lag = 0
for j in range(N):
for link_props in links[j]:
if len(link_props) > 2:
var, lag = link_props[0]
coeff = link_props[1]
# func = link_props[2]
if coeff != 0.:
min_lag = min(min_lag, abs(lag))
max_lag = max(max_lag, abs(lag))
else:
var, lag = link_props
min_lag = min(min_lag, abs(lag))
max_lag = max(max_lag, abs(lag))
return min_lag, max_lag
N = len(links)
# Get maximum time lag
min_lag, max_lag = _get_minmax_lag(links)
# Set maximum lag
if tau_max is None:
tau_max = max_lag
else:
if max_lag > tau_max:
raise ValueError("tau_max is smaller than maximum lag = %d "
"found in links, use tau_max=None or larger "
"value" % max_lag)
graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
for j in links.keys():
for link_props in links[j]:
if len(link_props) > 2:
var, lag = link_props[0]
coeff = link_props[1]
if coeff != 0.:
graph[var, j, abs(lag)] = "-->"
if lag == 0:
graph[j, var, 0] = "<--"
else:
var, lag = link_props
graph[var, j, abs(lag)] = "-->"
if lag == 0:
graph[j, var, 0] = "<--"
return graph
if __name__ == '__main__':
# Consider some toy data
import tigramite
import tigramite.toymodels.structural_causal_processes as toys
import tigramite.data_processing as pp
import tigramite.plotting as tp
from matplotlib import pyplot as plt
import sys
import sklearn
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
# def lin_f(x): return x
# coeff = .5
# links_coeffs = {0: [((0, -1), 0.5, lin_f)],
# 1: [((1, -1), 0.5, lin_f), ((0, -1), 0.5, lin_f)],
# 2: [((2, -1), 0.5, lin_f), ((1, 0), 0.5, lin_f)]
# }
# T = 1000
# data, nonstat = toys.structural_causal_process(
# links_coeffs, T=T, noises=None, seed=7)
# dataframe = pp.DataFrame(data)
# graph = CausalEffects.get_graph_from_dict(links_coeffs)
# original_graph = np.array([[['', ''],
# ['-->', ''],
# ['-->', ''],
# ['', '']],
# [['<--', ''],
# ['', '-->'],
# ['-->', ''],
# ['-->', '']],
# [['<--', ''],
# ['<--', ''],
# ['', '-->'],
# ['-->', '']],
# [['', ''],
# ['<--', ''],
# ['<--', ''],
# ['', '-->']]], dtype='<U3')
# graph = np.copy(original_graph)
# # Add T <-> Reco and T
# graph[2,3,0] = '+->' ; graph[3,2,0] = '<-+'
# graph[1,3,1] = '<->' #; graph[2,1,0] = '<--'
# added = np.zeros((4, 4, 1), dtype='<U3')
# added[:] = ""
# graph = np.append(graph, added , axis=2)
# X = [(1, 0)]
# Y = [(3, 0)]
# # # Initialize class as `stationary_dag`
# causal_effects = CausalEffects(graph, graph_type='stationary_admg',
# X=X, Y=Y, S=None,
# hidden_variables=None,
# verbosity=0)
# print(causal_effects.get_optimal_set())
# tp.plot_time_series_graph(
# graph = graph,
# save_name='Example_graph_in.pdf',
# # special_nodes=special_nodes,
# # var_names=var_names,
# figsize=(6, 4),
# )
# tp.plot_time_series_graph(
# graph = causal_effects.graph,
# save_name='Example_graph_out.pdf',
# # special_nodes=special_nodes,
# # var_names=var_names,
# figsize=(6, 4),
# )
# causal_effects.fit_wright_effect(dataframe=dataframe,
# # links_coeffs = links_coeffs,
# # mediation = [(1, 0), (1, -1), (1, -2)]
# )
# intervention_data = 1.*np.ones((1, 1))
# y1 = causal_effects.predict_wright_effect(
# intervention_data=intervention_data,
# )
# intervention_data = 0.*np.ones((1, 1))
# y2 = causal_effects.predict_wright_effect(
# intervention_data=intervention_data,
# )
# beta = (y1 - y2)
# print("Causal effect is %.5f" %(beta))
# tp.plot_time_series_graph(
# graph = causal_effects.graph,
# save_name='Example_graph.pdf',
# # special_nodes=special_nodes,
# var_names=var_names,
# figsize=(8, 4),
# )
T = 10000
def lin_f(x): return x
auto_coeff = 0.
coeff = 2.
links = {
0: [((0, -1), auto_coeff, lin_f)],
1: [((1, -1), auto_coeff, lin_f)],
2: [((2, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
3: [((3, -1), auto_coeff, lin_f)],
}
data, nonstat = toys.structural_causal_process(links, T=T,
noises=None, seed=7)
# # Create some missing values
# data[-10:,:] = 999.
# var_names = range(2)
dataframe = pp.DataFrame(data,
vector_vars={0:[(0,0), (1,0)],
1:[(2,0), (3,0)]}
)
# # Construct expert knowledge graph from links here
aux_links = {0: [(0, -1)],
1: [(1, -1), (0, 0)],
}
# # Use staticmethod to get graph
graph = CausalEffects.get_graph_from_dict(aux_links, tau_max=2)
# graph = np.array([['', '-->'],
# ['<--', '']], dtype='<U3')
# # We are interested in lagged total effect of X on Y
X = [(0, 0), (0, -1)]
Y = [(1, 0), (1, -1)]
# # Initialize class as `stationary_dag`
causal_effects = CausalEffects(graph, graph_type='stationary_dag',
X=X, Y=Y, S=None,
hidden_variables=None,
verbosity=1)
# print(data)
# # Optimal adjustment set (is used by default)
# # print(causal_effects.get_optimal_set())
# # # Fit causal effect model from observational data
causal_effects.fit_total_effect(
dataframe=dataframe,
# mask_type='y',
estimator=LinearRegression(),
)
# # Fit causal effect model from observational data
# causal_effects.fit_bootstrap_of(
# method='fit_total_effect',
# method_args={'dataframe':dataframe,
# # mask_type='y',
# 'estimator':LinearRegression()
# },
# boot_samples=3,
# boot_blocklength=1,
# seed=5
# )
# Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
lenX = 4 # len(dataframe.vector_vars[X[0][0]])
dox_vals = np.linspace(0., 1., 3)
intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), lenX)
intervention_data = np.array([[1., 0., 0., 0.]])
print(intervention_data)
pred_Y = causal_effects.predict_total_effect(
intervention_data=intervention_data)
print(pred_Y, pred_Y.shape)
# # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
# # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)
# intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), len(X))
# conf = causal_effects.predict_bootstrap_of(
# method='predict_total_effect',
# method_args={'intervention_data':intervention_data})
# print(conf, conf.shape)
# # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
# # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)
# # intervention_data = dox_vals.reshape(len(dox_vals), len(X))
# # pred_Y = causal_effects.predict_total_effect(
# # intervention_data=intervention_data)
# # print(pred_Y)
# # Fit causal effect model from observational data
# causal_effects.fit_wright_effect(
# dataframe=dataframe,
# # mask_type='y',
# # estimator=LinearRegression(),
# # data_transform=StandardScaler(),
# )
# # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
# dox_vals = np.linspace(0., 1., 5)
# intervention_data = dox_vals.reshape(len(dox_vals), len(X))
# pred_Y = causal_effects.predict_wright_effect(
# intervention_data=intervention_data)
# print(pred_Y)