"""Implementation of algorithms from Hierarchical Causal Models by E.N. Weinstein and D.M. Blei.
.. seealso::
https://arxiv.org/abs/2401.05330
"""
from __future__ import annotations
import itertools as itt
import typing
from collections.abc import Collection, Iterable, Sequence
from dataclasses import dataclass, field
from itertools import combinations
from typing import TYPE_CHECKING, Any, TypeAlias
import networkx as nx
from y0.dsl import Variable
from y0.graph import NxMixedGraph
if TYPE_CHECKING:
import pygraphviz
__all__ = [
"HierarchicalCausalModel",
"HierarchicalStructuralCausalModel",
"QVariable",
"augment_collapsed_model",
"augment_from_mechanism",
"augmentation_mechanism",
"collapse_hcm",
"marginalize_augmented_model",
]
SubunitGraph: TypeAlias = nx.DiGraph
SUBUNITS_KEY = "cluster_subunits"
VHint: TypeAlias = typing.Union[str, Variable, "QVariable"]
def _upgrade(v: VHint) -> Variable:
if not isinstance(v, str):
return v
if v.startswith("Q^"):
return QVariable.parse_str(v)
return Variable(v)
[docs]
class HierarchicalCausalModel:
"""A class that wraps HCM functionality."""
observed: set[Variable]
subunits: set[Variable]
def __init__(self) -> None:
"""Initialize the HCM."""
self._graph = nx.DiGraph()
self.observed = set()
self.subunits = set()
[docs]
def add_observed_node(self, node: VHint) -> None:
"""Add an observed node."""
node = _upgrade(node)
self._graph.add_node(node)
self.observed.add(node)
[docs]
def add_unobserved_node(self, node: VHint) -> None:
"""Add an unobserved node."""
self._graph.add_node(_upgrade(node))
[docs]
def add_edge(
self,
u: VHint,
v: VHint,
**kwargs: Any,
) -> None:
"""Add an edge."""
self._graph.add_edge(_upgrade(u), _upgrade(v), **kwargs)
[docs]
def add_subunits(self, subunit_nodes: Iterable[VHint]) -> None:
"""Annotate the given nodes as the subunit graph."""
self.subunits.update(_upgrade(x) for x in subunit_nodes)
[docs]
def is_node_observed(self, node: VHint) -> bool:
"""Check if the node is observed."""
return _upgrade(node) in self.observed
[docs]
def get_observed(self) -> set[Variable]:
"""Return the set of observed variables (both unit and subunit) in the HCM."""
return self.observed
[docs]
def get_unobserved(self) -> set[Variable]:
"""Return the set of unobserved variables (both unit and subunit) in the HCM."""
return set(self._graph.nodes()) - self.observed
[docs]
def get_subunits(self) -> set[Variable]:
"""Return the set of subunit variables in the HCM."""
return self.subunits
[docs]
def get_units(self) -> set[Variable]:
"""Return the set of unit variables in the HCM."""
return set(self._graph.nodes()) - self.subunits
[docs]
def get_subunit_graph(self) -> SubunitGraph:
"""Return the subunit subgraph of the input HCM."""
return nx.subgraph(self._graph, self.subunits).copy()
[docs]
def get_parents(self, node: VHint) -> set[Variable]:
"""Return the set of parent/predecessor variables of the given variable in the HCM."""
return set(self._graph.predecessors(_upgrade(node)))
[docs]
def delete_node(self, node: VHint) -> None:
"""Delete a node."""
self._graph.remove_node(_upgrade(node))
[docs]
def delete_edge(self, u: VHint, v: VHint) -> None:
"""Delete an edge."""
self._graph.remove_edge(_upgrade(u), _upgrade(v))
[docs]
def nodes(self) -> list[Variable]:
"""Get all nodes."""
return list(self._graph.nodes())
[docs]
def predecessors(self, node: VHint) -> list[Variable]:
"""Get predecessors."""
return list(self._graph.predecessors(_upgrade(node)))
[docs]
def successors(self, node: VHint) -> list[Variable]:
"""Get successors."""
return list(self._graph.successors(_upgrade(node)))
# def set_subgraph_style(self, style: str) -> None:
# """Set the style on the subgraph."""
# self.graph.subgraphs()[0].graph_attr["style"] = style
#
# def set_shape(self, v: VHint, shape: str) -> None:
# """Set the shape of a node."""
# node = self.graph.get_node(_safe_q(v))
# node.attr["shape"] = shape
[docs]
def edges(self) -> list[tuple[Variable, Variable]]:
"""Get all edges."""
return list(self._graph.edges())
[docs]
def get_direct_unit_descendants(self, subunit_node: VHint) -> set[Variable]:
"""Return the set of direct unit descendants of the given subunit variable in the HCM."""
descendants = set(self.successors(_upgrade(subunit_node)))
direct_unit_descendants = set()
while descendants:
new_descendants = set()
for descendant in descendants:
if descendant in self.subunits:
new_descendants.update(self.successors(descendant))
else:
direct_unit_descendants.add(descendant)
descendants = new_descendants
return direct_unit_descendants
[docs]
@classmethod
def from_lists(
cls,
*,
observed_subunits: Sequence[VHint] | None = None,
unobserved_subunits: Sequence[VHint] | None = None,
observed_units: Sequence[VHint] | None = None,
unobserved_units: Sequence[VHint] | None = None,
edges: Sequence[tuple[VHint, VHint]] | None = None,
) -> HierarchicalCausalModel:
"""Create a hierarchical causal model from the given node and edge lists.
:param observed_subunits: a list of names for the observed subunit variables
:param unobserved_subunits: a list of names for the unobserved subunit variables
:param observed_units: a list of names for the observed unit variables
:param unobserved_units: a list of names for the unobserved unit variables
:param edges: a list of edges
:returns: a hierarchical causal model with subunit variables in the
:data:`SUBUNITS_KEY` subgraph
"""
if observed_subunits is None:
observed_subunits = []
if unobserved_subunits is None:
unobserved_subunits = []
if observed_units is None:
observed_units = []
if unobserved_units is None:
unobserved_units = []
hcm = cls()
for observed_node in itt.chain(observed_subunits, observed_units):
hcm.add_observed_node(observed_node)
for unobserved_node in itt.chain(unobserved_subunits, unobserved_units):
hcm.add_unobserved_node(unobserved_node)
for u, v in edges or []:
hcm.add_edge(u, v)
hcm.add_subunits(itt.chain(observed_subunits, unobserved_subunits))
return hcm
[docs]
def copy_hcm(self) -> HierarchicalCausalModel:
"""Return a copy of the HCM."""
obs = self.get_observed()
unobs = self.get_unobserved()
units = self.get_units()
subunits = self.get_subunits()
copy = self.from_lists(
observed_subunits=list(obs & subunits),
unobserved_subunits=list(unobs & subunits),
observed_units=list(obs & units),
unobserved_units=list(unobs & units),
edges=self._graph.edges(),
)
return copy
[docs]
def to_hcgm(self: HierarchicalCausalModel) -> HierarchicalCausalModel:
"""Convert an HCM to a hierarchical causal graphical model (HCGM) with promoted Q variables."""
hcgm = self.copy_hcm()
observed = self.get_observed()
subunits = self.get_subunits()
subunit_graph = self.get_subunit_graph()
for s in subunits:
q_variable = _create_qvar(subunit_graph, s)
parent_set = set(self.get_parents(s))
if (s in observed) & ((parent_set & subunits) <= observed):
hcgm.add_observed_node(q_variable)
else:
hcgm.add_unobserved_node(q_variable)
for unit_parent in parent_set & hcgm.get_units():
hcgm.delete_edge(unit_parent, s)
hcgm.add_edge(unit_parent, q_variable)
hcgm.add_edge(q_variable, s)
# TODO what's this for? Is it used besides making diagrams?
# hcgm.set_subgraph_style("solid")
return hcgm
[docs]
def to_hscm(self: HierarchicalCausalModel) -> HierarchicalStructuralCausalModel:
"""Convert the input HCM to an explicit hierarchical structural causal model (HSCM)."""
obs = self.get_observed()
unobs = self.get_unobserved()
units = self.get_units()
subunits = self.get_subunits()
edges = self.edges()
hscm = HierarchicalStructuralCausalModel.from_lists(
observed_subunits=list(obs & subunits),
unobserved_subunits=list(unobs & subunits),
observed_units=list(obs & units),
unobserved_units=list(unobs & units),
edges=edges,
)
return typing.cast(HierarchicalStructuralCausalModel, hscm)
[docs]
def to_admg(self, *, return_hcgm: bool = False) -> NxMixedGraph:
"""Return a collapsed hierarchical causal model.
:param return_hcgm: if True, returns the intermediate hierarchical causal
graphical models (HCGM) with subunits and promoted Q variables
:returns: a mixed graph
:raises NotImplementedError: currently cannot handle unobserved subunit
variables
"""
if (self.get_unobserved() & self.get_subunits()) != set():
raise NotImplementedError("Currently cannot handle unobserved subunit variables.")
hcgm = self.to_hcgm()
if return_hcgm:
hgcm_original = self.to_hcgm()
subunit_graph = self.get_subunit_graph()
q_variables: set[QVariable] = set()
for subunit in self.get_subunits():
q_variable = _create_qvar(subunit_graph, subunit)
q_variables.add(q_variable)
for dud in self.get_direct_unit_descendants(subunit):
hcgm.add_edge(q_variable, dud)
hcgm.delete_node(subunit)
undirected = [
pair
for node in hcgm.get_unobserved()
for pair in combinations(hcgm.successors(node), r=2)
]
directed = [(source, target) for source, target in hcgm.edges() if source in hcgm.observed]
collapsed = NxMixedGraph.from_edges(directed=directed, undirected=undirected)
for q_variable in q_variables: # loop to check for and add disconnected Q variables
if q_variable not in collapsed:
collapsed.add_node(q_variable)
if return_hcgm:
return collapsed, hgcm_original # type:ignore
else:
return collapsed
[docs]
def to_pygraphviz(self) -> pygraphviz.AGraph: # TODO make style configurable with defaults
"""Get a pygraphviz object."""
import pygraphviz as pgv
def _pgv(n: Variable) -> str:
if isinstance(n, QVariable):
return n.pgv_str()
else:
return n.name
rv = pgv.AGraph(directed=True)
for node in self._graph.nodes():
if node in self.observed:
rv.add_node(_pgv(node), style="filled", color="lightgrey")
else:
rv.add_node(_pgv(node))
rv.add_subgraph(
[_pgv(node) for node in self.subunits],
name=SUBUNITS_KEY,
style="dashed",
label="m",
)
for u, v in self._graph.edges():
rv.add_edge(_pgv(u), _pgv(v))
return rv
[docs]
class HierarchicalStructuralCausalModel(HierarchicalCausalModel):
"""A subclass of HCM that wraps HSCM functionality."""
exogenous_noise: set[Variable]
def __init__(self) -> None:
"""Initialize the HSCM."""
self.exogenous_noise = set()
super().__init__()
[docs]
def add_unobserved_node(self, node: VHint) -> None:
"""Add an unobserved node and its exogenous noise."""
node = _upgrade(node)
unit_exogenous = _upgrade(
f"y_i^{node}"
) # TODO make "y_i" part configurable, but default to "y_i"
subunit_exogenous = _upgrade(
f"e_ij^{node}" # TODO same as above; make configurable with default
) # TODO how to do e_{ij} while also formatting {node}?
self._graph.add_node(node)
self._graph.add_edge(unit_exogenous, node)
self._graph.add_edge(subunit_exogenous, node)
self.add_subunits([subunit_exogenous])
self.exogenous_noise.update({unit_exogenous, subunit_exogenous})
[docs]
def add_observed_node(self, node: VHint) -> None:
"""Add an observed node and its exogenous noise."""
node = _upgrade(node)
self.add_unobserved_node(node)
self.observed.add(node)
[docs]
def add_edge(
self,
u: VHint,
v: VHint,
**kwargs: Any,
) -> None:
"""Add an edge."""
if any(node in self.exogenous_noise for node in {u, v}):
raise ValueError("Cannot add an edge to or from exogenous noise variables.")
else:
HierarchicalCausalModel.add_edge(self, u, v, **kwargs)
# self._graph.add_edge(_upgrade(u), _upgrade(v), **kwargs)
[docs]
def get_exogenous_noise(self) -> set[Variable]:
"""Return the set of exogenous noise variables in the HSCM."""
return self.exogenous_noise
[docs]
def to_hcm(self) -> HierarchicalCausalModel:
"""Convert the HSCM to a hierarchical causal model (HCM)."""
endogenous = set(self.nodes()) - self.get_exogenous_noise()
obs = self.get_observed() & endogenous
unobs = self.get_unobserved() & endogenous
units = self.get_units() & endogenous
subunits = self.get_subunits() & endogenous
hcm = HierarchicalCausalModel.from_lists(
observed_subunits=list(obs & subunits),
unobserved_subunits=list(unobs & subunits),
observed_units=list(obs & units),
unobserved_units=list(unobs & units),
edges=self._graph.edges(nbunch=list(endogenous)),
)
return hcm
[docs]
def to_hcgm(self: HierarchicalStructuralCausalModel) -> HierarchicalCausalModel:
"""Convert an HSCM to a hierarchical causal graphical model (HCGM) with promoted Q variables."""
return self.to_hcm().to_hcgm()
[docs]
def to_admg(self, *, return_hcgm: bool = False) -> NxMixedGraph:
"""Return a collapsed hierarchical causal model.
:param return_hcgm: if True, returns the intermediate hierarchical causal
graphical models (HCGM) with subunits and promoted Q variables
:returns: a mixed graph
"""
return self.to_hcm().to_admg(return_hcgm=return_hcgm)
[docs]
def to_pygraphviz(self) -> pygraphviz.AGraph: # TODO make style configurable with defaults
"""Get a pygraphviz object."""
import pygraphviz as pgv
def _pgv(n: Variable) -> str:
if isinstance(n, QVariable):
return n.pgv_str()
else:
return n.name
rv = pgv.AGraph(directed=True)
for node in self._graph.nodes():
if node in self.observed:
rv.add_node(_pgv(node), style="filled", color="lightgrey")
else:
rv.add_node(_pgv(node))
for node in self.nodes():
if node in self.get_exogenous_noise():
rv.get_node(_pgv(node)).attr["shape"] = "plaintext"
else:
rv.get_node(_pgv(node)).attr["shape"] = "square"
rv.add_subgraph(
[_pgv(node) for node in self.subunits],
name=SUBUNITS_KEY,
style="dashed",
label="m",
)
for u, v in self._graph.edges():
rv.add_edge(_pgv(u), _pgv(v))
return rv
def get_ancestors(subunit_graph: SubunitGraph, start_node: VHint) -> set[Variable]:
"""Perform a depth-first search to get all ancestors of a node in a subunit graph.
:param subunit_graph: A subunit graph
:param start_node: the node to start the search from
:returns: set of all ancestor nodes
"""
start_node = _upgrade(start_node)
stack = [start_node]
ancestors = set()
while stack:
node = stack.pop()
if node in ancestors:
continue
ancestors.add(node)
for predecessor in subunit_graph.predecessors(node):
if predecessor not in ancestors:
stack.append(predecessor)
# Remove the start_node from the visited set if you don't want to include it
ancestors.remove(start_node)
return ancestors
[docs]
@dataclass(frozen=True, order=True, repr=False)
class QVariable(Variable):
"""A variable, extended with a list of parents."""
parents: frozenset[Variable] = field(default_factory=frozenset)
[docs]
def get_lhs(self) -> Variable:
"""Get the left-hand side (i.e., child)."""
return Variable(self.name)
[docs]
def get_all(self) -> frozenset[Variable]:
"""Get the union of the left-hand side and right-hand side."""
return self.parents.union({self.get_lhs()})
[docs]
def pgv_str(self) -> str:
"""Get a string compatible with the V1 implementation."""
child_name = self.name
if not self.parents:
return f"Q^{child_name}"
parent_str = ",".join(sorted(p.name for p in self.parents))
return f"Q^{{{child_name}|{parent_str}}}"
def _iter_variables(self) -> Iterable[Variable]:
yield self.get_lhs()
yield from self.parents
[docs]
def to_text(self) -> str:
"""Get text."""
return self.pgv_str()
[docs]
def to_y0(self) -> str:
"""Get a string that can be parsed by Y0."""
p = [p.name for p in self.parents]
return f"QVariable({self.name}, {p})"
[docs]
def to_latex(self) -> str:
"""Get latex for the q-variable."""
return self.pgv_str()
[docs]
@classmethod
def parse_str(cls, s: str) -> QVariable:
"""Return the subunit variables of the input Q variable, separated by the conditional."""
if not s.startswith("Q^"):
raise ValueError(f"Q-variable string should start with `Q^`: {s}")
if "|" not in s:
lhs = s[2:]
if len(lhs) > 1:
raise ValueError("Invalid format for input Q variable")
if not lhs:
raise ValueError(f"Invalid q-variable string: {s}\n\nMissing left-hand side")
return cls(name=lhs)
var_str = s[3:-1]
parse1 = var_str.split("|")
if len(parse1) != 2:
raise ValueError("Invalid format for input Q variable")
lhs = parse1[0]
rhs = parse1[1].split(",")
if len(lhs) > 1:
raise ValueError("Invalid format for input Q variable")
if not lhs:
raise ValueError(f"Invalid q-variable string: {s}\n\nMissing left-hand side")
return cls(name=lhs, parents=frozenset(Variable(p) for p in rhs))
def _create_qvar(subunit_graph: SubunitGraph, subunit_node: Variable) -> QVariable:
"""Return a y0 Variable for the unit-level Q variable of the given subunit variable in the HCM."""
return QVariable(
name=subunit_node.name, parents=frozenset(subunit_graph.predecessors(subunit_node))
)
def _str_or_q(augmentation_variable: str | QVariable) -> QVariable:
if isinstance(augmentation_variable, str):
return QVariable.parse_str(augmentation_variable)
return augmentation_variable
[docs]
def augment_from_mechanism(
collapsed: NxMixedGraph, aug: str | QVariable, mechanism: Iterable[QVariable]
) -> NxMixedGraph:
"""Augment a collapsed model with a given augmentation variable and its mechanism.
:param collapsed: NxMixedGraph of the input collapsed model
:param aug: new variable to add into the collapsed model
:param mechanism: collection of variables in the collapsed model that determine the
augmentation_variable
:returns: NxMixedGraph of the augmented model
:raises TypeError: if any of the parts of the mechanism aren't q-variables
:raises ValueError: input mechanism variables must be contained in the collapsed
model
"""
aug = _str_or_q(aug)
augmented = collapsed.copy()
mechanism = set(mechanism)
if any(not isinstance(m, QVariable) for m in mechanism):
raise TypeError("all variables in mechanism need to be QVariables")
if not mechanism <= collapsed.nodes():
raise ValueError("The input mechanism must be contained in the collapsed model.")
augmented.add_node(aug)
for var in mechanism:
augmented.add_directed_edge(var, aug)
for var in set(augmented.nodes()) - {aug}:
parents = set(augmented.directed.predecessors(var))
if mechanism <= parents:
augmented.add_directed_edge(aug, var)
for parent in mechanism:
augmented.directed.remove_edge(parent, var)
return augmented
[docs]
def augmentation_mechanism(
subunit_graph: SubunitGraph, augmentation_variable: str | QVariable
) -> list[QVariable]:
"""Generate augmentation mechanism."""
augmentation_variable = _str_or_q(augmentation_variable)
nodes = set(subunit_graph.nodes())
lhs_var = augmentation_variable.get_lhs()
if lhs_var not in nodes:
raise KeyError(
f"Augmentation variable's left hand side {lhs_var} is not in subunit graph: {nodes}"
)
if not augmentation_variable.parents.issubset(nodes):
raise KeyError(
f"Augmentation variable's right hand side {augmentation_variable.parents} are not all not in subunit graph: {nodes}"
)
rhs = augmentation_variable.parents
mechanism = [_create_qvar(subunit_graph, lhs_var)]
direct_subunit_descendants = get_ancestors(subunit_graph, lhs_var).difference(rhs)
for dsd in direct_subunit_descendants:
mechanism.append(_create_qvar(subunit_graph, dsd))
return mechanism
[docs]
def collapse_hcm(model: HierarchicalCausalModel, return_hcgm: bool = False) -> NxMixedGraph:
"""Collapse the given hierarchical model according to Algorithm 1 of the HCM paper."""
return model.to_admg(return_hcgm=return_hcgm) # TODO handle input HSCM class as well?
[docs]
def augment_collapsed_model(
model: NxMixedGraph,
subunit_graph: SubunitGraph,
augmentation_variable: QVariable | str,
mechanism: Iterable[QVariable] | None = None,
) -> NxMixedGraph:
"""Augment given variable into the given collapsed model."""
# TODO test
augmentation_variable = _str_or_q(augmentation_variable)
if mechanism is None:
mechanism = augmentation_mechanism(subunit_graph, augmentation_variable)
augmented = augment_from_mechanism(model, augmentation_variable, mechanism)
return augmented
[docs]
def marginalize_augmented_model(
augmented: NxMixedGraph,
augmentation_variable: str | QVariable,
marginal_parents: Collection[str | QVariable],
) -> NxMixedGraph:
"""Marginalize out a given collection of variables from an augmented model.
:param augmented: NxMixedGraph of the input augmented model
:param augmentation_variable: the variable that was previously augmented into the
model
:param marginal_parents: collection of parents of the augmentation variable to be
marginalized out.
:returns: NxMixedGraph of the marginalized model
:raises ValueError: augmentation_variable must be in the augmented model
:raises ValueError: marginal_parents cannot be all the parents of
augmentation_variable
:raises ValueError: augmentation_variable must be the only child of the each
marginal parent
"""
augmentation_variable = _str_or_q(augmentation_variable)
marginal_parents = [_str_or_q(mp) for mp in marginal_parents]
marginalized = augmented.copy()
check_set = {augmentation_variable}
mechanism = set(augmented.directed.predecessors(augmentation_variable))
if augmentation_variable not in augmented.nodes():
raise ValueError("Augmentation variable must be in the input augmented model.")
if set(marginal_parents) == mechanism:
raise ValueError("Cannot marginalize all parents of the augmentation variable.")
for parent in marginal_parents:
if set(marginalized.directed.successors(parent)) != check_set:
raise ValueError(
"The augmentation variable must be the only child of the marginalized parents."
)
directed_grandparents = marginalized.directed.predecessors(parent)
for gp in directed_grandparents:
marginalized.add_directed_edge(gp, augmentation_variable)
marginalized.directed.remove_node(parent)
undirected_grandparents = marginalized.undirected.neighbors(parent)
for gp in undirected_grandparents:
marginalized.add_undirected_edge(gp, augmentation_variable)
marginalized.undirected.remove_node(parent)
return marginalized