Source code for y0.dsl

# -*- coding: utf-8 -*-

r"""An internal domain-specific language for probability expressions.

=======================  ====================================================================
Expression               Description
=======================  ====================================================================
:math:`P(A)`             The probability of A occurring
:math:`P(A^*)`           The probability of A not occurring
:math:`P(A, B)`          The joint probability of A and B occurring
:math:`P(A \mid B)`      The conditional probability of A given B occurring
:math:`P(A \mid B^*)`    The conditional probability of A occurring given B not occurring
:math:`P(A^* \mid B)`    The conditional probability of A not occurring given B occurring
:math:`P(A^* \mid B^*)`  The conditional probability of A not occurring given B not occurring
:math:`\sum_A P(A, B)`   The marginal probability of B
=======================  ====================================================================

Level 3 of Pearl's Causal Hierarchy.

==============================  =================================================
Expression                      Description
==============================  =================================================
:math:`P(Y_X \mid X^*, Y^*)`    Probability of sufficient causation
:math:`P(Y^*_{X^*} \mid X, Y)`  Probability of necessary causation
:math:`P(Y_X, Y^*_{X^*})`       Probability of necessary and sufficient causation
==============================  =================================================
"""

from __future__ import annotations

import functools
import itertools as itt
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from operator import attrgetter
from typing import (
    TYPE_CHECKING,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Protocol,
    Sequence,
    Set,
    Tuple,
    TypeVar,
    Union,
    cast,
)

if TYPE_CHECKING:
    import sympy

__all__ = [
    "Element",
    "Variable",
    "Intervention",
    "CounterfactualVariable",
    "Distribution",
    "Event",
    "P",
    "Probability",
    "Sum",
    "Product",
    "Fraction",
    "Expression",
    "One",
    "Zero",
    "Q",
    "QFactor",
    "A",
    "AA",
    "B",
    "C",
    "D",
    "M",
    "R",
    "S",
    "T",
    "U",
    "W",
    "X",
    "Y",
    "Z",
    "U1",
    "U2",
    "U3",
    "U4",
    "U5",
    "U6",
    "V1",
    "V2",
    "V3",
    "V4",
    "V5",
    "V6",
    "W0",
    "W1",
    "W2",
    "W3",
    "W4",
    "W5",
    "W6",
    "Y1",
    "Y2",
    "Y3",
    "Y4",
    "Y5",
    "Y6",
    "Z1",
    "Z2",
    "Z3",
    "Z4",
    "Z5",
    "Z6",
    # Helpers
    "ensure_ordering",
    "vmap_adj",
    "vmap_pairs",
    # Transport
    "PopulationProbability",
    "PP",
    "Pi1",
    "Pi2",
    "Pi3",
    "Pi4",
    "Pi5",
    "Pi6",
    "π1",
    "π2",
    "π3",
    "π4",
    "π5",
    "π6",
    "Population",
]

T_co = TypeVar("T_co", covariant=True)


def _to_interventions(variables: Sequence[Variable]) -> Tuple[Intervention, ...]:
    return tuple(
        (
            variable
            if isinstance(variable, Intervention)
            else Intervention(name=variable.name, star=False)
        )
        for variable in variables
    )


[docs] class Element(ABC): """An element in the y0 internal domain-speific language that can be converted to text, LaTeX, and code."""
[docs] @abstractmethod def to_text(self) -> str: """Output this DSL object in the internal string format."""
[docs] @abstractmethod def to_latex(self) -> str: """Output this DSL object in the LaTeX string format."""
[docs] @abstractmethod def to_y0(self) -> str: """Output this DSL object as y0 python code."""
def _repr_latex_(self) -> str: # hack for auto-display of latex in jupyter notebook return f"${self.to_latex()}$" def __str__(self) -> str: return self.to_y0() def __repr__(self) -> str: return self.to_y0() @abstractmethod def _iter_variables(self) -> Iterable[Variable]: """Iterate over variables."""
[docs] def get_variables(self) -> Set[Variable]: """Get the set of variables used in this expression.""" return set(self._iter_variables())
[docs] @dataclass(frozen=True, order=True, repr=False) class Variable(Element): """A variable, typically with a single letter.""" #: The name of the variable name: str #: The star status of the variable. None means it's a variable, #: False means it's the same as the value for the variable, #: and True means it's a different value from the variable. star: Optional[bool] = None def __post_init__(self): if not isinstance(self.name, str): raise TypeError(f"Names must be strings: {self.name}") if self.name in {"P", "Q", "PP"}: raise ValueError(f"trust me, {self.name} is a bad variable name.")
[docs] @classmethod def norm(cls, name: Union[str, Variable]) -> Variable: """Automatically upgrade a string to a variable.""" if isinstance(name, str): return Variable(name) elif isinstance(name, Variable): return name else: raise TypeError(f"({type(name)}) {name} is not valid")
[docs] def get_base(self) -> Variable: """Return the base variable, with no other nonsense.""" return Variable(self.name)
[docs] def to_text(self) -> str: """Output this variable in the internal string format.""" return self.name
[docs] def to_sympy(self) -> "sympy.Symbol": """Get the object for sympy.""" import sympy return sympy.Symbol(self.to_latex())
[docs] def to_latex(self) -> str: """Output this variable in the LaTeX string format. :returns: The LaTeX representaton of this variable. >>> Variable('X').to_latex() 'X' >>> Variable('X1').to_latex() 'X_1' >>> Variable('X12').to_latex() 'X_{12}' """ # if it ends with a number, use that as a subscript ending_numeric = 0 for c in reversed(self.name): if c.isnumeric(): ending_numeric += 1 if ending_numeric == 0: return self.name elif ending_numeric == 1: return f"{self.name[:-1]}_{self.name[-1]}" else: return f"{self.name[:-ending_numeric]}_{{{self.name[-ending_numeric:]}}}"
[docs] def to_y0(self) -> str: """Output this variable instance as y0 internal DSL code.""" if self.star is None: return self.name elif self.star: return f"+{self.name}" else: return f"-{self.name}"
[docs] def intervene(self, variables: VariableHint) -> CounterfactualVariable: """Intervene on this variable with the given variable(s). :param variables: The variable(s) used to extend this variable as it is changed to a counterfactual variable :returns: A new counterfactual variable over this variable with the given intervention(s). .. note:: This function can be accessed with the matmult @ operator. """ interventions = _to_interventions(_upgrade_variables(variables)) return CounterfactualVariable( name=self.name, star=self.star, interventions=frozenset(interventions), )
def __matmul__(self, variables: VariableHint) -> CounterfactualVariable: return self.intervene(variables)
[docs] def given(self, parents: Union[VariableHint, Distribution]) -> Distribution: """Create a distribution in which this variable is conditioned on the given variable(s). The new distribution is a Markov Kernel. :param parents: A variable or list of variables to include as conditions in the new conditional distribution :returns: A new conditional probability distribution :raises TypeError: If a distribution is given as the parents that contains conditionals .. note:: This function can be accessed with the or | operator. """ if not isinstance(parents, Distribution): return Distribution( children=(self,), parents=_upgrade_ordering(parents), ) elif parents.is_conditioned(): raise TypeError("can not be given a distribution that has conditionals") else: # The parents variable is actually a Distribution instance with no parents, # so its children become the parents for the new Markov Kernel distribution return Distribution( children=(self,), parents=parents.children, # don't think about this too hard )
def __or__(self, parents: Union[VariableHint, Distribution]) -> Distribution: return self.given(parents)
[docs] def joint(self, children: VariableHint) -> Distribution: """Create a joint distribution between this variable and the given variable(s). :param children: The variable(s) for use with this variable in a joint distribution :returns: A new joint distribution over this variable and the given variables. .. note:: This function can be accessed with the and & operator. """ return Distribution( children=_upgrade_ordering((self, *_upgrade_variables(children))), )
def __and__(self, children: VariableHint) -> Distribution: return self.joint(children) def _intervention(self, star: bool) -> Variable: return Intervention(name=self.name, star=star)
[docs] def invert(self) -> Variable: """Create an :class:`Intervention` variable that is different from what was observed (with a star).""" return self._intervention(not self.star)
def __invert__(self) -> Variable: return self.invert() def __pos__(self) -> Variable: return self._intervention(True) def __neg__(self) -> Variable: return self._intervention(False) @classmethod def __class_getitem__(cls, item) -> Variable: return Variable(item) def _iter_variables(self) -> Iterable[Variable]: """Get a set containing this variable.""" yield self
VariableHint = Union[str, Variable, Iterable[Union[str, Variable]]]
[docs] @dataclass(frozen=True, order=True, repr=False) class Intervention(Variable): """An intervention variable. An intervention variable is usually used as a subscript in a :class:`CounterfactualVariable`. """ def __post_init__(self): if self.star is None: raise ValueError("Intervention must have a non-None star")
[docs] def to_text(self) -> str: """Output this intervention variable in the internal string format.""" return f"{self.name}*" if self.star else self.name
[docs] def to_latex(self) -> str: """Output this intervention variable in the LaTeX string format.""" latex = super().to_latex() return f"{latex}^*" if self.star else latex
[docs] def to_y0(self) -> str: """Output this intervention instance as y0 internal DSL code.""" mark = "+" if self.star else "-" return f"{mark}{self.name}"
[docs] @dataclass(frozen=True, order=True, repr=False) class CounterfactualVariable(Variable): """A counterfactual variable. Counterfactual variables are like normal variables, but can have a list of interventions. Each intervention is either the same as what was observed (no star) or different from what was observed (star). """ #: The interventions on the variable. Should be non-empty interventions: frozenset[Intervention] = field(default_factory=frozenset) def __post_init__(self): if not self.interventions: raise ValueError("should give at least one intervention") for intervention in self.interventions: if not isinstance(intervention, Intervention): raise TypeError( f"only Intervention instances are allowed." f" Got: ({intervention.__class__.__name__}) {intervention}", )
[docs] def to_text(self) -> str: """Output this counterfactual variable in the internal string format.""" intervention_latex = _list_to_text(_sort_interventions(self.interventions)) return f"{self.name}_{{{intervention_latex}}}"
[docs] def to_latex(self) -> str: """Output this counterfactual variable in the LaTeX string format. :returns: A latex representation of this counterfactual variable >>> (Variable('X') @ Variable('Y')).to_latex() '{X}_{Y}' >>> (Variable('X1') @ Variable('Y')).to_latex() '{X_1}_{Y}' >>> (Variable('X12') @ Variable('Y')).to_latex() '{X_{12}}_{Y}' """ intervention_latex = _list_to_latex(_sort_interventions(self.interventions)) prefix = "^*" if self.star else "" return f"{{{super().to_latex()}}}{prefix}_{{{intervention_latex}}}"
[docs] def to_y0(self) -> str: """Output this counterfactual variable instance as y0 internal DSL code.""" if self.star is None: prefix = "" elif self.star: prefix = "+" else: prefix = "-" if len(self.interventions) == 1: return f"{prefix}{self.name} @ {list(self.interventions)[0].to_y0()}" else: ins = ", ".join(i.to_y0() for i in _sort_interventions(self.interventions)) return f"{prefix}{self.name} @ ({ins})"
[docs] def is_event(self) -> bool: """Return if the counterfactual variable has a value.""" return self.star is not None
[docs] def has_tautology(self) -> bool: """Return if the counterfactual variable contain its own value in the subscript. :returns: True if we force a variable X to have a value x and the resulting value of X is x. :raises ValueError: if the counterfactual value doesn't have a value assigned """ if not self.is_event(): raise ValueError( "Can not determine the consistency of a counterfactual variable with no value assigned." ) return any(self.name == i.name and self.star == i.star for i in self.interventions)
[docs] def is_inconsistent(self) -> bool: """Return if the counterfactual variable violates the Axiom of Effectiveness. :returns: True if we force a variable X to have a value x and the resulting value of X is not x :raises ValueError: if the counterfactual value doesn't have a value assigned """ if not self.is_event(): raise ValueError( "Can not determine the consistency of a counterfactual variable with no value assigned." ) return any(self.name == i.name and self.star != i.star for i in self.interventions)
[docs] def intervene(self, variables: VariableHint) -> CounterfactualVariable: """Intervene on this counterfactual variable with the given variable(s). :param variables: The variable(s) used to extend this counterfactual variable's current interventions. Automatically converts variables to interventions. :returns: A new counterfactual variable with both this counterfactual variable's interventions and the given intervention(s). .. warning:: Will raise a value error ff the value of a new intervention conflicts with the value of intervention already listed in this counterfactual. .. note:: This function can be accessed with the matmult @ operator. """ _interventions = _to_interventions(_upgrade_ordering(variables)) interventions = {*self.interventions, *_interventions} self._raise_for_overlapping_interventions(interventions) return CounterfactualVariable( name=self.name, star=self.star, interventions=frozenset(interventions) )
@staticmethod def _raise_for_overlapping_interventions(interventions: Iterable[Intervention]) -> None: """Raise an error if there are two values of the same variable in the list of interventions. :param interventions: Interventions to check for overlap :raises ValueError: If there are overlapping variables given """ overlaps = { (old, new) for old, new in itt.product(interventions, repeat=2) if old.name == new.name and old.star != new.star } if overlaps: raise ValueError(f"Overlapping interventions in new interventions: {overlaps}") def _with_star(self, star: bool) -> CounterfactualVariable: return CounterfactualVariable( name=self.name, star=star, interventions=self.interventions, )
[docs] def invert(self) -> CounterfactualVariable: """Invert the value of the counterfactual variable.""" return self._with_star(not self.star)
def __pos__(self) -> CounterfactualVariable: return self._with_star(True) def __neg__(self) -> CounterfactualVariable: return self._with_star(False) def _iter_variables(self) -> Iterable[Variable]: """Get the union of this variable and its interventions.""" yield from super()._iter_variables() for intervention in self.interventions: yield from intervention._iter_variables()
[docs] @dataclass(frozen=True) class Distribution(Element): """A general distribution over several child variables, conditioned by several parents. P(X | Y) means that X is a child and Y is a parent. """ children: Tuple[Variable, ...] parents: Tuple[Variable, ...] = field(default_factory=tuple) def __post_init__(self): if isinstance(self.children, (list, Variable)): raise TypeError(f"children of wrong type: {type(self.children)}") if isinstance(self.parents, (list, Variable)): raise TypeError if not self.children: raise ValueError("distribution must have at least one child")
[docs] @classmethod def safe( cls, distribution: Union[VariableHint, Distribution], *args: Union[str, Variable, Distribution], ) -> Distribution: """Create a distribution the given variable(s) or distribution. :param distribution: If given a :class:`Distribution`, creates a probability expression directly over the distribution. If given variable or list of variables, conveniently creates a :class:`Distribution` with the variable(s) as children. :param args: If the first argument (``distribution``) was given as a single variable, the ``args`` variadic argument can be used to specify a list of additional variables. :returns: A Distribution object :raises ValueError: If invalid combination of arguments are given. """ if isinstance(distribution, (str, Variable, Distribution)): extended_args = [distribution, *args] dist_pos = [i for i, e in enumerate(extended_args) if isinstance(e, Distribution)] # There are no distributions (e.g., no conditionals were given with the | already) if 0 == len(dist_pos): return Distribution( children=_upgrade_ordering(cast(VariableHint, extended_args)), ) # A single conditional was given. Everything before it should be considered # as child variables, and everything after as parent variables. elif 1 == len(dist_pos): i = dist_pos[0] pre = cast(Iterable[Union[str, Variable]], extended_args[:i]) dist = cast(Distribution, extended_args[i]) post = cast(Iterable[Union[str, Variable]], extended_args[i + 1 :]) return Distribution( children=_sorted_variables((*_upgrade_ordering(pre), *dist.children)), parents=_sorted_variables((*dist.parents, *_upgrade_ordering(post))), ) # Multiple conditionals were detected. This isn't allowed. else: raise ValueError("can not give multiple distribution objects") elif args: raise ValueError("can not use args/parents when giving an iterable as first argument") else: return Distribution( children=_upgrade_ordering(distribution), )
def _to_x(self, func: Callable[[Iterable[Variable]], str]) -> str: children = func(self.children) if not self.parents: return children return f"{children} | {func(self.parents)}"
[docs] def to_text(self) -> str: """Output this distribution in the internal string format.""" return self._to_x(_list_to_text)
[docs] def to_y0(self) -> str: """Output this distribution instance as y0 internal DSL code.""" return self._to_x(_list_to_y0)
[docs] def to_latex(self) -> str: """Output this distribution in the LaTeX string format.""" return self._to_x(_list_to_latex)
[docs] def is_conditioned(self) -> bool: """Return if this distribution is conditioned.""" return 0 < len(self.parents)
[docs] def is_markov_kernel(self) -> bool: """Return if this distribution a markov kernel -> one child variable and one or more conditionals.""" return len(self.children) == 1
[docs] def intervene(self, variables: VariableHint) -> Distribution: """Return a new distribution that has the given intervention(s) on all variables.""" # check that the variables aren't in any of them yet variables = _upgrade_ordering(variables) return Distribution( children=tuple(child.intervene(variables) for child in self.children), parents=tuple(parent.intervene(variables) for parent in self.parents), )
def __matmul__(self, variables: VariableHint) -> Distribution: return self.intervene(variables)
[docs] def uncondition(self) -> Distribution: """Return a new distribution that is not conditioned on the parents.""" return Distribution( children=(*self.children, *self.parents), )
[docs] def joint(self, children: VariableHint) -> Distribution: """Create a new distribution including the given child variables. :param children: The variable(s) with which this distribution's children are extended :returns: A new distribution. .. note:: This function can be accessed with the and & operator. """ return Distribution( children=_upgrade_ordering((*self.children, *_upgrade_variables(children))), parents=self.parents, )
def __and__(self, children: VariableHint) -> Distribution: return self.joint(children)
[docs] def given(self, parents: Union[VariableHint, Distribution]) -> Distribution: """Create a new mixed distribution additionally conditioned on the given parent variables. :param parents: The variable(s) with which this distribution's parents are extended :returns: A new distribution :raises TypeError: If a distribution is given as the parents that contains conditionals .. note:: This function can be accessed with the or | operator. """ # TODO handle duplicate variables in the parents. if not isinstance(parents, Distribution): return Distribution( children=self.children, parents=_upgrade_ordering((*self.parents, *_upgrade_variables(parents))), ) elif parents.is_conditioned(): raise TypeError("can not be given a distribution that has conditionals") else: # The parents variable is actually a Distribution instance with no parents, # so its children get appended as parents for the new mixed distribution return Distribution( children=self.children, parents=( *self.parents, *parents.children, ), # don't think about this too hard )
def __or__(self, parents: Union[VariableHint, Distribution]) -> Distribution: return self.given(parents) def _iter_variables(self) -> Iterable[Variable]: """Get the set of variables used in this distribution.""" for variable in itt.chain(self.children, self.parents): yield from variable._iter_variables()
[docs] class Expression(Element, ABC): """The abstract class representing all expressions.""" @abstractmethod def __mul__(self, other): pass @abstractmethod def _get_key(self) -> tuple: """Generate a sort key for a *canonical* expression. :returns: A tuple in which the first element is the integer priority for the expression and the rest depends on the expression type. """ raise NotImplementedError def __lt__(self, other: Expression): return self._get_key() < other._get_key() def __truediv__(self, expression: Expression) -> Expression: """Divide this expression by another and create a fraction.""" if isinstance(expression, One): return self elif isinstance(expression, Fraction): return Fraction(self * expression.denominator, expression.numerator) else: return Fraction(self, expression)
[docs] def conditional(self, ranges: VariableHint) -> Expression: """Return this expression, conditioned by the given variables. :param ranges: A variable or list of variables over which to marginalize this expression :returns: A fraction in which the denominator is represents the sum over the given ranges >>> from y0.dsl import P, A, B >>> assert P(A, B).conditional(A) == P(A, B) / Sum[B](P(A, B)) >>> assert P(A, B, C).conditional([A, B]) == P(A, B, C) / Sum[C](P(A, B, C)) """ ranges = _upgrade_ordering([r.get_base() for r in _upgrade_variables(ranges)]) ranges_complement = set([c.get_base() for c in self._iter_variables()]) - set(ranges) return self.normalize_marginalize(ranges_complement)
[docs] def normalize_marginalize(self, ranges: VariableHint) -> Expression: """Return this expression, normalized by this expression marginalized by the given variables.""" return self / self.marginalize(ranges)
[docs] def marginalize(self, ranges: VariableHint) -> Expression: """Return this expression, marginalizing out the given variables. :param ranges: A variable or list of variables over which to marginalize this expression :returns: The expression but summed over the given variables >>> from y0.dsl import P, A, B, C >>> assert P(A, B).marginalize(A) == Sum[A](P(A, B)) >>> assert P(A, B, C).marginalize([A, B]) == Sum[A, B](P(A, B, C)) """ return Sum.safe( expression=self, ranges=_upgrade_ordering([r.get_base() for r in _upgrade_variables(ranges)]), )
[docs] @dataclass(frozen=True, repr=False) class Probability(Expression): """The probability over a distribution.""" #: The distribution over which the probability is expressed distribution: Distribution
[docs] @classmethod def safe( cls, distribution: DistributionHint, *args: Union[str, Variable], interventions: Optional[VariableHint] = None, ) -> Probability: """Create a distribution the given variable(s) or distribution. :param distribution: If given a :class:`Distribution`, creates a probability expression directly over the distribution. If given variable or list of variables, conveniently creates a :class:`Distribution` with the variable(s) as children. :param args: If the first argument (``distribution``) was given as a single variable, the ``args`` variadic argument can be used to specify a list of additional variables. :param interventions: An optional variable or variables to use as interventions. :returns: A probability object """ distribution = Distribution.safe(distribution, *args) if interventions is not None: distribution = distribution.intervene(interventions) return Probability(distribution)
def _get_key(self): # TODO incorporate more information from children and parents return 0, self.children[0].name
[docs] def to_text(self) -> str: """Output this probability in the internal string format.""" return f"P({self.distribution.to_text()})"
def _help_level_2_distribution(self): # if all parts of distribution have same intervention set, then put it out front intervention_sets = { x.interventions if isinstance(x, CounterfactualVariable) else tuple() for x in itt.chain(self.children, self.parents) } # check that there's only one intervention set and that it's not an empty one if len(intervention_sets) == 1 and (interventions := intervention_sets.pop()): unintervened_distribution = Distribution( parents=tuple(Variable(name=v.name, star=v.star) for v in self.parents), children=tuple(Variable(name=v.name, star=v.star) for v in self.children), ) return interventions, unintervened_distribution else: return None, None
[docs] def to_y0(self) -> str: """Output this probability instance as y0 internal DSL code.""" interventions, unintervened_distribution = self._help_level_2_distribution() if not interventions: return f"P({self.distribution.to_y0()})" # only keep the + if necessary, otherwise show regular intervention_str = ",".join( f"+{intervention.name}" if intervention.star else intervention.name for intervention in interventions ) return f"P[{intervention_str}]({unintervened_distribution.to_y0()})"
[docs] def to_latex(self) -> str: """Output this probability in the LaTeX string format.""" interventions, unintervened_distribution = self._help_level_2_distribution() if not interventions: return f"P({self.distribution.to_latex()})" intervention_str = ",".join(intervention.to_latex() for intervention in interventions) return f"P_{{{intervention_str}}}({unintervened_distribution.to_latex()})"
@property def parents(self) -> Tuple[Variable, ...]: """Get the distribution's parents.""" return self.distribution.parents @property def children(self) -> Tuple[Variable, ...]: """Get the distribution's children.""" return self.distribution.children
[docs] def is_conditioned(self) -> bool: """Return if this distribution is conditioned.""" return self.distribution.is_conditioned()
[docs] def is_markov_kernel(self) -> bool: """Return if this distribution a markov kernel -> one child variable and one or more conditionals.""" return self.distribution.is_markov_kernel()
def __mul__(self, other: Expression) -> Expression: if isinstance(other, Zero): return other elif isinstance(other, One): return self elif isinstance(other, Product): return Product.safe((self, *other.expressions)) elif isinstance(other, Fraction): return Fraction(self * other.numerator, other.denominator) else: return Product.safe((self, other)) def _new(self, distribution: Distribution): # This is implemented this way to make overriding easier return Probability(distribution)
[docs] def intervene(self, variables: VariableHint) -> Probability: """Return a new probability where the underlying distribution has been intervened by the given variables.""" return self._new(self.distribution.intervene(variables))
def __matmul__(self, variables: VariableHint) -> Probability: return self.intervene(variables)
[docs] def uncondition(self) -> Probability: """Return a new probability where the underlying distribution is no longer conditioned by the parents. :returns: A new probability over a distribution over the children and parents of the previous distribution >>> from y0.dsl import P, A, B >>> P(A | B).uncondition() == P(A, B) """ return self._new(self.distribution.uncondition())
[docs] def conditional(self, ranges: VariableHint) -> Expression: """Return this expression, conditioned by the given variables. :param ranges: A variable or list of variables over which to marginalize this expression :returns: A fraction in which the denominator is represents the sum over the given ranges >>> from y0.dsl import P, A, B >>> assert P(A, B).conditional(A) == P(A, B) / Sum[B](P(A, B)) >>> assert P(A, B, C).conditional([A, B]) == P(A, B, C) / Sum[C](P(A, B, C)) """ ranges = _upgrade_ordering([r.get_base() for r in _upgrade_variables(ranges)]) ranges_complement = set( [c.get_base() for c in self._iter_variables() if not isinstance(c, Intervention)] ) - set(ranges) return self.normalize_marginalize(ranges_complement)
def _iter_variables(self) -> Iterable[Variable]: """Get the set of variables used in the distribution in this probability.""" yield from self.distribution._iter_variables()
DistributionHint = Union[VariableHint, Distribution] class ProbabilityBuilderType: """A base class for building probability distributions.""" def __call__( self, distribution: DistributionHint, *args: Union[str, Variable], interventions: Optional[VariableHint] = None, ) -> Probability: return Probability.safe(distribution, *args, interventions=interventions) def __getitem__(self, interventions: VariableHint): """Generate a probability builder closure. :param interventions: A variable or variables to intervene on using the do-calculus level 2 rules, meaning they are all applied to all parent and children variables in the resulting expression :returns: A function with the same semantics as :meth:`__call__` such that you can build a probability expression. >>> from y0.dsl import P, W, X, Y, Z >>> assert P[X](Y) == P(Y @ X) >>> assert P[X](Y, Z) == P(Y @ X & Z @ X) >>> assert P[X](Y | Z) == P(Y @ X | Z @ X) >>> assert P[X](Y @ Z) == P(Y @ Z @ X) >>> assert P[X](Y @ Z | W) == P(Y @ Z @ X | W @ X) """ return functools.partial(self, interventions=interventions) P = ProbabilityBuilderType() """``P`` is a magical object of mystery and wonder that can be used to create :class:`Probability` instances. It itself is a singleton instance of :class:`ProbabilityBuilderType` and can be used wither via the :meth:`ProbabilityBuilderType.__call__`, as if it were a function like ``P(Y)`` or it can be used as a combination with the :meth:`ProbabilityBuilderType.__getitem__` and a call, like ``P[X](Y)`` to denote interventions using the do-Calculus $L_2$ notation. Here are some examples: A univariate distribution can be created either with a string or a :class:`Variable`: >>> from y0.dsl import P, A >>> P('A') == P(A) **Multivariate Distributions** A joint distribution can be created with several strings or :class:`Variable` instances with variadic arguments: >>> from y0.dsl import P, A, B >>> P(A, B) == P('A', 'B') A joint distribution can also be created with a single argument that is either an iterable of either strings or :class:`Variable` instances >>> from y0.dsl import P, A, B >>> P((A, B)) == P([A, B]) == P(('A', 'B')) == P(['A', 'B']) This even extends to fancy generators, for which you can omit the parentheses: Creation with a fancy generator of variables: >>> from y0.dsl import P, A, B >>> P(Variable(name) for name in 'AB') == P(name for name in 'AB') == P(A, B) **Conditional Distributions** Creation with a conditional distribution: >>> from y0.dsl import P, A, B >>> P(A | B) Creation with a mixed joint/conditional distribution: >>> from y0.dsl import P, A, B, C >>> P(A & B | C) **Specifying an Intervention with L2 do-Calculus Notation** Intervene on a single variable: >>> from y0.dsl import P, X, Y >>> P[X](Y) == P(Y @ X) Intervene on multiple children: >>> from y0.dsl import P, X, Y, Z >>> P[X](Y, Z) == P(Y @ X & Z @ X) Intervene on multiple parents: >>> from y0.dsl import P, W, X, Y, Z >>> P[X](Y | (W, Z)) == P(Y @ X | (W @ X, Z @ X)): Intervene on both children and parents: >>> from y0.dsl import P, X, Y, Z >>> P[X](Y | Z) == P(Y @ X | Z @ X) Intervene on X on top of previous interventions: >>> from y0.dsl import P, X, Y, Z >>> P[X](Y @ Z) == P(Y @ X @ Z) Allow mixing with L3, where each variable can have different interventions: >>> from y0.dsl import P, W, X, Y, Z >>> P[X](Y @ Z | W) == P(Y @ X @ Z | W @ X) **Specifying Multiple Interventions with L2 do-Calculus Notation** Multiple interventions on a single variable: >>> from y0.dsl import P, X1, X2, Y >>> P[X1, X2](Y) == P(Y @ X) Multiple interventions on multiple children: >>> from y0.dsl import P, X1, X2, Y, Z >>> P[X1, X2](Y, Z) == P(Y @ X1 @ X2 & Z @ X1 @ X2) ... and so on """
[docs] @dataclass(frozen=True, repr=False) class Product(Expression): """Represent the product of several probability expressions.""" expressions: Tuple[Expression, ...] def __post_init__(self): if len(self.expressions) < 2: raise ValueError("Product() must two or more expressions")
[docs] @classmethod def safe(cls, expressions: Union[Expression, Iterable[Expression]]) -> Expression: """Construct a product from any iterable of expressions. :param expressions: An expression or iterable of expressions which should be multiplied :returns: A :class:`Product` object Standard usage, same as the normal ``__init__``: >>> from y0.dsl import Product, X, Y, A, P >>> Product.safe((P(X, Y), )) Use a list or other iterable: >>> Product.safe([P(X), P(Y | X)]) Use an inline generator: >>> Product.safe(P(v) for v in [X, Y]) Use a single expression: >>> Product.safe(P(X, Y)) """ if isinstance(expressions, Expression): return expressions # Remove multiplications of one expressions = tuple(expression for expression in expressions if expression != One()) # If any multiplications are by zero, then return zero if any(expression == Zero() for expression in expressions): return Zero() if not expressions: return One() if len(expressions) == 1: return expressions[0] return cls(expressions=tuple(sorted(expressions)))
def _get_key(self): inner_keys = (sexpr._get_key() for sexpr in self.expressions) return 2, *inner_keys
[docs] def to_text(self): """Output this product in the internal string format.""" return " ".join(expression.to_text() for expression in self.expressions)
[docs] def to_y0(self) -> str: """Output this product instance as y0 internal DSL code.""" return " * ".join(expr.to_y0() for expr in self.expressions)
[docs] def to_latex(self): """Output this product in the LaTeX string format.""" return " ".join(expression.to_latex() for expression in self.expressions)
def __mul__(self, other: Expression): if isinstance(other, Zero): return other if isinstance(other, Product): return Product.safe((*self.expressions, *other.expressions)) elif isinstance(other, Fraction): return Fraction(self * other.numerator, other.denominator) else: return Product.safe((*self.expressions, other)) def _iter_variables(self) -> Iterable[Variable]: """Get the union of the variables used in each expresison in this product.""" for expression in self.expressions: yield from expression._iter_variables()
def _list_to_text(elements: Iterable[Element]) -> str: return ", ".join(element.to_text() for element in elements) def _list_to_latex(elements: Iterable[Element]) -> str: return ", ".join(element.to_latex() for element in elements) def _list_to_y0(elements: Iterable[Element]) -> str: return ", ".join(element.to_y0() for element in elements)
[docs] @dataclass(frozen=True, repr=False) class Sum(Expression): """Represent the sum over an expression over an optional set of variables.""" #: The expression over which the sum is done expression: Expression #: The variables over which the sum is done. Defaults to an empty list, meaning no variables. ranges: frozenset[Variable] def __post_init__(self): if not isinstance(self.ranges, frozenset): raise TypeError if not self.ranges: raise ValueError("Sum must have ranges") for r in self.ranges: if isinstance(r, (CounterfactualVariable, Intervention)): raise TypeError("Ranges must not be counterfactuals nor interventions")
[docs] @classmethod def safe( cls, expression: Expression, ranges: Union[str, Variable, Iterable[Union[str, Variable]]], *, simplify: bool = False, ) -> Expression: """Construct a sum from an expression and a permissive set of things in the ranges. :param expression: The expression over which the sum is done :param ranges: The variable or list of variables over which the sum is done :param simplify: Should the sum be simplified using :func:`Sum.simplify`? :returns: A :class:`Sum` object Standard usage, same as the normal ``__init__``: >>> from y0.dsl import Sum, X, Y, A, P >>> Sum.safe(P(X, Y), (X,)) Use a list or other iterable: >>> Sum.safe(P(X, Y), [X]) Use a single variable: >>> Sum.safe(P(X, Y), X) """ if isinstance(ranges, str): ranges = (Variable(ranges),) elif isinstance(ranges, Variable): ranges = (ranges,) else: ranges = _upgrade_ordering(ranges) if not ranges: return expression if isinstance(expression, Zero): return expression rv = cls( expression=expression, ranges=frozenset(ranges), ) if simplify: return rv.simplify() return rv
[docs] def simplify(self) -> Expression: """Simplify this sum.""" expression = self.expression ranges = set(self.ranges) # Special case when ranges cover if isinstance(expression, Probability) and not expression.parents: # i.e., no conditions children = { child.get_base(): child for child in expression.children # FIXME what happens if same name appears with multiple different counterfactual variables? # this should actually evaluate to zero since that's impossible } if ranges == set(children): return One() elif ranges > set(children): keep = ranges - set(children) return Sum.safe( expression=One(), ranges=frozenset(v for k, v in children.items() if k in keep), ) elif ranges < set(children): keep = set(children) - ranges return expression._new( Distribution.safe(v for k, v in children.items() if k in keep) ) else: # partial or no overlap intersection = ranges.intersection(children) keep = set(children) - intersection prob = expression._new( Distribution.safe(v for k, v in children.items() if k in keep) ) return Sum.safe( expression=prob, ranges=ranges - intersection, ) return self
def _get_key(self): return 1, *self.expression._get_key() def _get_sorted_ranges(self) -> Sequence[Variable]: return sorted(self.ranges, key=attrgetter("name"))
[docs] def to_text(self) -> str: """Output this sum in the internal string format.""" ranges = _list_to_text(self._get_sorted_ranges()) return f"[ sum_{{{ranges}}} {self.expression.to_text()} ]"
[docs] def to_latex(self) -> str: """Output this sum in the LaTeX string format.""" ranges = _list_to_latex(self._get_sorted_ranges()) return rf"\sum\limits_{{{ranges}}} {self.expression.to_latex()}"
[docs] def to_y0(self): """Output this sum instance as y0 internal DSL code.""" if isinstance(self.expression, Fraction): s = self.expression.to_y0(parens=False) else: s = self.expression.to_y0() if not self.ranges: return f"Sum({s})" ranges = _list_to_y0(self._get_sorted_ranges()) return f"Sum[{ranges}]({s})"
def __mul__(self, expression: Expression): if isinstance(expression, Zero): return expression elif isinstance(expression, Product): return Product.safe((self, *expression.expressions)) else: return Product.safe((self, expression)) def _iter_variables(self) -> Iterable[Variable]: """Get the union of the variables used in the range of this sum and variables in its summand.""" yield from self.expression._iter_variables() for variable in self.ranges: yield from variable._iter_variables() @classmethod def __class_getitem__(cls, ranges: VariableHint) -> Callable[[Expression], Expression]: """Create a partial sum object over the given ranges. :param ranges: The variables over which the partial sum will be done :returns: A partial :class:`Sum` that can be called solely on an expression Example single variable sum: >>> from y0.dsl import Sum, P, A, B >>> Sum[B](P(A | B) * P(B)) Example multiple variable sum: >>> from y0.dsl import Sum, P, A, B, C >>> Sum[B, C](P(A | B) * P(B)) """ return functools.partial(Sum.safe, ranges=_upgrade_ordering(ranges))
[docs] @dataclass(frozen=True, repr=False) class Fraction(Expression): """Represents a fraction of two expressions.""" #: The expression in the numerator of the fraction numerator: Expression #: The expression in the denominator of the fraction denominator: Expression def __post_init__(self): if isinstance(self.denominator, Zero): raise ZeroDivisionError def _get_key(self): return ( 3, self.numerator._get_key(), self.denominator._get_key(), )
[docs] def to_text(self) -> str: """Output this fraction in the internal string format.""" return f"frac_{{{self.numerator.to_text()}}}{{{self.denominator.to_text()}}}"
[docs] def to_latex(self) -> str: """Output this fraction in the LaTeX string format.""" return rf"\frac{{{self.numerator.to_latex()}}}{{{self.denominator.to_latex()}}}"
[docs] def to_y0(self, parens: bool = True) -> str: """Output this fraction as y0 internal DSL code.""" s = f"({self.numerator.to_y0()} / {self.denominator.to_y0()})" return f"({s})" if parens else s
def __mul__(self, expression: Expression) -> Expression: if isinstance(expression, Zero): return expression elif isinstance(expression, Fraction): return Fraction( self.numerator * expression.numerator, self.denominator * expression.denominator, ) else: return Fraction(self.numerator * expression, self.denominator) def __truediv__(self, expression: Expression) -> Fraction: if isinstance(expression, One): return self elif isinstance(expression, Fraction): return Fraction( self.numerator * expression.denominator, self.denominator * expression.numerator, ) else: return Fraction(self.numerator, self.denominator * expression) def _iter_variables(self) -> Iterable[Variable]: """Get the set of variables used in the numerator and denominator of this fraction.""" yield from self.numerator._iter_variables() yield from self.denominator._iter_variables()
[docs] def flip(self) -> Fraction: """Exchange the numerator and denominator.""" return Fraction(self.denominator, self.numerator)
[docs] def simplify(self) -> Expression: """Simplify this fraction.""" if isinstance(self.denominator, One): return self.numerator if isinstance(self.numerator, Zero): return self.numerator if isinstance(self.numerator, One): if isinstance(self.denominator, Fraction): return self.denominator.flip().simplify() else: return self if self.numerator == self.denominator: return One() if isinstance(self.numerator, Product) and isinstance(self.denominator, Product): return self._simplify_parts(self.numerator.expressions, self.denominator.expressions) elif isinstance(self.numerator, Product): return self._simplify_parts(self.numerator.expressions, [self.denominator]) elif isinstance(self.denominator, Product): return self._simplify_parts([self.numerator], self.denominator.expressions) return self
@classmethod def _simplify_parts( cls, numerator: Sequence[Expression], denominator: Sequence[Expression] ) -> Expression: """Calculate the minimum fraction. :param numerator: A sequence of expressions that are multiplied in the product in the numerator :param denominator: A sequence of expressions that are multiplied in the product in the denominator :returns: A simplified fraction. """ new_numerator, new_denominator = cls._simplify_parts_helper(numerator, denominator) if new_numerator and new_denominator: return Fraction( Product.safe(new_numerator), Product.safe(new_denominator), ) elif new_numerator: return Product.safe(new_numerator) elif new_denominator: return One() / Product.safe(new_denominator) else: return One() @staticmethod def _simplify_parts_helper( numerator: Sequence[Expression], denominator: Sequence[Expression], ) -> Tuple[Tuple[Expression, ...], Tuple[Expression, ...]]: numerator_cancelled = set() denominator_cancelled = set() for i, n_expr in enumerate(numerator): for j, d_expr in enumerate(denominator): if j in denominator_cancelled: continue if n_expr == d_expr: numerator_cancelled.add(i) denominator_cancelled.add(j) break return ( tuple(expr for i, expr in enumerate(numerator) if i not in numerator_cancelled), tuple(expr for i, expr in enumerate(denominator) if i not in denominator_cancelled), )
[docs] class One(Expression): """The multiplicative identity (1)."""
[docs] def to_text(self) -> str: """Output this identity variable in the internal string format.""" return "1"
[docs] def to_latex(self) -> str: """Output this identity instance in the LaTeX string format.""" return "1"
[docs] def to_y0(self) -> str: """Output this identity instance as y0 internal DSL code.""" return "One()"
def _get_key(self): return 4, self.to_text() def __rmul__(self, expression: Expression) -> Expression: return expression def __mul__(self, expression: Expression) -> Expression: return expression def __eq__(self, other): return isinstance(other, One) # all ones are equal def _iter_variables(self) -> Iterable[Variable]: """Get the set of variables used in this expression.""" return iter([])
[docs] class Zero(Expression): """The additive identity (0)."""
[docs] def to_text(self) -> str: """Output this identity variable in the internal string format.""" return "0"
[docs] def to_latex(self) -> str: """Output this identity instance in the LaTeX string format.""" return "0"
[docs] def to_y0(self) -> str: """Output this identity instance as y0 internal DSL code.""" return "Zero()"
def _get_key(self): return 4, self.to_text() def __rmul__(self, expression: Expression) -> Expression: return self def __mul__(self, expression: Expression) -> Expression: return self def __truediv__(self, other: Expression) -> Expression: if isinstance(other, Zero): raise ZeroDivisionError return self def __eq__(self, other): return isinstance(other, Zero) # all zeros are equal def _iter_variables(self) -> Iterable[Variable]: """Get the set of variables used in this expression.""" return iter([])
class QBuilder(Protocol[T_co]): """A protocol for annotating the special class getitem functionality of the :class:`QFactor` class.""" def __call__(self, arg: VariableHint, *args: Union[str, Variable]) -> T_co: ...
[docs] @dataclass(frozen=True, repr=False) class QFactor(Expression): """A function from the variables in the domain to a probability function over variables in the codomain.""" domain: frozenset[Variable] codomain: frozenset[Variable]
[docs] @classmethod def safe( cls, domain: VariableHint, *args: Union[str, Variable], codomain: VariableHint, ) -> QFactor: """Create a Q factor with various input types.""" return cls( domain=cls._prepare_domain(domain, *args), codomain=frozenset(_upgrade_variables(codomain)), )
@staticmethod def _prepare_domain( arg: VariableHint, *args: Union[str, Variable], ) -> frozenset[Variable]: """Prepare a list of variables from a potentially unruly set of args and variadic args.""" if isinstance(arg, (str, Variable)): return frozenset((Variable.norm(arg), *_upgrade_ordering(args))) if args: raise ValueError("can not use variadic arguments with combination of first arg") return frozenset(_sorted_variables(_upgrade_ordering(arg))) @classmethod def __class_getitem__(cls, codomain: Union[Variable, Iterable[Variable]]) -> QBuilder[QFactor]: """Create a partial Q Factor object over the given codomain. :param codomain: The variables over which the partial Q Factor will be done :returns: A partial :class:`QFactor` that can be called solely on an expression Example single variable codomain Q expression: >>> from y0.dsl import Sum, Q, A, B, C >>> Q[C](A, B) Example multiple variable codomain Q expression: >>> from y0.dsl import Sum, Q, A, B, C, D >>> Q[C, D](A, B) """ return functools.partial(cls.safe, codomain=codomain) def _get_key(self) -> tuple: return -5, min(v.name for v in self.domain), min(v.name for v in self.codomain) def _sorted_codomain(self): return sorted(self.codomain, key=attrgetter("name")) def _sorted_domain(self): return sorted(self.domain, key=attrgetter("name"))
[docs] def to_text(self) -> str: """Output this Q factor in the internal string format.""" codomain = _list_to_text(self._sorted_codomain()) domain = _list_to_text(self._sorted_domain()) return f"Q[{codomain}]({domain})"
[docs] def to_latex(self) -> str: """Output this Q factor in the LaTeX string format.""" codomain = _list_to_latex(self._sorted_codomain()) domain = _list_to_latex(self._sorted_domain()) return rf"Q_{{{codomain}}}({{{domain}}})"
[docs] def to_y0(self) -> str: """Output this Q factor instance as y0 internal DSL code.""" codomain = _list_to_y0(self._sorted_codomain()) domain = _list_to_y0(self._sorted_domain()) return f"Q[{codomain}]({domain})"
def __mul__(self, other: Expression): if isinstance(other, Product): return Product.safe((self, *other.expressions)) elif isinstance(other, Fraction): return Fraction(self * other.numerator, other.denominator) else: return Product.safe((self, other)) def _iter_variables(self) -> Iterable[Variable]: yield from self.codomain yield from self.domain
Q = QFactor AA = Variable("AA") A, B, C, D, E, F, G, M, R, S, T, U, W, X, Y, Z = map(Variable, "ABCDEFGMRSTUWXYZ") # type: ignore U1, U2, U3, U4, U5, U6 = [Variable(f"U{i}") for i in range(1, 7)] V1, V2, V3, V4, V5, V6 = [Variable(f"V{i}") for i in range(1, 7)] W0, W1, W2, W3, W4, W5, W6 = [Variable(f"W{i}") for i in range(7)] M0, M1, M2, M3, M4, M5, M6 = [Variable(f"M{i}") for i in range(7)] X1, X2, X3, X4, X5, X6 = [Variable(f"X{i}") for i in range(1, 7)] Y1, Y2, Y3, Y4, Y5, Y6 = [Variable(f"Y{i}") for i in range(1, 7)] Z1, Z2, Z3, Z4, Z5, Z6 = [Variable(f"Z{i}") for i in range(1, 7)] π1, π2, π3, π4, π5, π6 = Pi1, Pi2, Pi3, Pi4, Pi5, Pi6 = [Variable(f{i}") for i in range(1, 7)] def _sort_interventions(interventions: Iterable[Intervention]) -> Tuple[Intervention, ...]: return tuple(sorted(interventions, key=lambda i: (i.name, i.star))) def _variable_sort_key(variable: Variable) -> tuple[str, str]: if isinstance(variable, CounterfactualVariable): return variable.name, ",".join( i.to_y0() for i in _sort_interventions(variable.interventions) ) else: return variable.name, "" def _sorted_variables(variables: Iterable[Variable]) -> Tuple[Variable, ...]: return tuple(sorted(variables, key=_variable_sort_key)) def _upgrade_variables(variables: VariableHint) -> Tuple[Variable, ...]: if isinstance(variables, str): return (Variable(variables),) elif isinstance(variables, Variable): return (variables,) else: return tuple(Variable.norm(variable) for variable in variables) def _upgrade_ordering(variables: VariableHint) -> Tuple[Variable, ...]: return _sorted_variables(set(_upgrade_variables(variables))) OrderingHint = Optional[Iterable[Union[str, Variable]]]
[docs] def ensure_ordering( expression: Expression, *, ordering: OrderingHint = None, ) -> Sequence[Variable]: """Get a canonical ordering of the variables in the expression, or pass one through. The canonical ordering of the variables in a given expression is based on the alphabetical sort order of the variables based on their names. :param expression: The expression to get a canonical ordering from. :param ordering: A given ordering to pass through if not none, otherwise calculate it. :returns: The ordering """ if ordering is not None: return _upgrade_ordering(ordering) # use alphabetical ordering return _sorted_variables(expression.get_variables())
def _get_treatment_variables(variables: set[Variable]) -> set[Variable]: return {variable for variable in variables if isinstance(variable, Intervention)} def _get_outcome_variables(variables: set[Variable]) -> set[Variable]: return {variable for variable in variables if not isinstance(variable, Intervention)} def get_outcomes_and_treatments(*, query: Expression) -> tuple[set[Variable], set[Variable]]: """Get outcomes and treatments sets from the query expression.""" variables = query.get_variables() return ( _get_outcome_variables(variables), _get_treatment_variables(variables), ) def outcomes_and_treatments_to_query( *, outcomes: set[Variable], treatments: Optional[set[Variable]] = None ) -> Expression: """Create a query expression from a set of outcome and treatment variables.""" if not treatments: return P(outcomes) return P(Variable.norm(y) @ _upgrade_ordering(treatments) for y in outcomes)
[docs] def vmap_pairs(edges: Iterable[Tuple[str, str]]) -> List[Tuple[Variable, Variable]]: """Map pair of strings to pairs of variables.""" return [(Variable(source), Variable(target)) for source, target in edges]
[docs] def vmap_adj(adjacency_dict): """Map an adjacency dictionary of strings to variables.""" return { Variable(source): [Variable(target) for target in targets] for source, targets in adjacency_dict.items() }
#: A conjunction of factual and counterfactual events Event = Dict[Variable, Intervention] Population = Variable
[docs] @dataclass(frozen=True, repr=False) class PopulationProbability(Probability): """A probability that is annotated with a population. >>> from y0.dsl import PP, Pi1, Y, X >>> # Make a population-annotated probability of Y >>> PP[Pi1](Y) >>> # Make a conditioned population of Y @ X >>> PP[Pi1][X](Y) Related publications: - `Surrogate Outcomes and Transportability <https://arxiv.org/abs/1806.07172>`_ (Tikka and Karvanen, 2018) """ population: Population def _new(self, distribution) -> PopulationProbability: return PopulationProbability(population=self.population, distribution=distribution) def _get_key(self): return -1, self.population, self.children[0].name
[docs] def to_y0(self) -> str: """Output this probability instance as y0 internal DSL code.""" interventions, unintervened_distribution = self._help_level_2_distribution() if not interventions: return f"P({self.distribution.to_y0()})" # only keep the + if necessary, otherwise show regular intervention_str = ",".join( f"+{intervention.name}" if intervention.star else intervention.name for intervention in interventions ) return f"PP[{self.population.to_y0()}][{intervention_str}]({unintervened_distribution.to_y0()})"
[docs] def to_latex(self) -> str: """Output this probability in the LaTeX string format.""" interventions, unintervened_distribution = self._help_level_2_distribution() if self.population == TARGET_DOMAIN: pop_latex = r"\pi^\ast" else: pop_latex = self.population.to_latex() if not interventions: return f"P^{{{pop_latex}}}({self.distribution.to_latex()})" intervention_str = ",".join(intervention.to_latex() for intervention in interventions) return f"P_{{{intervention_str}}}^{{{pop_latex}}}({unintervened_distribution.to_latex()})"
class PopulationProbabilityBuilderType(ProbabilityBuilderType): """A magical type for building population probabilities.""" def __init__(self, population: Population): """Initialize the builder with a given population.""" self.population = population @classmethod def __class_getitem__(cls, population: Population) -> "PopulationProbabilityBuilderType": """Get a population probability builder class initialized with the given population.""" return cls(population) def __call__(self, *args, **kwargs) -> PopulationProbability: # noqa:D102 probability = super().__call__(*args, **kwargs) return PopulationProbability( population=self.population, distribution=probability.distribution ) PP = PopulationProbabilityBuilderType TARGET_DOMAIN = Population("pi*")