Source code for y0.algorithm.identify.id_std

"""An implementation of the identification algorithm from [shpitser2006]_."""

from collections.abc import Sequence
from typing import Annotated

from .utils import Identification, Unidentifiable
from ...dsl import Expression, P, Probability, Product, Sum, Variable
from ...graph import NxMixedGraph
from ...util import InPaperAs

__all__ = [
    "identify",
]


[docs] def identify( identification: Identification, *, ordering: Sequence[Variable] | None = None ) -> Expression: """Run the ID algorithm from [shpitser2006]_. :param identification: The identification tuple :param ordering: A topological ordering of the variables. If not passed, is calculated from the directed component of the mixed graph. :returns: the expression corresponding to the identification :raises Unidentifiable: If no appropriate identification can be found See also :func:`identify_outcomes` for a more idiomatic way of running the ID algorithm given a graph, treatments, and outcomes. """ return _identify( graph=identification.graph, treatments=identification.treatments, outcomes=identification.outcomes, # estimand begins as the joint distribution over all # nodes in the graph and is progressively refined estimand=identification.estimand, ordering=ordering, )
def _identify( # noqa:C901 graph: NxMixedGraph, treatments: Annotated[set[Variable], InPaperAs("x")], outcomes: Annotated[set[Variable], InPaperAs("y")], estimand: Expression, *, ordering: Sequence[Variable] | None = None, ) -> Expression: """Run the ID algorithm from [shpitser2006]_. :param ordering: A topological ordering of the variables. If not passed, is calculated from the directed component of the mixed graph. :returns: the expression corresponding to the identification :raises Unidentifiable: If no appropriate identification can be found """ # see page 5 of https://cdn.aaai.org/AAAI/2006/AAAI06-191.pdf nodes: Annotated[set[Variable], InPaperAs(r"\mathbf{v}")] = set(graph.nodes()) # line 1 if not treatments: # FIXME why is estimand passed here? Why not Probability.safe(nodes)? # this is a problem when recurring on a subgraph - return Sum.safe(expression=estimand, ranges=nodes.difference(outcomes)) # line 2 outcomes_and_ancestors = graph.ancestors_inclusive(outcomes) not_outcomes_or_ancestors: Annotated[set[Variable], InPaperAs(r"An(\mathbf{Y})_G")] = ( nodes.difference(outcomes_and_ancestors) ) if not_outcomes_or_ancestors: return _identify( outcomes=outcomes, treatments=treatments & outcomes_and_ancestors, # FIXME there's a closed form way for calculating this that isn't # based on an external estimand. should use that instead. estimand=Sum.safe(expression=estimand, ranges=not_outcomes_or_ancestors), graph=graph.subgraph(outcomes_and_ancestors), ordering=ordering, ) # line 3 no_effect_on_outcome = graph.get_no_effect_on_outcomes(treatments, outcomes) if no_effect_on_outcome: return _identify( outcomes=outcomes, treatments=treatments | no_effect_on_outcome, estimand=estimand, graph=graph, ordering=ordering, ) # line 4.1 C(G \ X) = {S1, ..., Sk} graph_without_treatments: Annotated[NxMixedGraph, InPaperAs(r"G \setminus \mathbf{X}")] = ( graph.remove_nodes_from(treatments) ) districts_without_treatment: Annotated[ set[frozenset[Variable]], InPaperAs(r"C(G \setminus \mathbf{X})") ] = graph_without_treatments.districts() if not graph_without_treatments.is_connected(): # line 4.2a expression = Product.safe( _identify( outcomes=set(district_without_treatment), treatments=nodes - district_without_treatment, estimand=estimand, graph=graph, ordering=ordering, ) for district_without_treatment in districts_without_treatment ) # line 4.2b return Sum.safe( expression=expression, ranges=nodes.difference(outcomes | treatments), ) # line 5, if C(G) = {G}, if graph.is_connected(): # e.g., there's only 1 c-component, and it encompasses all nodes raise Unidentifiable(graph.nodes(), districts_without_treatment) if len(districts_without_treatment) != 1: # pragma: no cover raise RuntimeError # line 4.3, C(G \setminus X) = {S} district_without_treatment: Annotated[frozenset[Variable], InPaperAs("S")] = next( iter(districts_without_treatment) ) # TODO move this up, but causes some errors if ordering is None: ordering = graph.topological_sort() else: # in case ordering was passed in, remove any variables # that are irrelevant to the current graph. these might # exist if the sort was run on a supergraph ordering = [v for v in ordering if v in nodes] # line 6, if S ∈ C(G) if district_without_treatment in graph.districts(): return Sum.safe( expression=_district_product(district_without_treatment, ordering), ranges=district_without_treatment - outcomes, ) # line 7, if (∃S')S ⊂ S' ∈ C(G) for district in graph.districts(): if district_without_treatment < district: return _identify( graph=graph.subgraph(district), outcomes=outcomes, treatments=treatments & district, estimand=_district_product(district, ordering), ordering=ordering, ) raise RuntimeError # pragma: no cover def _get_single_district(graph: NxMixedGraph) -> frozenset[Variable]: districts = graph.districts() if len(districts) != 1: raise RuntimeError return districts.pop() def line_1( identification: Identification, ) -> Annotated[Expression, InPaperAs(r"\sum_{v - y} P(\mathbf{v})")]: r"""Run line 1 of identification algorithm. If no action has been taken, the effect on :math:`\mathbf Y` is just the marginal of the observational distribution :param identification: The data structure with the treatment, outcomes, estimand, and graph :returns: The marginal of the outcome variables """ outcomes = identification.outcomes nodes = set(identification.graph.nodes()) return Sum.safe( expression=identification.estimand, ranges=nodes.difference(outcomes), ) def line_2(identification: Identification) -> Identification: r"""Run line 2 of the identification algorithm. If we are interested in the effect on :math:`\mathbf Y`, it is sufficient to restrict our attention on the parts of the model ancestral to :math:`\mathbf Y`. .. math:: \text{if }\mathbf V - An(\mathbf Y)_G \neq \emptyset \\ \text{ return } \mathbf{ ID}\left(\mathbf y, \mathbf x\cap An(\mathbf Y)_G, \sum_{\mathbf V - An(Y)_G}P, G_{An(\mathbf Y)}\right) :param identification: The data structure with the treatment, outcomes, estimand, and graph :returns: The new estimand :raises ValueError: If the line 2 precondition is not met """ graph = identification.graph treatments = identification.treatments outcomes = identification.outcomes nodes = set(graph.nodes()) outcomes_and_ancestors = graph.ancestors_inclusive(outcomes) not_outcomes_or_ancestors = nodes.difference(outcomes_and_ancestors) if not not_outcomes_or_ancestors: raise ValueError("line 2 precondition not met") reduced_treatments: Annotated[set[Variable], InPaperAs(r"x ^ An(\mathbf{Y})_G")] = ( treatments & outcomes_and_ancestors ) identification = Identification.from_parts( outcomes=outcomes, treatments=reduced_treatments, estimand=Sum.safe(expression=identification.estimand, ranges=not_outcomes_or_ancestors), graph=graph.subgraph(outcomes_and_ancestors), ) return identification def line_3(identification: Identification) -> Identification: r"""Run line 3 of the identification algorithm. Forces an action on any node where such an action would have no effect on :math:\mathbf Y`—assuming we already acted on :math:`\mathbf X`. Since actions remove incoming arrows, we can view line 3 as simplifying the causal graph we consider by removing certain arcs from the graph, without affecting the overall answer. :param identification: The data structure with the treatment, outcomes, estimand, and graph :returns: The new estimand :raises ValueError: If the preconditions for line 3 aren't met. """ outcomes = identification.outcomes treatments = identification.treatments graph = identification.graph no_effect_on_outcome = graph.get_no_effect_on_outcomes(treatments, outcomes) if not no_effect_on_outcome: raise ValueError( 'Line 3 precondition not met. There were no variables in "no_effect_on_outcome"' ) return identification.with_treatments(no_effect_on_outcome) def line_4(identification: Identification) -> list[Identification]: r"""Run line 4 of the identification algorithm. The key line of the algorithm, it decomposes the problem into a set of smaller problems using the key property of *c-component factorization* of causal models. If the entire graph is a single C-component already, further problem decomposition is impossible, and we must provide base cases. :math:`\mathbf{ID}` has three base cases. :param identification: The data structure with the treatment, outcomes, estimand, and graph :returns: A list of new estimands :raises ValueError: If the precondition that there are more than 1 districts without treatments is not met """ treatments = identification.treatments estimand = identification.estimand graph = identification.graph nodes = set(graph.nodes()) # line 4 graph_without_treatments: Annotated[NxMixedGraph, InPaperAs(r"G \setminus \mathbf{X}")] = ( graph.remove_nodes_from(treatments) ) districts_without_treatment: Annotated[ set[frozenset[Variable]], InPaperAs(r"C(G \setminus \mathbf{X})") ] = graph_without_treatments.districts() if len(districts_without_treatment) <= 1: raise ValueError("Line 4 precondition not met") return [ Identification.from_parts( outcomes=district_without_treatment, treatments=nodes - district_without_treatment, estimand=estimand, graph=graph, ) for district_without_treatment in districts_without_treatment ] def line_5(identification: Identification) -> None: r"""Run line 5 of the identification algorithm. Fails because it finds two C-components, the graph :math:`G` itself, and a subgraph :math:`S` that does not contain any :math:`\mathbf X` nodes. But that is exactly one of the properties of C-forests that make up a hedge. In fact, it turns out that it is always possible to recover a hedge from these two c-components. :param identification: The data structure with the treatment, outcomes, estimand, and graph :raises Unidentifiable: If line 5 realizes that identification is not possible """ treatments = identification.treatments graph = identification.graph nodes = set(graph.nodes()) graph_without_treatments = graph.remove_nodes_from(treatments) districts_without_treatment = graph_without_treatments.districts() # line 5 districts = graph.districts() if districts == {frozenset(nodes)}: raise Unidentifiable(districts, districts_without_treatment) # TODO this line 6 isn't used in the actual implementation, delete or merge def line_6( identification: Identification, *, ordering: Sequence[Variable] | None = None ) -> Expression: r"""Run line 6 of the identification algorithm. Asserts that if there are no bidirected arcs from :math:`X` to the other nodes in the current subproblem under consideration, then we can replace acting on :math:`X` by conditioning, and thus solve the subproblem. ..math: :: \text{ if }S\in C(G) \\ \text{ return }\sum_{S - \mathbf y}\prod_{\{i|V_i\in S\}}P\left(v_i|v_\pi^{(i-1)}\right) :param identification: The data structure with the treatment, outcomes, estimand, and graph :param ordering: A topologically ordered sequence of all variables. All occurring before the child will be used as parents. :returns: A list of new estimands :raises ValueError: If line 6 precondition is not met """ outcomes = identification.outcomes treatments = identification.treatments graph = identification.graph districts = graph.districts() graph_without_treatments = graph.remove_nodes_from(treatments) district_without_treatments = _get_single_district(graph_without_treatments) # line 6 if district_without_treatments not in districts: raise ValueError("Line 6 precondition not met") if ordering is None: ordering = graph.topological_sort() expression = _district_product(district_without_treatments, ordering) ranges = district_without_treatments - outcomes return Sum.safe(expression=expression, ranges=ranges) def line_7( identification: Identification, ordering: Sequence[Variable] | None = None ) -> Identification: r"""Run line 7 of the identification algorithm. The most complex case where :math:`\mathbf X` is partitioned into two sets, :math:`\mathbf W` which contain bidirected arcs into other nodes in the subproblem, and :math:`\mathbf Z` which do not. In this situation, identifying :math:`P(\mathbf y|do(\mathbf x))` from :math:`P(v)` is equivalent to identifying :math:`P(\mathbf y|do(\mathbf w))` from :math:`P(\mathbf V|do(\mathbf z))`, since :math:`P(\mathbf y|do(\mathbf x)) = P(\mathbf y|do(\mathbf w), do(\mathbf z))`. But the term :math:`P(\mathbf V|do(\mathbf z))` is identifiable using the previous base case, so we can consider the subproblem of identifying :math:`P(\mathbf y|do(\mathbf w))` .. math:: \text{ if }(\exists S')S\subset S'\in C(G) \\ \text{ return }\mathbf{ID}\left(\mathbf y, \mathbf x\cap S', \prod_{\{i|V_i\in S'\}}P(V_i|V_\pi^{(i-1)}\cap S', V_\pi^{(i-1)} - S'), G_{S'}\right) :param identification: The data structure with the treatment, outcomes, estimand, and graph :param ordering: A topologically ordered sequence of all variables. All occurring before the child will be used as parents. :returns: A new estimand :raises ValueError: If line 7 does not find a suitable district """ outcomes = identification.outcomes treatments = identification.treatments graph = identification.graph graph_without_treatments = graph.remove_nodes_from(treatments) # line 7 precondition requires single district district_without_treatment = _get_single_district(graph_without_treatments) if ordering is None: ordering = graph.topological_sort() # line 7 for district in graph.districts(): if district_without_treatment < district: return Identification.from_parts( outcomes=outcomes, treatments=treatments & district, estimand=_district_product(district, ordering), graph=graph.subgraph(district), ) raise ValueError("Could not identify suitable district") def _district_product(district: frozenset[Variable], ordering: Sequence[Variable]) -> Expression: return Product.safe(p_parents(v, ordering) for v in district) def p_parents(child: Variable, ordering: Sequence[Variable]) -> Probability: """Get a probability expression based on a topological ordering. :param child: The child variable :param ordering: A topologically ordered sequence of all variables. All occurring before the child will be used as parents. :returns: A probability expression """ return P(child | ordering[: ordering.index(child)])