# -*- coding: utf-8 -*-
"""Graph data structures."""
from __future__ import annotations
import itertools as itt
import json
import warnings
from dataclasses import dataclass, field
from itertools import chain, combinations
from typing import (
TYPE_CHECKING,
Any,
Collection,
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
import networkx as nx
from networkx.classes.reportviews import NodeView
from networkx.utils import open_file
from .dsl import CounterfactualVariable, Intervention, Variable, vmap_adj, vmap_pairs
if TYPE_CHECKING:
import ananke.graphs
import pgmpy.inference.CausalInference
import pgmpy.models
import sympy
__all__ = [
"NxMixedGraph",
"CausalEffectGraph",
"DEFULT_PREFIX",
"DEFAULT_TAG",
"set_latent",
]
CausalEffectGraph = Any
#: The default key in a latent variable DAG represented as a :class:`networkx.DiGraph`
#: for nodes that correspond to "latent" variables
DEFAULT_TAG = "hidden"
#: The default prefix for latent variables in a latent variable DAG represented. After the prefix,
#: there will be a number assigned that's incremented during construction.
DEFULT_PREFIX = "u_"
NO_SET_LATENT_FLAG = "no_set_latent"
[docs]
@dataclass
class NxMixedGraph:
"""A mixed graph based on a :class:`networkx.Graph` and a :class:`networkx.DiGraph`.
Example usage:
.. code-block:: python
graph = NxMixedGraph()
graph.add_directed_edge('X', 'Y')
graph.add_undirected_edge('X', 'Y')
"""
#: A directed graph
directed: nx.DiGraph = field(default_factory=nx.DiGraph)
#: A undirected graph
undirected: nx.Graph = field(default_factory=nx.Graph)
def __post_init__(self):
"""Process the graphs."""
self.directed.graph[NO_SET_LATENT_FLAG] = True
self.undirected.graph[NO_SET_LATENT_FLAG] = True
def __eq__(self, other: Any) -> bool:
"""Check for equality of nodes, directed edges, and undirected edges."""
return (
isinstance(other, NxMixedGraph)
and self.nodes() == other.nodes()
and (self.directed.edges() == other.directed.edges())
and (self.undirected.edges() == other.undirected.edges())
)
def __iter__(self) -> Iterable[Variable]:
"""Iterate over nodes in the graph."""
return iter(self.directed)
def __len__(self) -> int:
"""Count the nodes in the graph."""
return len(self.directed)
def __contains__(self, item: Variable) -> bool:
"""Check if the given item is a node in the graph."""
return item in self.directed
[docs]
def copy(self):
"""Get a copy of the graph."""
return self.__class__(
directed=self.directed.copy(),
undirected=self.undirected.copy(),
)
[docs]
def is_counterfactual(self) -> bool:
"""Check if this is a counterfactual graph."""
return any(isinstance(n, CounterfactualVariable) for n in self.nodes())
[docs]
def raise_on_counterfactual(self) -> None:
"""Raise an error if this is a counterfactual graph.
:raises ValueError: if this graph is a counterfactual graph
"""
if self.is_counterfactual():
raise ValueError("This operation is not available for counterfactual graphs")
[docs]
def add_node(self, n: Variable) -> None:
"""Add a node."""
n = Variable.norm(n)
self.directed.add_node(n)
self.undirected.add_node(n)
[docs]
def add_directed_edge(self, u: Union[str, Variable], v: Union[str, Variable], **attr) -> None:
"""Add a directed edge from u to v."""
u = Variable.norm(u)
v = Variable.norm(v)
self.directed.add_edge(u, v, **attr)
self.undirected.add_node(u)
self.undirected.add_node(v)
[docs]
def add_undirected_edge(self, u: Union[str, Variable], v: Union[str, Variable], **attr) -> None:
"""Add an undirected edge between u and v."""
u = Variable.norm(u)
v = Variable.norm(v)
self.undirected.add_edge(u, v, **attr)
self.directed.add_node(u)
self.directed.add_node(v)
[docs]
def nodes(self) -> NodeView[Variable]:
"""Get the nodes in the graph."""
return self.directed.nodes()
[docs]
def to_admg(self) -> "ananke.graphs.ADMG":
"""Get an ananke ADMG."""
self.raise_on_counterfactual()
from ananke.graphs import ADMG
# update the way stringification happens so this
# can support arbitrary variables, like counterfactuals
return ADMG(
vertices=[n.name for n in self.nodes()],
di_edges=[(u.name, v.name) for u, v in self.directed.edges()],
bi_edges=[(u.name, v.name) for u, v in self.undirected.edges()],
)
[docs]
def to_pgmpy_bayesian_network(self) -> "pgmpy.models.BayesianNetwork":
"""Convert a mixed graph to an equivalent :class:`pgmpy.BayesianNetwork`."""
from pgmpy.models import BayesianNetwork
edges = [(u.name, v.name) for u, v in self.directed.edges()]
latents = set()
for u, v in self.undirected.edges():
latent = f"U_{u.name}_{v.name}"
latents.add(latent)
edges.append((latent, u.name))
edges.append((latent, v.name))
model = BayesianNetwork(ebunch=edges, latents=latents)
return model
[docs]
def to_pgmpy_causal_inference(self) -> "pgmpy.inference.CausalInference.CausalInference":
"""Get a pgmpy causal inference object."""
from pgmpy.inference.CausalInference import CausalInference
return CausalInference(self.to_pgmpy_bayesian_network())
[docs]
def to_linear_scm_sympy(self) -> dict[Variable, "sympy.Expr"]:
"""Generate a Sympy system of equations."""
import sympy
variable_to_equation = {}
for node in self.topological_sort():
terms = []
# Add parent edges
for parent in self.directed.predecessors(node):
beta = sympy_nested(r"\beta", parent, node)
terms.append(beta * parent.to_sympy())
# Add noise term
epsilon_symbol = sympy_nested(r"\epsilon", node)
terms.append(epsilon_symbol)
# get bidirected edges
for u, v in self.undirected.edges(node):
u, v = sorted([u, v])
gamma_symbol = sympy_nested(r"\gamma", u, v)
terms.append(gamma_symbol)
variable_to_equation[node] = cast(sympy.Expr, sum(terms))
return variable_to_equation
[docs]
def to_linear_scm_latex(self) -> str:
"""Generate a Sympy system of equations."""
import sympy
equations_dict = self.to_linear_scm_sympy()
latex_equations = [
rf"{variable.to_latex()} &= {sympy.latex(expression)} \\"
for variable, expression in equations_dict.items()
]
return _LatexStr(r"\begin{align*}" + "\n ".join(latex_equations) + r"\end{align*}")
[docs]
@classmethod
def from_admg(cls, admg) -> NxMixedGraph:
"""Create from an ananke ADMG."""
return cls.from_str_edges(
nodes=admg.vertices,
directed=admg.di_edges,
undirected=admg.bi_edges,
)
[docs]
def to_latent_variable_dag(
self,
*,
prefix: Optional[str] = None,
start: int = 0,
tag: Optional[str] = None,
) -> nx.DiGraph:
"""Create a labeled DAG where bi-directed edges are assigned as nodes upstream of their two incident nodes.
:param prefix: The prefix for latent variables. If none, defaults to :data:`y0.graph.DEFAULT_PREFIX`.
:param start: The starting number for latent variables (defaults to 0, could be changed to 1 if desired)
:param tag: The key for node data describing whether it is latent.
If None, defaults to :data:`y0.graph.DEFAULT_TAG`.
:return: A latent variable DAG.
"""
self.raise_on_counterfactual()
return _latent_dag(
di_edges=self.directed.edges(),
bi_edges=self.undirected.edges(),
prefix=prefix,
start=start,
tag=tag,
)
[docs]
@classmethod
def from_latent_variable_dag(cls, graph: nx.DiGraph, tag: Optional[str] = None) -> NxMixedGraph:
"""Load a labeled DAG."""
if tag is None:
tag = DEFAULT_TAG
if any(tag not in data for data in graph.nodes.values()):
raise ValueError(f"missing label {tag} in one or more nodes.")
rv = cls()
for node, data in graph.nodes.items():
if data[tag]:
for a, b in itt.combinations(graph.successors(node), 2):
rv.add_undirected_edge(a, b)
else:
for child in graph.successors(node):
rv.add_directed_edge(node, child)
return rv
[docs]
def to_causaleffect(self) -> CausalEffectGraph:
"""Get a causaleffect R object.
:returns: A causaleffect R object.
.. warning:: Appropriate R imports need to be done first for 'causaleffect' and 'igraph'.
"""
import rpy2.robjects
return rpy2.robjects.r(self.to_causaleffect_str())
[docs]
def joint(self) -> nx.MultiGraph:
"""Return a joint graph."""
rv = nx.MultiGraph()
rv.add_nodes_from(self.directed)
rv.add_edges_from(self.directed.edges)
rv.add_edges_from(self.undirected.edges)
return rv
[docs]
def moralize(self):
"""Moralize the graph.
:returns: A moralized ADMG in which all nodes $U$ and $v$ that are parents of some
node $N$ are connected with an undirected edge.
.. seealso:: https://en.wikipedia.org/wiki/Moral_graph
"""
rv = NxMixedGraph(directed=self.directed.copy(), undirected=self.undirected.copy())
# Moralize (link parents of mentioned nodes)
for u, v in iter_moral_links(self):
rv.add_undirected_edge(u, v)
return rv
[docs]
def draw(
self, ax=None, title: Optional[str] = None, prog: Optional[str] = None, latex: bool = True
) -> None:
"""Render the graph using matplotlib.
:param ax: Axis to draw on (if none specified, makes a new one)
:param title: The optional title to show with the graph
:param prog: The pydot program to use, like dot, neato, osage, etc.
If none is given, uses osage for small graphs and dot for larger ones.
:param latex: Parse string variables as y0 if possible to make pretty latex output
"""
import matplotlib.pyplot as plt
if prog is None:
if self.directed.number_of_nodes() > 6:
prog = "dot"
else:
prog = "osage"
layout = _layout(self, prog=prog)
u_proxy = nx.DiGraph(self.undirected.edges)
labels = None if not latex else {node: _get_latex(node) for node in self.directed}
if ax is None:
ax = plt.gca()
# TODO choose sizes based on size of axis
node_size = 1_500
node_size_offset = 500
line_widths = 2
margins = 0.3
font_size = 20
arrow_size = 20
radius = 0.3
nx.draw_networkx_nodes(
self.directed,
pos=layout,
node_color="white",
node_size=node_size,
edgecolors="black",
linewidths=line_widths,
ax=ax,
margins=margins,
)
nx.draw_networkx_labels(
self.directed, pos=layout, ax=ax, labels=labels, font_size=font_size
)
nx.draw_networkx_edges(
self.directed,
pos=layout,
edge_color="black",
ax=ax,
node_size=node_size + node_size_offset,
width=line_widths,
arrowsize=arrow_size,
)
nx.draw_networkx_edges(
u_proxy,
pos=layout,
node_size=node_size + node_size_offset,
ax=ax,
style=":",
width=line_widths,
connectionstyle=f"arc3, rad={radius}",
arrowstyle="-",
edge_color="grey",
)
if title:
ax.set_title(title)
ax.axis("off")
[docs]
@classmethod
def from_causaleffect(cls, graph) -> NxMixedGraph:
"""Construct an instance from a causaleffect R graph."""
raise NotImplementedError
[docs]
def to_causaleffect_str(self) -> str:
"""Get a string to be imported by R."""
if not self.directed:
raise ValueError("graph must have some directed edges")
formula = ", ".join(f"{u} -+ {v}" for u, v in self.directed.edges())
if self.undirected:
formula += "".join(f", {u} -+ {v}, {v} -+ {u}" for u, v in self.undirected.edges())
rv = f"g <- graph.formula({formula}, simplify = FALSE)"
for i in range(self.undirected.number_of_edges()):
idx = 2 * i + self.directed.number_of_edges() + 1
rv += (
f'\ng <- set.edge.attribute(graph = g, name = "description",'
f' index = c({idx}, {idx + 1}), value = "U")'
)
return rv
[docs]
@classmethod
def from_edges(
cls,
nodes: Optional[Iterable[Variable]] = None,
directed: Optional[Iterable[Tuple[Variable, Variable]]] = None,
undirected: Optional[Iterable[Tuple[Variable, Variable]]] = None,
) -> NxMixedGraph:
"""Make a mixed graph from a pair of edge lists."""
if directed is None and undirected is None:
raise ValueError("must provide at least one of directed/undirected edge lists")
rv = cls()
for n in nodes or []:
rv.add_node(n)
for u, v in directed or []:
rv.add_directed_edge(u, v)
for u, v in undirected or []:
rv.add_undirected_edge(u, v)
return rv
[docs]
@classmethod
def from_str_edges(
cls,
nodes: Optional[Iterable[str]] = None,
directed: Optional[Iterable[Tuple[str, str]]] = None,
undirected: Optional[Iterable[Tuple[str, str]]] = None,
) -> NxMixedGraph:
"""Make a mixed graph from a pair of edge lists where nodes are strings."""
return cls.from_edges(
nodes=None if nodes is None else [Variable(n) for n in nodes],
directed=None if directed is None else vmap_pairs(directed),
undirected=None if undirected is None else vmap_pairs(undirected),
)
[docs]
@classmethod
def from_adj(
cls,
nodes: Optional[Iterable[Variable]] = None,
directed: Optional[Mapping[Variable, Collection[Variable]]] = None,
undirected: Optional[Mapping[Variable, Collection[Variable]]] = None,
) -> NxMixedGraph:
"""Make a mixed graph from a pair of adjacency lists."""
rv = cls()
for n in nodes or []:
rv.add_node(n)
for u, vs in (directed or {}).items():
rv.add_node(u)
for v in vs:
rv.add_directed_edge(u, v)
for u, vs in (undirected or {}).items():
rv.add_node(u)
for v in vs:
rv.add_undirected_edge(u, v)
return rv
[docs]
@classmethod
def from_str_adj(
cls,
nodes: Optional[Iterable[str]] = None,
directed: Optional[Mapping[str, Collection[str]]] = None,
undirected: Optional[Mapping[str, Collection[str]]] = None,
) -> NxMixedGraph:
"""Make a mixed graph from a pair of adjacency lists of strings."""
return cls.from_adj(
nodes=None if nodes is None else [Variable(n) for n in nodes],
directed=None if directed is None else vmap_adj(directed),
undirected=None if undirected is None else vmap_adj(undirected),
)
[docs]
@classmethod
@open_file(1)
def from_causalfusion_path(cls, file) -> NxMixedGraph:
"""Load a graph from a CausalFusion JSON file."""
return cls.from_causalfusion_json(json.load(file))
[docs]
@classmethod
def from_causalfusion_json(cls, data: Mapping[str, Any]) -> NxMixedGraph:
"""Load a graph from a CausalFusion JSON object."""
rv = cls()
for edge in data["edges"]:
u, v = edge["from"], edge["to"]
if edge["type"] == "directed":
rv.add_directed_edge(u, v)
elif edge["type"] == "bidirected":
rv.add_undirected_edge(u, v)
else:
raise ValueError(f'unhandled edge type: {edge["type"]}')
return rv
[docs]
def subgraph(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
"""Return a subgraph given a set of vertices.
:param vertices: a subset of nodes
:returns: A NxMixedGraph subgraph
"""
vertices = _ensure_set(vertices)
return self.from_edges(
nodes=vertices,
directed=_include_adjacent(self.directed, vertices),
undirected=_include_adjacent(self.undirected, vertices),
)
[docs]
def remove_in_edges(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
"""Return a mutilated graph given a set of interventions.
:param vertices: a subset of nodes from which to remove incoming edges
:returns: A NxMixedGraph subgraph
"""
vertices = _ensure_set(vertices)
return self.from_edges(
nodes=vertices,
directed=_exclude_target(self.directed, vertices),
undirected=_exclude_adjacent(self.undirected, vertices),
)
[docs]
def get_intervened_ancestors(self, interventions, outcomes) -> Set[Variable]:
"""Get the ancestors of outcomes in a graph that has been intervened on.
:param interventions: a set of interventions in the graph
:param outcomes: a set of outcomes in the graph
:returns: Set of nodes
"""
return self.remove_in_edges(interventions).ancestors_inclusive(outcomes)
[docs]
def get_no_effect_on_outcomes(self, interventions, outcomes) -> Set[Variable]:
"""Find nodes in the graph which have no effect on the outcomes.
:param interventions: a set of interventions in the graph
:param outcomes: a set of outcomes in the graph
:returns: Set of nodes
"""
return self.nodes() - interventions - self.get_intervened_ancestors(interventions, outcomes)
[docs]
def remove_nodes_from(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
"""Return a subgraph that does not contain any of the specified vertices.
:param vertices: a set of nodes to remove from graph
:returns: A NxMixedGraph subgraph
"""
vertices = _ensure_set(vertices)
return self.from_edges(
nodes=self.nodes() - vertices,
directed=_exclude_adjacent(self.directed, vertices),
undirected=_exclude_adjacent(self.undirected, vertices),
)
[docs]
def remove_out_edges(self, vertices: Union[Variable, Iterable[Variable]]) -> NxMixedGraph:
"""Return a subgraph that does not have any outgoing edges from any of the given vertices.
:param vertices: a set of nodes whose outgoing edges get removed from the graph
:returns: NxMixedGraph subgraph
"""
vertices = _ensure_set(vertices)
return self.from_edges(
nodes=self.nodes(),
directed=_exclude_source(self.directed, vertices),
undirected=self.undirected.edges(),
)
[docs]
def ancestors_inclusive(self, sources: Union[Variable, Iterable[Variable]]) -> set[Variable]:
"""Ancestors of a set include the set itself."""
sources = _ensure_set(sources)
return _ancestors_inclusive(self.directed, sources)
[docs]
def descendants_inclusive(self, sources: Union[Variable, Iterable[Variable]]) -> set[Variable]:
"""Descendants of a set include the set itself."""
sources = _ensure_set(sources)
return _descendants_inclusive(self.directed, sources)
[docs]
def topological_sort(self) -> Iterable[Variable]:
"""Get a topological sort from the directed component of the mixed graph."""
return nx.topological_sort(self.directed)
[docs]
def get_c_components(self) -> list[frozenset[Variable]]:
"""Get the co-components (i.e., districts) in the undirected portion of the graph."""
warnings.warn("use NxMixedGraph.districts()", DeprecationWarning, stacklevel=2)
return list(self.districts())
[docs]
def districts(self) -> set[frozenset[Variable]]:
"""Get the districts."""
return {frozenset(c) for c in nx.connected_components(self.undirected)}
[docs]
def get_district(self, node: Variable) -> frozenset[Variable]:
"""Get the district the node is in."""
for district in self.districts():
if node in district:
return district
raise KeyError(f"{node} not found in graph")
[docs]
def is_connected(self) -> bool:
"""Return if there is only a single connected component in the undirected graph."""
return nx.is_connected(self.undirected)
[docs]
def intervene(self, variables: Set[Intervention]) -> NxMixedGraph:
"""Intervene on the given variables.
:param variables: A set of interventions
:returns: A graph that has been intervened on the given variables, with edges into the intervened nodes removed
"""
return self.from_edges(
nodes=[node.intervene(variables) for node in self.nodes()],
directed=[
(u.intervene(variables), v.intervene(variables))
for u, v in self.directed.edges()
if _node_not_an_intervention(v, variables)
],
undirected=[
(u.intervene(variables), v.intervene(variables))
for u, v in self.undirected.edges()
if _node_not_an_intervention(u, variables)
and _node_not_an_intervention(v, variables)
],
)
[docs]
def get_markov_pillow(self, nodes: Collection[Variable]) -> Set[Variable]:
"""For each district, intervene on the domain of each parent not in the district."""
parents_of_district: Set[Variable] = set()
for node in nodes:
parents_of_district |= set(self.directed.predecessors(node))
return parents_of_district - set(nodes)
[docs]
def get_markov_blanket(self, nodes: Union[Variable, Iterable[Variable]]) -> Set[Variable]:
"""Get the Markov blanket for a set of nodes.
The Markov blanket in a directed graph is the union of the parents, children,
and parents of children of a given node.
:param nodes: A node or nodes to get the Markov blanket from
:return: A set of variables comprising the Markov blanket
"""
if isinstance(nodes, Variable):
nodes = {nodes}
else:
nodes = set(nodes)
blanket = set()
for node in nodes:
blanket.update(self.directed.predecessors(node))
for successor in self.directed.successors(node):
blanket.add(successor)
blanket.update(self.directed.predecessors(successor))
return blanket.difference(nodes)
[docs]
def disorient(self) -> nx.Graph:
"""Return a graph with all edges converted to a flat undirected graph."""
rv = nx.Graph()
rv.add_nodes_from(self.nodes())
rv.add_edges_from(self.directed.edges())
rv.add_edges_from(self.undirected.edges())
return rv
[docs]
def pre(
self,
nodes: Union[Variable, Iterable[Variable]],
topological_sort_order: Optional[Sequence[Variable]] = None,
) -> list[Variable]:
"""Find all nodes prior to the given set of nodes under a topological sort order.
:param nodes: iterable of nodes.
:param topological_sort_order: A valid topological sort order. If none given, calculates from the graph.
:return: list corresponding to the order up until the given nodes.
This does not include any of the nodes from the query.
"""
if not topological_sort_order:
topological_sort_order = list(self.topological_sort())
node_set = _ensure_set(nodes)
pre = []
for node in topological_sort_order:
if node in node_set:
break
pre.append(node)
return pre
class _LatexStr(str):
def _repr_latex_(self):
return self
def _node_not_an_intervention(node: Variable, interventions: Set[Intervention]) -> bool:
"""Confirm that node is not an intervention."""
if isinstance(node, (Intervention, CounterfactualVariable)):
raise TypeError(
"this shouldn't happen since the graph should not have interventions as nodes"
)
return (+node not in interventions) and (-node not in interventions)
def _ancestors_inclusive(graph: nx.DiGraph, sources: set[Variable]) -> set[Variable]:
ancestors = set(
itt.chain.from_iterable(nx.algorithms.dag.ancestors(graph, source) for source in sources)
)
return sources | ancestors
def _descendants_inclusive(graph: nx.DiGraph, sources: set[Variable]) -> set[Variable]:
descendants = set(
itt.chain.from_iterable(nx.algorithms.dag.descendants(graph, source) for source in sources)
)
return sources | descendants
def _include_adjacent(
graph: nx.Graph, vertices: set[Variable]
) -> Collection[Tuple[Variable, Variable]]:
vertices = _ensure_set(vertices)
return [(u, v) for u, v in graph.edges() if u in vertices and v in vertices]
def _exclude_source(
graph: nx.Graph, vertices: set[Variable]
) -> Collection[Tuple[Variable, Variable]]:
return [(u, v) for u, v in graph.edges() if u not in vertices]
def _exclude_target(
graph: nx.Graph, vertices: set[Variable]
) -> Collection[Tuple[Variable, Variable]]:
return [(u, v) for u, v in graph.edges() if v not in vertices]
def _exclude_adjacent(
graph: nx.Graph, vertices: set[Variable]
) -> Collection[Tuple[Variable, Variable]]:
return [(u, v) for u, v in graph.edges() if u not in vertices and v not in vertices]
def _latent_dag(
di_edges: Iterable[Tuple[Variable, Variable]],
bi_edges: Iterable[Tuple[Variable, Variable]],
*,
prefix: Optional[str] = None,
start: int = 0,
tag: Optional[str] = None,
) -> nx.DiGraph:
"""Create a labeled DAG where bi-directed edges are assigned as nodes upstream of their two incident nodes.
:param di_edges: A list of directional edges
:param bi_edges: A list of bidirectional edges
:param prefix: The prefix for latent variables. If none, defaults to :data:`y0.graph.DEFAULT_PREFIX`.
:param start: The starting number for latent variables (defaults to 0, could be changed to 1 if desired)
:param tag: The key for node data describing whether it is latent.
If None, defaults to :data:`y0.graph.DEFAULT_TAG`.
:return: A latent variable DAG.
"""
if tag is None:
tag = DEFAULT_TAG
if prefix is None:
prefix = DEFULT_PREFIX
bi_edges_list = list(bi_edges)
rv = nx.DiGraph()
rv.add_nodes_from(itt.chain.from_iterable(bi_edges_list))
rv.add_edges_from(di_edges)
nx.set_node_attributes(rv, False, tag)
for i, (u, v) in enumerate(sorted(bi_edges_list), start=start):
latent_node = Variable(f"{prefix}{i}")
rv.add_node(latent_node, **{tag: True})
rv.add_edge(latent_node, u)
rv.add_edge(latent_node, v)
return rv
[docs]
def set_latent(
graph: nx.DiGraph,
latent_nodes: Union[Variable, Iterable[Variable]],
tag: Optional[str] = None,
) -> None:
"""Quickly set the latent variables in a graph."""
if graph.graph.get(NO_SET_LATENT_FLAG):
raise RuntimeError(
"Do not set latent variables on graphs inside a NxMixedGraph using set_latent().\n"
"This function is strictly only for nx.DiGraphs that have been constructed based on "
"a NxMixedGraph, but not the NxMixedGraph itself."
)
if tag is None:
tag = DEFAULT_TAG
if isinstance(latent_nodes, Variable):
latent_nodes = [latent_nodes]
latent_nodes = set(latent_nodes)
for node, data in graph.nodes(data=True):
data[tag] = node in latent_nodes
def _get_latex(node) -> str:
if isinstance(node, str):
from y0.parser import parse_y0
try:
expr = parse_y0(node)
except Exception:
return node
else:
return expr._repr_latex_()
from y0.dsl import Variable
if isinstance(node, Variable):
return node._repr_latex_()
raise TypeError
def _ensure_set(vertices: Union[Variable, Iterable[Variable]]) -> set[Variable]:
rv = {vertices} if isinstance(vertices, Variable) else set(vertices)
if any(isinstance(v, Intervention) for v in rv):
raise TypeError("can not use interventions here")
return rv
def _layout(self, prog):
joint = self.joint()
try:
layout = nx.nx_agraph.pygraphviz_layout(joint, prog=prog)
except ImportError:
pass
else:
return layout
try:
layout = nx.nx_pydot.pydot_layout(joint, prog=prog)
except ImportError:
pass
else:
return layout
return nx.spring_layout(joint)
def is_a_fixable(graph: NxMixedGraph, treatments: Union[Variable, Collection[Variable]]) -> bool:
"""Check if the treatments are a-fixable.
A treatment is said to be a-fixable if it can be fixed by removing a single directed edge from the graph.
In other words, a treatment is a-fixable if it has exactly one descendant in its district.
This code was adapted from :mod:`ananke` ananke code at:
https://gitlab.com/causal/ananke/-/blob/dev/ananke/estimation/counterfactual_mean.py?ref_type=heads#L58-65
:param graph: A NxMixedGraph
:param treatments: A list of treatments
:raises NotImplementedError: a-fixability on multiple treatments is an open research question
:returns: bool
"""
if not isinstance(treatments, Variable):
raise NotImplementedError(
"a-fixability on multiple treatments is an open research question"
)
descendants = graph.descendants_inclusive(treatments)
descendants_in_district = graph.get_district(treatments).intersection(descendants)
return 1 == len(descendants_in_district)
def is_p_fixable(graph: NxMixedGraph, treatments: Union[Variable, Collection[Variable]]) -> bool:
"""Check if the treatments are p-fixable.
This code was adapted from :mod:`ananke` ananke code at:
https://gitlab.com/causal/ananke/-/blob/dev/ananke/estimation/counterfactual_mean.py?ref_type=heads#L85-92
:param graph: A NxMixedGraph
:param treatments: A list of treatments
:raises NotImplementedError: p-fixability on multiple treatments is an open research question
:returns: bool
"""
if not isinstance(treatments, Variable):
raise NotImplementedError(
"p-fixability on multiple treatments is an open research question"
)
children = set(graph.directed.successors(treatments))
children_in_district = graph.get_district(treatments).intersection(children)
return 0 == len(children_in_district)
def is_markov_blanket_shielded(graph: NxMixedGraph) -> bool:
"""Check if the ADMG is a Markov blanket shielded.
Being Markov blanket (Mb) shielded means that two vertices are non-adjacent
only when they are absent from each others' Markov blankets.
This code was adapted from :mod:`ananke` ananke code at:
https://gitlab.com/causal/ananke/-/blob/dev/ananke/graphs/admg.py?ref_type=heads#L381-403
:param graph: A NxMixedGraph
:returns: bool
"""
# Iterate over all pairs of vertices
for u, v in itt.combinations(graph.nodes(), 2):
# Check if the pair is not adjacent
if not (
any(
[
graph.directed.has_edge(u, v),
graph.directed.has_edge(v, u),
graph.undirected.has_edge(u, v),
]
)
):
# If one is in the Markov blanket of the other, then it is not mb-shielded
if _markov_blanket_overlap(graph, u, v):
return False
return True
def get_district_and_predecessors(
graph: NxMixedGraph,
nodes: Iterable[Variable],
topological_sort_order: Optional[Sequence[Variable]] = None,
):
"""Get the union of district, predecessors and predecessors of district for a given set of nodes.
This code was adapted from :mod:`ananke` ananke code at:
https://gitlab.com/causal/ananke/-/blob/dev/ananke/graphs/admg.py?ref_type=heads#L96-117
:param graph: A NxMixedGraph
:param nodes: List of nodes
:param topological_sort_order: A valid topological sort order
:return: Set corresponding to union of district, predecessors and predecessors of district of a given set of nodes
"""
if not topological_sort_order:
topological_sort_order = list(graph.topological_sort())
# Get the subgraph corresponding to the nodes and nodes prior to them
pre = graph.pre(nodes, topological_sort_order)
sub_graph = graph.subgraph(pre + list(nodes))
result: Set[Variable] = set()
for node in nodes:
result.update(sub_graph.get_district(node))
for node in result.copy():
result.update(sub_graph.directed.predecessors(node))
return result - set(nodes)
def _markov_blanket_overlap(graph: NxMixedGraph, u: Variable, v: Variable) -> bool:
return u in get_district_and_predecessors(graph, [v]) or v in get_district_and_predecessors(
graph, [u]
)
def iter_moral_links(graph: NxMixedGraph) -> Iterable[Tuple[Variable, Variable]]:
"""Generate links to ensure all co-parents in a graph are linked.
May generate links that already exist as we assume we are not working on a multi-graph.
:param graph: Graph to process
:yields: An collection of edges to add.
"""
# note that combinations(x, 2) returns an empty list when len(x) == 1
yield from chain.from_iterable(
combinations(graph.directed.predecessors(node), 2) for node in graph.nodes()
)
def get_nodes_in_directed_paths(
graph: NxMixedGraph,
sources: Union[Variable, Set[Variable]],
targets: Union[Variable, Set[Variable]],
) -> Set[Variable]:
"""Get all nodes appearing in directed paths from sources to targets.
:param graph: an NxMixedGraph
:param sources: source nodes
:param targets: target nodes
:return: the nodes on all causal paths from sources to targets
"""
sources = _ensure_set(sources)
targets = _ensure_set(targets)
if nx.is_directed_acyclic_graph(graph.directed):
return _get_nodes_in_directed_paths_dag(graph.directed, sources, targets)
else:
# note, this is a simpler implementation can use :func:`nx.all_simple_paths`,
# but it is less efficient since it requires potentially calculating the same
# paths over and over again.
return _get_nodes_in_directed_paths_cyclic(graph.directed, sources, targets)
def _get_nodes_in_directed_paths_dag(
graph: nx.DiGraph, sources: set[Variable], targets: set[Variable]
) -> set[Variable]:
tc: nx.DiGraph = nx.transitive_closure_dag(graph)
rv = {
node
for node in graph.nodes()
if any(
tc.has_edge(source, node) and tc.has_edge(node, target)
for source, target in itt.product(sources, targets)
)
}
for source, target in itt.product(sources, targets):
if tc.has_edge(source, target):
rv.add(source)
rv.add(target)
return rv
def _get_nodes_in_directed_paths_cyclic(
graph: nx.DiGraph, sources: set[Variable], targets: set[Variable]
) -> set[Variable]:
return {
node
for source, target in itt.product(sources, targets)
for causal_path in nx.all_simple_paths(graph, source, target)
for node in causal_path
}
def sympy_nested(glyph: str, *variables: Variable) -> "sympy.Symbol":
"""Create a sympy nested symbol."""
import sympy
inner_latex = ",".join(variable.to_latex() for variable in variables)
return sympy.Symbol(rf"{glyph}_{{{inner_latex}}}")