Source code for y0.algorithm.identify.api

"""High-level API for identification algorithms."""

from collections.abc import Iterable, Sequence
from typing import Literal, overload

from .id_c import idc
from .id_std import identify
from .utils import Identification, Unidentifiable
from ...dsl import Expression, Variable
from ...graph import NxMixedGraph

__all__ = [
    "identify_outcomes",
]


# docstr-coverage:excused `overload`
@overload
def identify_outcomes(
    graph: NxMixedGraph,
    treatments: Variable | Iterable[Variable],
    outcomes: Variable | Iterable[Variable],
    conditions: None | Variable | Iterable[Variable] = ...,
    *,
    strict: Literal[True] = ...,
    ordering: Sequence[Variable] | None = ...,
) -> Expression: ...


# docstr-coverage:excused `overload`
@overload
def identify_outcomes(
    graph: NxMixedGraph,
    treatments: Variable | Iterable[Variable],
    outcomes: Variable | Iterable[Variable],
    conditions: None | Variable | Iterable[Variable] = ...,
    *,
    strict: Literal[False] = ...,
    ordering: Sequence[Variable] | None = ...,
) -> Expression | None: ...


[docs] def identify_outcomes( graph: NxMixedGraph, treatments: Variable | Iterable[Variable], outcomes: Variable | Iterable[Variable], conditions: None | Variable | Iterable[Variable] = None, *, strict: bool = False, ordering: Sequence[Variable] | None = None, ) -> Expression | None: """Calculate the estimand for the treatment(s)m outcome(s), and optional condition(s). :param graph: An acyclic directed mixed graph :param treatments: The node or nodes that are treated :param outcomes: The node or nodes that are outcomes :param conditions: Optional condition or condition nodes. If given, uses the IDC algorithm via :func:`y0.algorithm.identify.idc`. Otherwise, uses the ID algorithm via :func:`y0.algorithm.identify.identify`. :param strict: If true, raises an error identification fails rather than returning None :param ordering: A topological ordering of the variables. If not passed, is calculated from the directed component of the mixed graph. :returns: An expression representing the estimand if the query is identifiable. If the query is not identifiable, returns none. """ identification = Identification.from_parts( graph=graph, treatments=treatments, outcomes=outcomes, conditions=conditions ) try: if not identification.conditions: rv = identify(identification, ordering=ordering) else: rv = idc(identification, ordering=ordering) except Unidentifiable: if strict: raise return None return rv