# -*- coding: utf-8 -*-
"""Utilities for identification algorithms."""
from __future__ import annotations
from typing import Any, Iterable, Optional, Union
import networkx as nx
from y0.dsl import (
CounterfactualVariable,
Distribution,
Expression,
Intervention,
P,
Probability,
Variable,
)
from y0.graph import NxMixedGraph, _ensure_set
from y0.mutate.canonicalize_expr import canonical_expr_equal
__all__ = [
"Query",
"Identification",
"Unidentifiable",
"str_nodes_to_variable_nodes",
]
[docs]
class Unidentifiable(Exception): # noqa:N818
"""Raised on failure of the identification algorithm."""
[docs]
class Query:
"""An identification query."""
outcomes: set[Variable]
treatments: set[Variable]
conditions: set[Variable]
def __init__(
self,
outcomes: Union[Variable, set[Variable]],
treatments: Union[Variable, set[Variable]],
conditions: Union[None, Variable, set[Variable]] = None,
) -> None:
"""Instantiate an identification.
:param outcomes: The outcomes in the query
:param treatments: The treatments in the query (e.g., counterfactual variables)
:param conditions: The conditions in the query (e.g., coming after the bar)
"""
self.outcomes = _ensure_set(outcomes)
self.treatments = _ensure_set(treatments)
self.conditions = _ensure_set(conditions or set())
def __eq__(self, other: Any) -> bool:
"""Check if the outcomes, treatments, and conditions are equal."""
return (
isinstance(other, Query)
and self.outcomes == other.outcomes
and self.treatments == other.treatments
and self.conditions == other.conditions
)
[docs]
@classmethod
def from_str(
cls,
outcomes: Union[str, Iterable[str]],
treatments: Union[str, Iterable[str]],
conditions: Optional[Iterable[str]] = None,
) -> Query:
"""Construct a query from text variable names."""
return cls(
outcomes=(
{Variable(outcomes)}
if isinstance(outcomes, str)
else {Variable(n) for n in outcomes}
),
treatments=(
{Variable(treatments)}
if isinstance(treatments, str)
else {Variable(n) for n in treatments}
),
conditions=None if conditions is None else {Variable(n) for n in conditions},
)
[docs]
@classmethod
def from_expression(
cls,
query: Union[Probability, Distribution],
) -> Query:
"""Instantiate an identification.
:param query: The query probability expression
:returns: An identification tuple
:raises ValueError: If there are ragged counterfactual variables in the query
"""
outcomes = {child.get_base() for child in query.children} # clean counterfactuals
conditions = {parent.get_base() for parent in query.parents}
first_child = query.children[0]
if not isinstance(first_child, CounterfactualVariable):
if _unexp_interventions(query.children) or _unexp_interventions(query.parents):
raise ValueError("Inconsistent usage of interventions")
treatments = set()
else:
interventions = set(first_child.interventions)
if _ragged_interventions(query.children, interventions) or _ragged_interventions(
query.parents, interventions
):
raise ValueError("Inconsistent usage of interventions")
treatments = {intervention.get_base() for intervention in first_child.interventions}
return Query(
outcomes=outcomes,
treatments=treatments,
conditions=conditions,
)
[docs]
def exchange_observation_with_action(
self, variables: Union[Variable, Iterable[Variable]]
) -> Query:
"""Move the condition variable(s) to the treatments."""
if isinstance(variables, Variable):
variables = {variables}
else:
variables = set(variables)
if any(v not in self.conditions for v in variables):
raise ValueError
return Query(
outcomes=self.outcomes,
treatments=self.treatments | variables,
conditions=self.conditions - variables,
)
[docs]
def exchange_action_with_observation(
self, variables: Union[Variable, Iterable[Variable]]
) -> Query:
"""Move the treatment variable(s) to the conditions."""
if isinstance(variables, Variable):
variables = {variables}
else:
variables = set(variables)
if any(v not in self.treatments for v in variables):
raise ValueError
return Query(
outcomes=self.outcomes,
treatments=self.treatments - variables,
conditions=self.conditions | variables,
)
[docs]
def with_treatments(self, extra_treatments: Iterable[Variable]) -> Query:
"""Create a new identification with additional treatments."""
return Query(
outcomes=self.outcomes,
treatments=self.treatments.union(extra_treatments),
conditions=self.conditions,
)
[docs]
def uncondition(self) -> Query:
"""Move the conditions to outcomes."""
return Query(
outcomes=self.outcomes | self.conditions,
treatments=self.treatments,
conditions=None,
)
@property
def expression(self) -> Expression:
"""Return the query as a Probabilistic expression."""
if self.conditions and self.treatments:
return P[self.treatments](self.outcomes | self.conditions)
elif self.treatments:
return P[self.treatments](self.outcomes)
elif self.conditions:
return P(self.outcomes | self.conditions)
else:
return P(self.outcomes)
def _unexp_interventions(variables: Iterable[Variable]) -> bool:
return any(isinstance(c, CounterfactualVariable) for c in variables)
def _ragged_interventions(variables: Iterable[Variable], interventions: set[Intervention]) -> bool:
return not all(
isinstance(child, CounterfactualVariable) and set(child.interventions) == interventions
for child in variables
)
[docs]
class Identification:
"""A package of a query and resulting estimand from identification on a graph."""
query: Query
graph: NxMixedGraph
estimand: Expression
def __init__(
self,
query: Query,
graph: NxMixedGraph,
estimand: Optional[Expression] = None,
) -> None:
"""Instantiate an identification.
:param query: The generalized identification query (outcomes/treatments/conditions)
:param graph: The graph
:param estimand: If none is given, will use the joint distribution over all variables in the graph.
"""
self.query = query
self.graph = str_nodes_to_variable_nodes(graph)
self.estimand = P(self.graph.nodes()) if estimand is None else estimand
[docs]
@classmethod
def from_parts(
cls,
outcomes: set[Variable],
treatments: set[Variable],
graph: NxMixedGraph,
estimand: Optional[Expression] = None,
conditions: Optional[set[Variable]] = None,
) -> Identification:
"""Instantiate an identification.
:param outcomes: The outcomes in the query
:param treatments: The treatments in the query (e.g., counterfactual variables)
:param conditions: The conditions in the query (e.g., coming after the bar)
:param graph: The graph
:param estimand: If none is given, will use the joint distribution over all variables in the graph.
:returns: An identification object
"""
return cls(
query=Query(outcomes=outcomes, treatments=treatments, conditions=conditions),
graph=graph,
estimand=estimand,
)
[docs]
@classmethod
def from_expression(
cls,
*,
query: Union[Probability, Distribution],
graph: NxMixedGraph,
estimand: Optional[Expression] = None,
) -> Identification:
"""Instantiate an identification.
:param query: The query probability expression
:param graph: The graph
:param estimand: If none is given, will use the joint distribution over all variables in the graph.
:returns: An identification object
"""
return cls(
query=Query.from_expression(query),
graph=graph,
estimand=estimand,
)
@property
def outcomes(self) -> set[Variable]:
"""Return this identification object's query's outcomes."""
return self.query.outcomes
@property
def treatments(self) -> set[Variable]:
"""Return this identification object's query's treatments."""
return self.query.treatments
@property
def conditions(self) -> set[Variable]:
"""Return this identification object's query's conditions."""
return self.query.conditions
[docs]
def exchange_observation_with_action(
self, variables: Union[Variable, Iterable[Variable]]
) -> Identification:
"""Move the condition variable(s) to the treatments."""
return Identification(
query=self.query.exchange_observation_with_action(variables),
graph=self.graph,
estimand=self.estimand,
)
[docs]
def exchange_action_with_observation(
self, variables: Union[Variable, Iterable[Variable]]
) -> Identification:
"""Move the treatment variable(s) to the conditions."""
return Identification(
query=self.query.exchange_action_with_observation(variables),
graph=self.graph,
estimand=self.estimand,
)
[docs]
def with_treatments(self, extra_treatments: Iterable[Variable]) -> Identification:
"""Create a new identification with additional treatments."""
return Identification(
query=self.query.with_treatments(extra_treatments),
estimand=self.estimand,
graph=self.graph,
)
[docs]
def uncondition(self) -> Identification:
"""Move the conditions to outcomes."""
return Identification(
query=self.query.uncondition(),
estimand=self.estimand,
graph=self.graph,
)
def __repr__(self) -> str:
return (
f'Identification(outcomes="{self.outcomes}, treatments="{self.treatments}",'
f'conditions="{self.conditions}", graph="{self.graph!r}", estimand="{self.estimand}")'
)
def __eq__(self, other: Any) -> bool:
"""Check if the query, estimand, and graph are equal."""
return (
isinstance(other, Identification)
and self.query == other.query
and canonical_expr_equal(self.estimand, other.estimand)
and self.graph == other.graph
)
def str_nodes_to_variable_nodes(graph: NxMixedGraph) -> NxMixedGraph:
"""Generate a variable graph from this graph of strings."""
return NxMixedGraph.from_edges(
nodes={Variable.norm(node) for node in graph.nodes()},
directed=_convert(graph.directed),
undirected=_convert(graph.undirected),
)
def _convert(graph: nx.Graph) -> list[tuple[Variable, Variable]]:
return [(Variable.norm(u), Variable.norm(v)) for u, v in graph.edges()]