Source code for y0.algorithm.identify.utils

"""Utilities for identification algorithms."""

from __future__ import annotations

from collections.abc import Iterable
from itertools import chain
from typing import Any, cast

import networkx as nx

from y0.dsl import (
    CounterfactualVariable,
    Distribution,
    Expression,
    P,
    Probability,
    Variable,
)
from y0.graph import NxMixedGraph, _ensure_set
from y0.mutate.canonicalize_expr import canonical_expr_equal

__all__ = [
    "Identification",
    "Query",
    "Unidentifiable",
    "str_nodes_to_variable_nodes",
]


[docs] class Unidentifiable(Exception): # noqa:N818 """Raised on failure of the identification algorithm."""
[docs] class Query: """An identification query.""" outcomes: set[Variable] treatments: set[Variable] conditions: set[Variable] def __init__( self, outcomes: Variable | Iterable[Variable], treatments: Variable | Iterable[Variable], conditions: None | Variable | Iterable[Variable] = None, ) -> None: """Instantiate an identification. :param outcomes: The outcomes in the query :param treatments: The treatments in the query (e.g., counterfactual variables) :param conditions: The conditions in the query (e.g., coming after the bar) """ self.outcomes = _ensure_set(outcomes) self.treatments = _ensure_set(treatments) self.conditions = _ensure_set(conditions) if conditions is not None else set() def __eq__(self, other: Any) -> bool: """Check if the outcomes, treatments, and conditions are equal.""" return ( isinstance(other, Query) and self.outcomes == other.outcomes and self.treatments == other.treatments and self.conditions == other.conditions )
[docs] @classmethod def from_str( cls, outcomes: str | Iterable[str], treatments: str | Iterable[str], conditions: Iterable[str] | None = None, ) -> Query: """Construct a query from text variable names.""" return cls( outcomes=( {Variable(outcomes)} if isinstance(outcomes, str) else {Variable(n) for n in outcomes} ), treatments=( {Variable(treatments)} if isinstance(treatments, str) else {Variable(n) for n in treatments} ), conditions=None if conditions is None else {Variable(n) for n in conditions}, )
[docs] @classmethod def from_expression( cls, query: Probability | Distribution, ) -> Query: """Instantiate an identification. :param query: The query probability expression :returns: An identification tuple :raises ValueError: If there are ragged counterfactual variables in the query """ outcomes = {child.get_base() for child in query.children} # clean counterfactuals conditions = {parent.get_base() for parent in query.parents} treatments: set[Variable] if any(isinstance(c, CounterfactualVariable) for c in chain(query.children, query.parents)): if not all(isinstance(c, CounterfactualVariable) for c in query.children): raise ValueError( "if any children or parents are counterfactual variables, all children have to be" ) if not all(isinstance(c, CounterfactualVariable) for c in query.parents): raise ValueError( "if any children or parents are counterfactual variables, all parents have to be" ) intervention_sets: set[frozenset[Variable]] = { cast(CounterfactualVariable, c).interventions for c in chain(query.children, query.parents) } if len(intervention_sets) != 1: raise ValueError("inconsistent usage of interventions") treatments = {x.get_base() for x in next(iter(intervention_sets))} else: treatments = set() return Query(outcomes=outcomes, treatments=treatments, conditions=conditions)
[docs] def exchange_observation_with_action(self, variables: Variable | Iterable[Variable]) -> Query: """Move the condition variable(s) to the treatments.""" variables = _ensure_set(variables) if missing := (variables - self.conditions): raise ValueError(f"variables don't appear in conditions: {missing}") return Query( outcomes=self.outcomes, treatments=self.treatments | variables, conditions=self.conditions - variables, )
[docs] def exchange_action_with_observation(self, variables: Variable | Iterable[Variable]) -> Query: """Move the treatment variable(s) to the conditions.""" variables = _ensure_set(variables) if missing := (variables - self.treatments): raise ValueError(f"variables don't appear in treatments: {missing}") return Query( outcomes=self.outcomes, treatments=self.treatments - variables, conditions=self.conditions | variables, )
[docs] def with_treatments(self, extra_treatments: Iterable[Variable]) -> Query: """Create a new identification with additional treatments.""" return Query( outcomes=self.outcomes, treatments=self.treatments.union(extra_treatments), conditions=self.conditions, )
[docs] def uncondition(self) -> Query: """Move the conditions to outcomes.""" return Query( outcomes=self.outcomes | self.conditions, treatments=self.treatments, conditions=None, )
@property def expression(self) -> Expression: """Return the query as a Probabilistic expression.""" distribution = Distribution.safe(self.outcomes) if self.conditions: distribution = distribution.given(self.conditions) elif self.treatments: distribution = distribution.intervene(self.treatments) return Probability(distribution)
[docs] class Identification: """A package of a query and resulting estimand from identification on a graph.""" query: Query graph: NxMixedGraph estimand: Expression def __init__( self, query: Query, graph: NxMixedGraph, estimand: Expression | None = None, ) -> None: """Instantiate an identification. :param query: The generalized identification query (outcomes/treatments/conditions) :param graph: The graph :param estimand: If none is given, will use the joint distribution over all variables in the graph. """ self.query = query self.graph = str_nodes_to_variable_nodes(graph) self.estimand = P(self.graph.nodes()) if estimand is None else estimand
[docs] @classmethod def from_parts( cls, outcomes: Variable | Iterable[Variable], treatments: Variable | Iterable[Variable], graph: NxMixedGraph, estimand: Expression | None = None, conditions: Variable | Iterable[Variable] | None = None, ) -> Identification: """Instantiate an identification. :param outcomes: The outcomes in the query :param treatments: The treatments in the query (e.g., counterfactual variables) :param conditions: The conditions in the query (e.g., coming after the bar) :param graph: The graph :param estimand: If none is given, will use the joint distribution over all variables in the graph. :returns: An identification object """ return cls( query=Query(outcomes=outcomes, treatments=treatments, conditions=conditions), graph=graph, estimand=estimand, )
[docs] @classmethod def from_expression( cls, *, query: Probability | Distribution, graph: NxMixedGraph, estimand: Expression | None = None, ) -> Identification: """Instantiate an identification. :param query: The query probability expression :param graph: The graph :param estimand: If none is given, will use the joint distribution over all variables in the graph. :returns: An identification object """ return cls( query=Query.from_expression(query), graph=graph, estimand=estimand, )
@property def outcomes(self) -> set[Variable]: """Return this identification object's query's outcomes.""" return self.query.outcomes @property def treatments(self) -> set[Variable]: """Return this identification object's query's treatments.""" return self.query.treatments @property def conditions(self) -> set[Variable]: """Return this identification object's query's conditions.""" return self.query.conditions
[docs] def exchange_observation_with_action( self, variables: Variable | Iterable[Variable] ) -> Identification: """Move the condition variable(s) to the treatments.""" return Identification( query=self.query.exchange_observation_with_action(variables), graph=self.graph, estimand=self.estimand, )
[docs] def exchange_action_with_observation( self, variables: Variable | Iterable[Variable] ) -> Identification: """Move the treatment variable(s) to the conditions.""" return Identification( query=self.query.exchange_action_with_observation(variables), graph=self.graph, estimand=self.estimand, )
[docs] def with_treatments(self, extra_treatments: Iterable[Variable]) -> Identification: """Create a new identification with additional treatments.""" return Identification( query=self.query.with_treatments(extra_treatments), estimand=self.estimand, graph=self.graph, )
[docs] def uncondition(self) -> Identification: """Move the conditions to outcomes.""" return Identification( query=self.query.uncondition(), estimand=self.estimand, graph=self.graph, )
def __repr__(self) -> str: return ( f'Identification(outcomes="{self.outcomes}, treatments="{self.treatments}",' f'conditions="{self.conditions}", graph="{self.graph!r}", estimand="{self.estimand}")' ) def __eq__(self, other: Any) -> bool: """Check if the query, estimand, and graph are equal.""" return ( isinstance(other, Identification) and self.query == other.query and canonical_expr_equal(self.estimand, other.estimand) and self.graph == other.graph )
def str_nodes_to_variable_nodes(graph: NxMixedGraph) -> NxMixedGraph: """Generate a variable graph from this graph of strings.""" return NxMixedGraph.from_edges( nodes={Variable.norm(node) for node in graph.nodes()}, directed=_convert(graph.directed), undirected=_convert(graph.undirected), ) def _convert(graph: nx.Graph) -> list[tuple[Variable, Variable]]: return [(Variable.norm(u), Variable.norm(v)) for u, v in graph.edges()]