# -*- coding: utf-8 -*-
"""An implementation of the identification algorithm."""
from typing import List, Sequence
from .utils import Identification, Unidentifiable
from ...dsl import Expression, P, Probability, Product, Sum, Variable
from ...graph import NxMixedGraph
__all__ = [
"identify",
]
[docs]
def identify(identification: Identification) -> Expression:
"""Run the ID algorithm from [shpitser2006]_.
:param identification: The identification tuple
:returns: the expression corresponding to the identification
:raises Unidentifiable: If no appropriate identification can be found
See also :func:`identify_outcomes` for a more idiomatic way of running
the ID algorithm given a graph, treatments, and outcomes.
"""
graph = identification.graph
treatments = identification.treatments
outcomes = identification.outcomes
vertices = set(graph.nodes())
# line 1
if not treatments:
return line_1(identification)
# line 2
outcomes_and_ancestors = graph.ancestors_inclusive(outcomes)
not_outcomes_or_ancestors = vertices.difference(outcomes_and_ancestors)
if not_outcomes_or_ancestors:
return identify(line_2(identification))
# line 3
no_effect_on_outcome = graph.get_no_effect_on_outcomes(treatments, outcomes)
if no_effect_on_outcome:
return identify(line_3(identification))
# line 4
graph_without_treatments = graph.remove_nodes_from(treatments)
if not graph_without_treatments.is_connected():
expression = Product.safe(map(identify, line_4(identification)))
return Sum.safe(
expression=expression,
ranges=vertices.difference(outcomes | treatments),
)
# line 5
if graph.is_connected(): # e.g., there's only 1 c-component, and it encompasses all vertices
raise Unidentifiable(graph.nodes(), graph_without_treatments.districts())
# line 6
district_without_treatment = _get_single_district(graph_without_treatments)
if district_without_treatment in graph.districts():
parents = list(graph.topological_sort())
expression = Product.safe(p_parents(v, parents) for v in district_without_treatment)
ranges = district_without_treatment - outcomes
return Sum.safe(
expression=expression,
ranges=ranges,
)
# line 7
return identify(line_7(identification))
def _get_single_district(graph: NxMixedGraph) -> frozenset[Variable]:
districts = graph.districts()
if len(districts) != 1:
raise RuntimeError
return districts.pop()
def line_1(identification: Identification) -> Expression:
r"""Run line 1 of identification algorithm.
If no action has been taken, the effect on :math:`\mathbf Y` is just the marginal of
the observational distribution
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:returns: The marginal of the outcome variables
"""
outcomes = identification.outcomes
vertices = set(identification.graph.nodes())
return Sum.safe(
expression=identification.estimand,
ranges=vertices.difference(outcomes),
)
def line_2(identification: Identification) -> Identification:
r"""Run line 2 of the identification algorithm.
If we are interested in the effect on :math:`\mathbf Y`, it is sufficient to restrict our attention
on the parts of the model ancestral to :math:`\mathbf Y`.
.. math::
\text{if }\mathbf V - An(\mathbf Y)_G \neq \emptyset \\
\text{ return } \mathbf{ ID}\left(\mathbf y, \mathbf x\cap An(\mathbf Y)_G, \sum_{\mathbf V -
An(Y)_G}P, G_{An(\mathbf Y)}\right)
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:returns: The new estimand
:raises ValueError: If the line 2 precondition is not met
"""
graph = identification.graph
treatments = identification.treatments
outcomes = identification.outcomes
vertices = set(graph.nodes())
outcomes_and_ancestors = graph.ancestors_inclusive(outcomes)
not_outcomes_or_ancestors = vertices.difference(outcomes_and_ancestors)
outcome_ancestral_graph = graph.subgraph(outcomes_and_ancestors)
if not not_outcomes_or_ancestors:
raise ValueError("line 2 precondition not met")
return Identification.from_parts(
outcomes=outcomes,
treatments=treatments & outcomes_and_ancestors,
estimand=Sum.safe(expression=identification.estimand, ranges=not_outcomes_or_ancestors),
graph=outcome_ancestral_graph,
)
def line_3(identification: Identification) -> Identification:
r"""Run line 3 of the identification algorithm.
Forces an action on any node where such an action would have no
effect on :math:\mathbf Y`—assuming we already acted on
:math:`\mathbf X`. Since actions remove incoming arrows, we can
view line 3 as simplifying the causal graph we consider by
removing certain arcs from the graph, without affecting the
overall answer.
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:returns: The new estimand
:raises ValueError: If the preconditions for line 3 aren't met.
"""
outcomes = identification.outcomes
treatments = identification.treatments
graph = identification.graph
no_effect_on_outcome = graph.get_no_effect_on_outcomes(treatments, outcomes)
if not no_effect_on_outcome:
raise ValueError(
'Line 3 precondition not met. There were no variables in "no_effect_on_outcome"'
)
return identification.with_treatments(no_effect_on_outcome)
def line_4(identification: Identification) -> List[Identification]:
r"""Run line 4 of the identification algorithm.
The key line of the algorithm, it decomposes the problem into a set
of smaller problems using the key property of *c-component
factorization* of causal models. If the entire graph is a single
C-component already, further problem decomposition is impossible,
and we must provide base cases. :math:`\mathbf{ID}` has three base
cases.
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:returns: A list of new estimands
:raises ValueError: If the precondition that there are more than 1 districts without treatments is not met
"""
treatments = identification.treatments
estimand = identification.estimand
graph = identification.graph
vertices = set(graph.nodes())
# line 4
graph_without_treatments = graph.remove_nodes_from(treatments)
districts_without_treatment = graph_without_treatments.districts()
if len(districts_without_treatment) <= 1:
raise ValueError("Line 4 precondition not met")
return [
Identification.from_parts(
outcomes=set(district_without_treatment),
treatments=vertices - district_without_treatment,
estimand=estimand,
graph=graph,
)
for district_without_treatment in districts_without_treatment
]
def line_5(identification: Identification) -> None:
r"""Run line 5 of the identification algorithm.
Fails because it finds two C-components, the graph :math:`G`
itself, and a subgraph :math:`S` that does not contain any
:math:`\mathbf X` nodes. But that is exactly one of the properties
of C-forests that make up a hedge. In fact, it turns out that it
is always possible to recover a hedge from these two c-components.
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:raises Unidentifiable: If line 5 realizes that identification is not possible
"""
treatments = identification.treatments
graph = identification.graph
vertices = set(graph.nodes())
graph_without_treatments = graph.remove_nodes_from(treatments)
districts_without_treatment = graph_without_treatments.districts()
# line 5
districts = graph.districts()
if districts == {frozenset(vertices)}:
raise Unidentifiable(districts, districts_without_treatment)
def line_6(identification: Identification) -> Expression:
r"""Run line 6 of the identification algorithm.
Asserts that if there are no bidirected arcs from :math:`X` to the other nodes in the current subproblem
under consideration, then we can replace acting on :math:`X` by conditioning, and thus solve the subproblem.
..math::
\text{ if }S\in C(G) \\
\text{ return }\sum_{S - \mathbf y}\prod_{\{i|V_i\in S\}}P\left(v_i|v_\pi^{(i-1)}\right)
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:returns: A list of new estimands
:raises ValueError: If line 6 precondition is not met
"""
outcomes = identification.outcomes
treatments = identification.treatments
graph = identification.graph
districts = graph.districts()
graph_without_treatments = graph.remove_nodes_from(treatments)
district_without_treatments = _get_single_district(graph_without_treatments)
# line 6
if district_without_treatments not in districts:
raise ValueError("Line 6 precondition not met")
parents = list(graph.topological_sort())
expression = Product.safe(p_parents(v, parents) for v in district_without_treatments)
ranges = district_without_treatments - outcomes
return Sum.safe(
expression=expression,
ranges=ranges,
)
def line_7(identification: Identification) -> Identification:
r"""Run line 7 of the identification algorithm.
The most complex case where :math:`\mathbf X` is partitioned into
two sets, :math:`\mathbf W` which contain bidirected arcs into
other nodes in the subproblem, and :math:`\mathbf Z` which do
not. In this situation, identifying :math:`P(\mathbf y|do(\mathbf
x))` from :math:`P(v)` is equivalent to identifying
:math:`P(\mathbf y|do(\mathbf w))` from :math:`P(\mathbf
V|do(\mathbf z))`, since :math:`P(\mathbf y|do(\mathbf x)) =
P(\mathbf y|do(\mathbf w), do(\mathbf z))`. But the term
:math:`P(\mathbf V|do(\mathbf z))` is identifiable using the
previous base case, so we can consider the subproblem of
identifying :math:`P(\mathbf y|do(\mathbf w))`
.. math::
\text{ if }(\exists S')S\subset S'\in C(G) \\
\text{ return }\mathbf{ID}\left(\mathbf y, \mathbf x\cap S',
\prod_{\{i|V_i\in S'\}}P(V_i|V_\pi^{(i-1)}\cap S', V_\pi^{(i-1)} -
S'), G_{S'}\right)
:param identification: The data structure with the treatment, outcomes, estimand, and graph
:returns: A new estimand
:raises ValueError: If line 7 does not find a suitable district
"""
outcomes = identification.outcomes
treatments = identification.treatments
graph = identification.graph
graph_without_treatments = graph.remove_nodes_from(treatments)
# line 7 precondition requires single district
district_without_treatments = _get_single_district(graph_without_treatments)
# line 7
for district in graph.districts():
if district_without_treatments < district:
parents = list(graph.topological_sort())
return Identification.from_parts(
outcomes=outcomes,
treatments=treatments & district,
estimand=Product.safe(p_parents(v, parents) for v in district),
graph=graph.subgraph(district),
)
raise ValueError("Could not identify suitable district")
def p_parents(child: Variable, ordering: Sequence[Variable]) -> Probability:
"""Get a probability expression based on a topological ordering.
:param child: The child variable
:param ordering: A topologically ordered sequence of all variables. All occurring before the
child will be used as parents.
:return: A probability expression
"""
return P(child | ordering[: ordering.index(child)])