"""An implementation of the identification algorithm from [shpitser2006]_."""
from collections.abc import Sequence
from typing import Annotated
from .utils import Identification, Unidentifiable
from ...dsl import Expression, P, Probability, Product, Sum, Variable
from ...graph import NxMixedGraph
from ...util import InPaperAs
__all__ = [
"identify",
]
[docs]
def identify(
identification: Identification, *, ordering: Sequence[Variable] | None = None
) -> Expression:
"""Run the ID algorithm from [shpitser2006]_.
:param identification: The identification tuple
:param ordering: A topological ordering of the variables. If not passed, is
calculated from the directed component of the mixed graph.
:returns: the expression corresponding to the identification
:raises Unidentifiable: If no appropriate identification can be found
See also :func:`identify_outcomes` for a more idiomatic way of running the ID
algorithm given a graph, treatments, and outcomes.
"""
return _identify(
graph=identification.graph,
treatments=identification.treatments,
outcomes=identification.outcomes,
# estimand begins as the joint distribution over all
# nodes in the graph and is progressively refined
estimand=identification.estimand,
ordering=ordering,
)
def _identify( # noqa:C901
graph: NxMixedGraph,
treatments: Annotated[set[Variable], InPaperAs("x")],
outcomes: Annotated[set[Variable], InPaperAs("y")],
estimand: Expression,
*,
ordering: Sequence[Variable] | None = None,
) -> Expression:
"""Run the ID algorithm from [shpitser2006]_.
:param ordering: A topological ordering of the variables. If not passed, is
calculated from the directed component of the mixed graph.
:returns: the expression corresponding to the identification
:raises Unidentifiable: If no appropriate identification can be found
"""
# see page 5 of https://cdn.aaai.org/AAAI/2006/AAAI06-191.pdf
nodes: Annotated[set[Variable], InPaperAs(r"\mathbf{v}")] = set(graph.nodes())
# line 1
if not treatments:
# FIXME why is estimand passed here? Why not Probability.safe(nodes)?
# this is a problem when recurring on a subgraph -
return Sum.safe(expression=estimand, ranges=nodes.difference(outcomes))
# line 2
outcomes_and_ancestors = graph.ancestors_inclusive(outcomes)
not_outcomes_or_ancestors: Annotated[set[Variable], InPaperAs(r"An(\mathbf{Y})_G")] = (
nodes.difference(outcomes_and_ancestors)
)
if not_outcomes_or_ancestors:
return _identify(
outcomes=outcomes,
treatments=treatments & outcomes_and_ancestors,
# FIXME there's a closed form way for calculating this that isn't
# based on an external estimand. should use that instead.
estimand=Sum.safe(expression=estimand, ranges=not_outcomes_or_ancestors),
graph=graph.subgraph(outcomes_and_ancestors),
ordering=ordering,
)
# line 3
no_effect_on_outcome = graph.get_no_effect_on_outcomes(treatments, outcomes)
if no_effect_on_outcome:
return _identify(
outcomes=outcomes,
treatments=treatments | no_effect_on_outcome,
estimand=estimand,
graph=graph,
ordering=ordering,
)
# line 4.1 C(G \ X) = {S1, ..., Sk}
graph_without_treatments: Annotated[NxMixedGraph, InPaperAs(r"G \setminus \mathbf{X}")] = (
graph.remove_nodes_from(treatments)
)
districts_without_treatment: Annotated[
set[frozenset[Variable]], InPaperAs(r"C(G \setminus \mathbf{X})")
] = graph_without_treatments.districts()
if not graph_without_treatments.is_connected():
# line 4.2a
expression = Product.safe(
_identify(
outcomes=set(district_without_treatment),
treatments=nodes - district_without_treatment,
estimand=estimand,
graph=graph,
ordering=ordering,
)
for district_without_treatment in districts_without_treatment
)
# line 4.2b
return Sum.safe(
expression=expression,
ranges=nodes.difference(outcomes | treatments),
)
# line 5, if C(G) = {G},
if graph.is_connected(): # e.g., there's only 1 c-component, and it encompasses all nodes
raise Unidentifiable(graph.nodes(), districts_without_treatment)
if len(districts_without_treatment) != 1: # pragma: no cover
raise RuntimeError
# line 4.3, C(G \setminus X) = {S}
district_without_treatment: Annotated[frozenset[Variable], InPaperAs("S")] = next(
iter(districts_without_treatment)
)
# TODO move this up, but causes some errors
if ordering is None:
ordering = graph.topological_sort()
else:
# in case ordering was passed in, remove any variables
# that are irrelevant to the current graph. these might
# exist if the sort was run on a supergraph
ordering = [v for v in ordering if v in nodes]
# line 6, if S ∈ C(G)
if district_without_treatment in graph.districts():
return Sum.safe(
expression=_district_product(district_without_treatment, ordering),
ranges=district_without_treatment - outcomes,
)
# line 7, if (∃S')S ⊂ S' ∈ C(G)
for district in graph.districts():
if district_without_treatment < district:
return _identify(
graph=graph.subgraph(district),
outcomes=outcomes,
treatments=treatments & district,
estimand=_district_product(district, ordering),
ordering=ordering,
)
raise RuntimeError # pragma: no cover
def _get_single_district(graph: NxMixedGraph) -> frozenset[Variable]:
districts = graph.districts()
if len(districts) != 1:
raise RuntimeError
return districts.pop()
def line_1(
identification: Identification,
) -> Annotated[Expression, InPaperAs(r"\sum_{v - y} P(\mathbf{v})")]:
r"""Run line 1 of identification algorithm.
If no action has been taken, the effect on :math:`\mathbf Y` is just the marginal of
the observational distribution
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:returns: The marginal of the outcome variables
"""
outcomes = identification.outcomes
nodes = set(identification.graph.nodes())
return Sum.safe(
expression=identification.estimand,
ranges=nodes.difference(outcomes),
)
def line_2(identification: Identification) -> Identification:
r"""Run line 2 of the identification algorithm.
If we are interested in the effect on :math:`\mathbf Y`, it is sufficient to
restrict our attention on the parts of the model ancestral to :math:`\mathbf Y`.
.. math::
\text{if }\mathbf V - An(\mathbf Y)_G \neq \emptyset \\
\text{ return } \mathbf{ ID}\left(\mathbf y, \mathbf x\cap An(\mathbf Y)_G, \sum_{\mathbf V -
An(Y)_G}P, G_{An(\mathbf Y)}\right)
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:returns: The new estimand
:raises ValueError: If the line 2 precondition is not met
"""
graph = identification.graph
treatments = identification.treatments
outcomes = identification.outcomes
nodes = set(graph.nodes())
outcomes_and_ancestors = graph.ancestors_inclusive(outcomes)
not_outcomes_or_ancestors = nodes.difference(outcomes_and_ancestors)
if not not_outcomes_or_ancestors:
raise ValueError("line 2 precondition not met")
reduced_treatments: Annotated[set[Variable], InPaperAs(r"x ^ An(\mathbf{Y})_G")] = (
treatments & outcomes_and_ancestors
)
identification = Identification.from_parts(
outcomes=outcomes,
treatments=reduced_treatments,
estimand=Sum.safe(expression=identification.estimand, ranges=not_outcomes_or_ancestors),
graph=graph.subgraph(outcomes_and_ancestors),
)
return identification
def line_3(identification: Identification) -> Identification:
r"""Run line 3 of the identification algorithm.
Forces an action on any node where such an action would have no effect on
:math:\mathbf Y`—assuming we already acted on :math:`\mathbf X`. Since actions
remove incoming arrows, we can view line 3 as simplifying the causal graph we
consider by removing certain arcs from the graph, without affecting the overall
answer.
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:returns: The new estimand
:raises ValueError: If the preconditions for line 3 aren't met.
"""
outcomes = identification.outcomes
treatments = identification.treatments
graph = identification.graph
no_effect_on_outcome = graph.get_no_effect_on_outcomes(treatments, outcomes)
if not no_effect_on_outcome:
raise ValueError(
'Line 3 precondition not met. There were no variables in "no_effect_on_outcome"'
)
return identification.with_treatments(no_effect_on_outcome)
def line_4(identification: Identification) -> list[Identification]:
r"""Run line 4 of the identification algorithm.
The key line of the algorithm, it decomposes the problem into a set of smaller
problems using the key property of *c-component factorization* of causal models. If
the entire graph is a single C-component already, further problem decomposition is
impossible, and we must provide base cases. :math:`\mathbf{ID}` has three base
cases.
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:returns: A list of new estimands
:raises ValueError: If the precondition that there are more than 1 districts without
treatments is not met
"""
treatments = identification.treatments
estimand = identification.estimand
graph = identification.graph
nodes = set(graph.nodes())
# line 4
graph_without_treatments: Annotated[NxMixedGraph, InPaperAs(r"G \setminus \mathbf{X}")] = (
graph.remove_nodes_from(treatments)
)
districts_without_treatment: Annotated[
set[frozenset[Variable]], InPaperAs(r"C(G \setminus \mathbf{X})")
] = graph_without_treatments.districts()
if len(districts_without_treatment) <= 1:
raise ValueError("Line 4 precondition not met")
return [
Identification.from_parts(
outcomes=district_without_treatment,
treatments=nodes - district_without_treatment,
estimand=estimand,
graph=graph,
)
for district_without_treatment in districts_without_treatment
]
def line_5(identification: Identification) -> None:
r"""Run line 5 of the identification algorithm.
Fails because it finds two C-components, the graph :math:`G` itself, and a subgraph
:math:`S` that does not contain any :math:`\mathbf X` nodes. But that is exactly one
of the properties of C-forests that make up a hedge. In fact, it turns out that it
is always possible to recover a hedge from these two c-components.
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:raises Unidentifiable: If line 5 realizes that identification is not possible
"""
treatments = identification.treatments
graph = identification.graph
nodes = set(graph.nodes())
graph_without_treatments = graph.remove_nodes_from(treatments)
districts_without_treatment = graph_without_treatments.districts()
# line 5
districts = graph.districts()
if districts == {frozenset(nodes)}:
raise Unidentifiable(districts, districts_without_treatment)
# TODO this line 6 isn't used in the actual implementation, delete or merge
def line_6(
identification: Identification, *, ordering: Sequence[Variable] | None = None
) -> Expression:
r"""Run line 6 of the identification algorithm.
Asserts that if there are no bidirected arcs from :math:`X` to the other nodes in
the current subproblem under consideration, then we can replace acting on :math:`X`
by conditioning, and thus solve the subproblem.
..math:
::
\text{ if }S\in C(G) \\
\text{ return }\sum_{S - \mathbf y}\prod_{\{i|V_i\in S\}}P\left(v_i|v_\pi^{(i-1)}\right)
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:param ordering: A topologically ordered sequence of all variables. All occurring
before the child will be used as parents.
:returns: A list of new estimands
:raises ValueError: If line 6 precondition is not met
"""
outcomes = identification.outcomes
treatments = identification.treatments
graph = identification.graph
districts = graph.districts()
graph_without_treatments = graph.remove_nodes_from(treatments)
district_without_treatments = _get_single_district(graph_without_treatments)
# line 6
if district_without_treatments not in districts:
raise ValueError("Line 6 precondition not met")
if ordering is None:
ordering = graph.topological_sort()
expression = _district_product(district_without_treatments, ordering)
ranges = district_without_treatments - outcomes
return Sum.safe(expression=expression, ranges=ranges)
def line_7(
identification: Identification, ordering: Sequence[Variable] | None = None
) -> Identification:
r"""Run line 7 of the identification algorithm.
The most complex case where :math:`\mathbf X` is partitioned into two sets,
:math:`\mathbf W` which contain bidirected arcs into other nodes in the subproblem,
and :math:`\mathbf Z` which do not. In this situation, identifying :math:`P(\mathbf
y|do(\mathbf x))` from :math:`P(v)` is equivalent to identifying :math:`P(\mathbf
y|do(\mathbf w))` from :math:`P(\mathbf V|do(\mathbf z))`, since :math:`P(\mathbf
y|do(\mathbf x)) = P(\mathbf y|do(\mathbf w), do(\mathbf z))`. But the term
:math:`P(\mathbf V|do(\mathbf z))` is identifiable using the previous base case, so
we can consider the subproblem of identifying :math:`P(\mathbf y|do(\mathbf w))`
.. math::
\text{ if }(\exists S')S\subset S'\in C(G) \\
\text{ return }\mathbf{ID}\left(\mathbf y, \mathbf x\cap S',
\prod_{\{i|V_i\in S'\}}P(V_i|V_\pi^{(i-1)}\cap S', V_\pi^{(i-1)} -
S'), G_{S'}\right)
:param identification: The data structure with the treatment, outcomes, estimand,
and graph
:param ordering: A topologically ordered sequence of all variables. All occurring
before the child will be used as parents.
:returns: A new estimand
:raises ValueError: If line 7 does not find a suitable district
"""
outcomes = identification.outcomes
treatments = identification.treatments
graph = identification.graph
graph_without_treatments = graph.remove_nodes_from(treatments)
# line 7 precondition requires single district
district_without_treatment = _get_single_district(graph_without_treatments)
if ordering is None:
ordering = graph.topological_sort()
# line 7
for district in graph.districts():
if district_without_treatment < district:
return Identification.from_parts(
outcomes=outcomes,
treatments=treatments & district,
estimand=_district_product(district, ordering),
graph=graph.subgraph(district),
)
raise ValueError("Could not identify suitable district")
def _district_product(district: frozenset[Variable], ordering: Sequence[Variable]) -> Expression:
return Product.safe(p_parents(v, ordering) for v in district)
def p_parents(child: Variable, ordering: Sequence[Variable]) -> Probability:
"""Get a probability expression based on a topological ordering.
:param child: The child variable
:param ordering: A topologically ordered sequence of all variables. All occurring
before the child will be used as parents.
:returns: A probability expression
"""
return P(child | ordering[: ordering.index(child)])