"""Implement of surrogate outcomes and transportability from [tikka2018]_."""
import logging
from collections.abc import Collection, Iterable
from copy import deepcopy
from dataclasses import dataclass
from typing import cast
from y0.algorithm.conditional_independencies import are_d_separated
from y0.dsl import (
TARGET_DOMAIN,
CounterfactualVariable,
Distribution,
Expression,
Fraction,
Intervention,
One,
Population,
PopulationProbability,
Probability,
Product,
Sum,
Variable,
Zero,
_upgrade_variables_set,
)
from y0.graph import NxMixedGraph
from y0.mutate.canonicalize_expr import canonicalize
__all__ = [
"TransportQuery",
"identify_target_outcomes",
"trso",
]
logger = logging.getLogger(__name__)
def get_nodes_to_transport(
*,
surrogate_interventions: Variable | Iterable[Variable],
surrogate_outcomes: Variable | Iterable[Variable],
graph: NxMixedGraph,
) -> set[Variable]:
"""Identify which nodes the transport nodes should point to.
:param surrogate_interventions: The interventions performed in an experiment.
:param surrogate_outcomes: The outcomes observed in an experiment.
:param graph: The graph of the target domain.
:returns: A set of variables representing target domain nodes where transportability
nodes should be added.
"""
surrogate_interventions = _upgrade_variables_set(surrogate_interventions)
surrogate_outcomes = _upgrade_variables_set(surrogate_outcomes)
# Find the c_component with surrogate_outcomes
c_component_surrogate_outcomes: set[Variable] = set()
for component in graph.districts():
# Check if surrogate_outcomes is present in the current set
if surrogate_outcomes.intersection(component):
c_component_surrogate_outcomes.update(component)
ancestors_surrogate_outcomes = graph.get_intervened_ancestors(
surrogate_interventions, surrogate_outcomes
)
# Descendants of interventions in graph
descendants_interventions = graph.descendants_inclusive(surrogate_interventions)
return (descendants_interventions - surrogate_outcomes).union(
c_component_surrogate_outcomes - ancestors_surrogate_outcomes
)
_TRANSPORT_PREFIX = "T_"
def transport_variable(variable: Variable) -> Variable:
"""Create a transport Variable by adding the transport prefix to a variable.
:param variable: variable that the transport node will point to
:returns: Variable with _TRANSPORT_PREFIX and variable name
:raises TypeError: If a non-standard variable is passed
"""
if isinstance(variable, CounterfactualVariable | Intervention):
raise TypeError
return Variable(_TRANSPORT_PREFIX + variable.name)
def is_transport_node(node: Variable) -> bool:
"""Check if a Variable is a transport node.
:param node: A node to evaluate.
:returns: boolean True if node is a transport node, False otherwise.
"""
return not isinstance(node, CounterfactualVariable | Intervention) and node.name.startswith(
_TRANSPORT_PREFIX
)
def get_transport_nodes(graph: NxMixedGraph) -> set[Variable]:
"""Find all the transport nodes in a graph.
:param graph: an NxMixedGraph which may have transport nodes
:returns: Set containing all transport nodes in the graph
"""
return {node for node in graph.nodes() if is_transport_node(node)}
def get_regular_nodes(graph: NxMixedGraph) -> set[Variable]:
"""Find all the nodes in a graph which are not transport nodes.
:param graph: an NxMixedGraph
:returns: Set containing all nodes which are not transport nodes
"""
return {node for node in graph.nodes() if not is_transport_node(node)}
def _c14n_safe(expression: Expression | None) -> Expression | None:
if expression is None:
return None
return canonicalize(expression)
def create_transport_diagram(
*,
nodes_to_transport: Iterable[Variable],
graph: NxMixedGraph,
) -> NxMixedGraph:
"""Create a NxMixedGraph identical to graph but with transport nodes added.
:param nodes_to_transport: nodes which have transport nodes pointing to them.
:param graph: The graph of the target domain.
:returns: graph with transport nodes added
"""
rv = NxMixedGraph()
for node in graph.nodes():
rv.add_node(node)
for u, v in graph.directed.edges():
rv.add_directed_edge(u, v)
for u, v in graph.undirected.edges():
rv.add_undirected_edge(u, v)
for node in nodes_to_transport:
transport_node = transport_variable(node)
rv.add_directed_edge(transport_node, node)
return rv
[docs]
@dataclass
class TransportQuery:
"""A query used as output for surrogate_to_transport."""
target_interventions: set[Variable]
target_outcomes: set[Variable]
graphs: dict[Population, NxMixedGraph]
domains: set[Population]
surrogate_interventions: dict[Population, set[Variable]]
target_experiments: set[Variable]
@dataclass
class TRSOQuery:
"""A query used for TRSO input."""
target_interventions: set[Variable]
target_outcomes: set[Variable]
expression: Expression
active_interventions: set[Variable]
domain: Population
domains: set[Population]
graphs: dict[Population, NxMixedGraph]
surrogate_interventions: dict[Population, set[Variable]]
def surrogate_to_transport(
*,
graph: NxMixedGraph,
target_outcomes: set[Variable],
target_interventions: set[Variable],
surrogate_outcomes: dict[Population, set[Variable]],
surrogate_interventions: dict[Population, set[Variable]],
) -> TransportQuery:
"""Create transportability diagrams and query from a surrogate outcome problem.
:param target_outcomes: A set of target variables for causal effects.
:param target_interventions: A set of interventions for the target domain.
:param graph: The graph of the target domain.
:param surrogate_outcomes: A dictionary of outcomes in other populations
:param surrogate_interventions: A dictionary of interventions in other populations
:returns: An octuple representing the query transformation of a surrogate outcome
query.
:raises ValueError: if surrogate outcomes' and surrogate interventions' keys do not
correspond
"""
if set(surrogate_outcomes) != set(surrogate_interventions):
raise ValueError("Inconsistent surrogate outcome and intervention domains")
graphs = {
domain: create_transport_diagram(
graph=graph,
nodes_to_transport=get_nodes_to_transport(
surrogate_interventions=surrogate_interventions[domain],
surrogate_outcomes=domain_outcomes,
graph=graph,
),
)
for domain, domain_outcomes in surrogate_outcomes.items()
}
graphs[TARGET_DOMAIN] = graph
return TransportQuery(
target_interventions=target_interventions,
target_outcomes=target_outcomes,
graphs=graphs,
domains=set(surrogate_outcomes),
surrogate_interventions=surrogate_interventions,
target_experiments=set(),
)
def trso_line1(
target_outcomes: set[Variable],
expression: Expression,
graph: NxMixedGraph,
) -> Expression:
"""Return the probability in the case where no interventions are present.
:param target_outcomes: A set of nodes that comprise our target outcomes.
:param expression: The distribution in the current domain.
:param graph: The graph with transport nodes in this domain.
:returns: Sum over the probabilities of nodes other than target outcomes.
"""
return Sum.safe(expression, get_regular_nodes(graph) - target_outcomes)
def trso_line2(
query: TRSOQuery,
outcomes_ancestors: set[Variable],
) -> TRSOQuery:
"""Restrict the interventions and diagram to only include ancestors of target variables.
:param query: A TRSO query
:param outcomes_ancestors: the ancestors of target variables in
transportability_diagram
:returns: A TRSO query with modified attributes.
:raises TypeError: if the new query's expression is not a population probability
"""
new_query = deepcopy(query)
new_query.target_interventions.intersection_update(outcomes_ancestors)
for domain, graph in query.graphs.items():
outcome_ancestors_domain = graph.ancestors_inclusive(query.target_outcomes)
new_query.graphs[domain] = graph.subgraph(outcome_ancestors_domain)
new_query.expression = Sum.safe(
query.expression,
get_regular_nodes(query.graphs[query.domain]) - outcomes_ancestors,
simplify=True,
)
if isinstance(new_query.expression, Probability):
if not isinstance(new_query.expression, PopulationProbability):
raise TypeError
# it might be the case that these two are not the same, but
# other parts of the algorithm clean it up. This isn't so
# satisfying. Sorry!
# if new_query.expression.population != new_query.domain:
# pass
new_query.expression = PopulationProbability(
population=new_query.domain,
distribution=Distribution(
children=new_query.expression.children,
),
)
return new_query
def trso_line3(query: TRSOQuery, additional_interventions: set[Variable]) -> TRSOQuery:
"""Add nodes that will affect the outcome to the interventions of the query.
:param query: A TRSO query
:param additional_interventions: interventions to be added to target_interventions
:returns: A TRSO query with modified attributes.
"""
new_query = deepcopy(query)
new_query.target_interventions.update(additional_interventions)
return new_query
def trso_line4(
query: TRSOQuery,
components: Iterable[frozenset[Variable]],
) -> dict[frozenset[Variable], TRSOQuery]:
"""Find the trso inputs for each C-component.
:param query: A TRSO query
:param components: Set of c_components of transportability_diagram without
target_interventions
:returns: Dictionary with components as keys TRSOQuery objects as values
"""
graph = query.graphs[query.domain]
rv = {}
for component in components:
new_query = deepcopy(query)
new_query.target_outcomes = set(component)
new_query.target_interventions = get_regular_nodes(graph) - component
rv[component] = new_query
return rv
def trso_line6(query: TRSOQuery) -> dict[Population, TRSOQuery]:
"""Find the active interventions for each domain, remove available experiments from interventions.
:param query: A TRSO query
:returns: Dictionary with domains as keys TRSOQuery objects as values
"""
expressions = {}
for domain, graph in query.graphs.items():
if domain == TARGET_DOMAIN:
continue
new_query = _line_6_helper(query, domain, graph)
if new_query is not None:
expressions[domain] = new_query
return expressions
def _line_6_helper(query: TRSOQuery, domain: Population, graph: NxMixedGraph) -> TRSOQuery | None:
"""Perform d-separation check and then modify query active interventions.
:param query: A TRSO query
:param domain: A given population
:param graph: A NxMixedGraph
:returns: A TRSO query or None
"""
surrogate_interventions = query.surrogate_interventions[domain]
surrogate_intersect_target = surrogate_interventions.intersection(query.target_interventions)
if not surrogate_intersect_target:
return None
if not all_transports_d_separated(
graph,
target_interventions=query.target_interventions,
target_outcomes=query.target_outcomes,
):
return None
new_query = deepcopy(query)
new_query.target_interventions = query.target_interventions - surrogate_interventions
new_query.domain = domain
new_query.graphs[new_query.domain] = graph.remove_nodes_from(surrogate_intersect_target)
new_query.active_interventions = surrogate_intersect_target
return new_query
def activate_domain_and_interventions(
expression: Expression, interventions: set[Variable], domain: Population
) -> Expression:
"""Intervene on the target variables of expression using the active interventions.
:param expression: A probability expression.
:param interventions: Set of active interventions
:param domain: A given population
:returns: A new expression, intervened
:raises NotImplementedError: If an expression type that is not handled gets passed
:raises TypeError: if the expression is a probability but not a population
probability
"""
if isinstance(expression, Probability):
if not isinstance(expression, PopulationProbability):
raise TypeError
return PopulationProbability(
population=domain,
distribution=Distribution.safe(set(expression.children) - interventions),
).intervene(interventions)
if isinstance(expression, Sum):
# TODO need full integration test to trso() function that covers this branch
# Don't intervene the ranges because counterfactual variables shouldn't be in ranges
# intervened_ranges = tuple(
# variable.activate_domain_and_interventions(active_interventions) for variable in expression.ranges
# )
return Sum.safe(
activate_domain_and_interventions(expression.expression, interventions, domain),
expression.ranges,
)
if isinstance(expression, Fraction):
numerator = activate_domain_and_interventions(expression.numerator, interventions, domain)
denominator = activate_domain_and_interventions(
expression.denominator, interventions, domain
)
return cast(Fraction, numerator / denominator).simplify()
if isinstance(expression, Product):
# TODO need full integration test to trso() function that covers this branch
return Product.safe(
activate_domain_and_interventions(expr, interventions, domain)
for expr in expression.expressions
)
raise NotImplementedError(f"Unhandled expression type: {type(expression)}")
def all_transports_d_separated(
graph: NxMixedGraph, target_interventions: set[Variable], target_outcomes: set[Variable]
) -> bool:
"""Check if all target_interventions are d-separated from target_outcomes.
:param graph: The graph with transport nodes in this domain.
:param target_interventions: Set of target interventions
:param target_outcomes: Set of target interventions
:returns: True if all interventions are d-separated from all outcomes, False
otherwise.
"""
transportability_nodes = get_transport_nodes(graph)
graph_without_interventions = graph.remove_in_edges(target_interventions)
return all(
are_d_separated(
graph_without_interventions,
transportability_node,
outcome,
conditions=target_interventions,
)
for transportability_node in transportability_nodes
if transportability_node in graph_without_interventions
for outcome in target_outcomes
)
def trso_line9(query: TRSOQuery, district: set[Variable]) -> Expression:
"""Get the probability in the case with exactly one districts_without_interventions and it is present in districts.
:param query: A TRSO query
:param district: The C-component present in both districts_without_interventions and
districts
:returns: An Expression
:raises RuntimeError: If the query's expression is zero. This should never happen
"""
logger.debug(
"Calling trso algorithm line 9 with expression %s \n district %s",
query.expression,
district,
)
if isinstance(query.expression, Zero): # pragma: no cover
# TODO if we can't create an integration test (i.e., a call to trso)
# that triggers this line, then it can be safely removed
raise RuntimeError
ordering = query.graphs[query.domain].topological_sort()
ordering_set = set(ordering)
my_product: Expression = One()
for node in district:
i = ordering.index(node)
pre, post = ordering[:i], ordering[: i + 1]
pre_set = ordering_set - set(post)
post_set = ordering_set - set(pre)
numerator = Sum.safe(query.expression, pre_set)
denominator = Sum.safe(query.expression, post_set)
my_product *= numerator / denominator
my_product = cast(Fraction, my_product).simplify()
logger.debug(
"Returning trso algorithm line 9 with expression %s",
Sum.safe(my_product, district - query.target_outcomes),
)
return Sum.safe(my_product, district - query.target_outcomes)
def trso_line10(
query: TRSOQuery,
district: set[Variable],
new_surrogate_interventions: dict[Population, set[Variable]],
) -> TRSOQuery:
"""Update the TRSO query to restrict interventions and graph to district.
:param query: A TRSO query
:param district: The C-component of districts which contains
district_without_interventions
:param new_surrogate_interventions: Dict mapping domains to interventions performed
in that domain.
:returns: A modified TRSOQuery
"""
ordering = query.graphs[query.domain].topological_sort()
expressions = []
for node in district:
i = ordering.index(node)
pre_node = set(ordering[:i])
# note tikka splits this into two expressions that when taken together equal pre_node
distribution = Distribution.safe(node | pre_node)
expressions.append(
PopulationProbability(population=query.domain, distribution=distribution)
)
new_query = deepcopy(query)
new_query.target_interventions = query.target_interventions.intersection(district)
new_query.expression = canonicalize(Product.safe(expressions))
new_query.graphs[query.domain] = query.graphs[query.domain].subgraph(district)
new_query.surrogate_interventions = new_surrogate_interventions
return new_query
[docs]
def trso(query: TRSOQuery) -> Expression | None: # noqa:C901
"""Run the TRSO algorithm to evaluate a transport problem.
:param query: A TRSO query, which contains 8 instance variables needed for TRSO
:returns: An Expression evaluating the given query, or None
:raises RuntimeError: when an impossible condition is met
"""
# Check that domain is in query.domains
# check that query.surrogate_interventions keys are equals to domains
# check that query.graphs keys are equal to domains
logger.debug(
"Calling trso algorithm with "
"\t- target_interventions: %s\n"
"\t- target_outcomes: %s\n"
"\t- expression: %s\n"
"\t- active_interventions: %s\n"
"\t- domain: %s\n"
"\t- domains: %s\n"
"\t- graph[domain] nodes: %s\n"
"\t- surrogate_interventions: %s",
query.target_interventions,
query.target_outcomes,
query.expression,
query.active_interventions,
query.domain,
query.domains,
query.graphs[query.domain].nodes(),
query.surrogate_interventions,
)
graph = query.graphs[query.domain]
# line 1
if not query.target_interventions:
logger.debug("Calling trso algorithm line 1")
return canonicalize(trso_line1(query.target_outcomes, query.expression, graph))
# line 2
outcome_ancestors = graph.ancestors_inclusive(query.target_outcomes)
if get_regular_nodes(graph) - outcome_ancestors:
new_query = trso_line2(query, outcome_ancestors)
logger.debug("Calling trso algorithm line 2")
return _c14n_safe(trso(new_query))
# line 3
additional_interventions = graph.get_no_effect_on_outcomes(
query.target_interventions, query.target_outcomes
)
if additional_interventions:
new_query = trso_line3(query, additional_interventions)
logger.debug("Calling trso algorithm line 3")
return _c14n_safe(trso(new_query))
# line 4
districts_without_interventions: set[frozenset[Variable]] = graph.remove_nodes_from(
query.target_interventions
).districts()
if len(districts_without_interventions) > 1:
subqueries = trso_line4(
query,
districts_without_interventions,
)
terms = []
logger.debug("Calling trso algorithm line 4 with %d subqueries", len(subqueries))
for i, subquery in enumerate(subqueries.values()):
logger.debug("Calling subquery %d of trso algorithm line 4", i + 1)
term = trso(subquery)
if term is None:
return None
terms.append(term)
product = Product.safe(terms)
summand = canonicalize(product) # fix sort order inside product
return canonicalize(
Sum.safe(
summand,
get_regular_nodes(graph) - query.target_interventions.union(query.target_outcomes),
)
)
# line 6
if not query.active_interventions and query.surrogate_interventions:
expressions: dict[Population, Expression] = {}
for domain, subquery in trso_line6(query).items():
logger.debug("Calling trso algorithm line 6 for domain %s", domain)
expression = trso(subquery)
if expression is None:
continue
expression = activate_domain_and_interventions(
expression, subquery.active_interventions, domain
)
if expression is not None: # line7
logger.debug(
"Calling trso algorithm line 7",
)
expressions[domain] = expression
if len(expressions) == 1:
return canonicalize(next(iter(expressions.values())))
elif len(expressions) > 1:
# TODO need full integration test to trso() function that covers this branch
# or change to ``raise RuntimeError`` if it's not possible to reach in practice
logger.warning("more than one expression were non-none")
# What if more than 1 expression doesn't fail?
# Is it non-deterministic or can we prove it will be length 1?
return canonicalize(next(iter(expressions.values())))
else:
# if there are no expressions, then we move on to line 8
pass
# line8 checks that len(districts)) != 1
districts = graph.districts()
# line 11 states return fail if len(districts)==1
# keep explict tests for 0 and 1 to ensure adequate testing
if len(districts) == 0:
# TODO we need an integration test (i.e., call to trso()) that covers this.
# if it's not possible to cover in a real setting, then we can change this
# to raising a runtime error. Nathaniel notes that this probably only occurs
# if the graph is empty, but it's not clear if it makes sense to have an empty
# graph
return None
elif len(districts) == 1:
return None
# line 8, i.e. len(districts)>1
# line 9
if len(districts_without_interventions) == 0: # pragma: no cover
# This would happen if there is an intervention that is also an outcome,
# which we ensure is not possible when calling the algorithm from its harness
raise RuntimeError
# at this point, we already checked for cases where len > 2 and len == 0,
# so we can safely pop the only element
district_without_interventions = districts_without_interventions.pop()
if district_without_interventions in districts:
return canonicalize(trso_line9(query, set(district_without_interventions)))
# line10
logger.debug("Calling trso algorithm line 10")
target_districts = [
district for district in districts if district_without_interventions.issubset(district)
]
if len(target_districts) != 1: # pragma: no cover
# At this point, the mathematics require this, and therfore this
# test should never evaluate to true
raise RuntimeError
target_district = target_districts.pop()
# district is C' districts should be D[C'], but we chose to return set of nodes instead of subgraph
if len(query.active_interventions) == 0:
# TRSO Line 6 could return an empty list and skip over the returns, allowing this line to be reached.
new_surrogate_interventions = {}
elif _pillow_has_transport(graph, target_district):
return None
else:
new_surrogate_interventions = query.surrogate_interventions
new_query = trso_line10(
query,
set(target_district),
new_surrogate_interventions,
)
return _c14n_safe(trso(new_query))
def _pillow_has_transport(graph: NxMixedGraph, district: Collection[Variable]) -> bool:
return any(is_transport_node(node) for node in graph.get_markov_pillow(district))
def check_and_raise_missing(nodes: set[Variable], graph: NxMixedGraph, name: str) -> None:
"""Verify that nodes are present in the graph.
:param nodes: A set of nodes that should be in the graph.
:param graph: An NxMixedGraph(), the graph of the target domain.
:param name: Name of the set of nodes
:raises ValueError: If any element of nodes is not in the graph.
"""
missing_nodes = nodes - graph.nodes()
missing_nodes_text = {node.to_text() for node in missing_nodes}
if missing_nodes_text:
raise ValueError(
f"The following {name} are not in the graph: {', '.join(missing_nodes_text)}"
)
[docs]
def identify_target_outcomes(
graph: NxMixedGraph,
*,
target_outcomes: set[Variable],
target_interventions: set[Variable],
surrogate_outcomes: dict[Population, set[Variable]],
surrogate_interventions: dict[Population, set[Variable]],
) -> Expression | None:
r"""Get the estimand for the target outcome givne the surrogate outcomes.
.. seealso::
Originally described in https://arxiv.org/abs/1806.07172.
:param target_outcomes: A set of target variables for causal effects.
:param target_interventions: A set of interventions for the target domain.
:param graph: The graph of the target domain.
:param surrogate_outcomes: A dictionary of outcomes in other populations
:param surrogate_interventions: A dictionary of interventions in other populations
:returns: An Expression evaluating the given query, or None
:raises ValueError: If the target outcomes and target interventions intersect
The example from figure 8 of the original paper can be executed with the following
code:
.. code-block:: python
from y0.algorithm.transport import identify_target_outcome
from y0.dsl import X1, X2, Y1, Y2, Pi1, Pi2
from y0.examples import tikka_trso_figure_8_graph
estimand = identify_target_outcome(
graph=tikka_trso_figure_8_graph,
target_outcomes={Y1, Y2},
target_interventions={X1, X2},
surrogate_outcomes={Pi1: {Y1}, Pi2: {Y2}},
surrogate_interventions={Pi1: {X1}, Pi2: {X2}},
)
This returns the following estimand: $\sum_{W, Z} P(W, Z) \frac{P_{X_1}^{π_1}(W,
Y_1, Z)}{P_{X_1}(W, Z)} \frac{P_{X_2}^{π_2}(W, X_1, Y_2, Z)}{P_{X_2}(W, X_1, Z)}$
"""
# TODO add vanilla identification check?
# vanilla_estimand = identify_outcomes(
# graph=graph, outcomes=target_outcomes, treatments=target_interventions
# )
# if vanilla_estimand is not None:
# logger.warning(f"This query is identifiable without surrogates: {vanilla_estimand}")
check_and_raise_missing(target_outcomes, graph, "target_outcomes")
check_and_raise_missing(target_interventions, graph, "target_interventions")
check_and_raise_missing(set().union(*surrogate_outcomes.values()), graph, "surrogate_outcomes")
check_and_raise_missing(
set().union(*surrogate_interventions.values()), graph, "surrogate_interventions"
)
outcome_is_intervention = target_outcomes.intersection(target_interventions)
if outcome_is_intervention:
raise ValueError(
f"The variables {outcome_is_intervention} cannot be target_outcomes and target_interventions"
)
transport_query = surrogate_to_transport(
graph=graph,
target_outcomes=target_outcomes,
target_interventions=target_interventions,
surrogate_outcomes=surrogate_outcomes,
surrogate_interventions=surrogate_interventions,
)
initial_expression = PopulationProbability(
population=TARGET_DOMAIN,
distribution=Distribution.safe(graph.nodes()),
)
trso_query = TRSOQuery(
target_interventions=transport_query.target_interventions,
target_outcomes=transport_query.target_outcomes,
expression=initial_expression,
active_interventions=set(),
domain=TARGET_DOMAIN,
domains=transport_query.domains,
graphs=transport_query.graphs,
surrogate_interventions=transport_query.surrogate_interventions,
)
return trso(trso_query)