pymc.model_graph.model_to_graphviz#

pymc.model_graph.model_to_graphviz(model=None, *, var_names=None, formatting='plain', save=None, figsize=None, dpi=300, node_formatters=None, include_dim_lengths=True)[source]#

Produce a graphviz Digraph from a PyMC model.

Requires graphviz, which may be installed most easily with

conda install -c conda-forge python-graphviz

Alternatively, you may install the graphviz binaries yourself, and then pip install graphviz to get the python bindings. See http://graphviz.readthedocs.io/en/stable/manual.html for more information.

Parameters:
modelpm.Model

The model to plot. Not required when called from inside a modelcontext.

var_namesiterable of variable names, optional

Subset of variables to be plotted that identify a subgraph with respect to the entire model graph

formattingstr, optional

one of { “plain” }

savestr, optional

If provided, an image of the graph will be saved to this location. The format is inferred from the file extension.

figsizetuple[int, int], optional

Width and height of the figure in inches. If not provided, uses the default figure size. It only affect the size of the saved figure.

dpiint, optional

Dots per inch. It only affects the resolution of the saved figure. The default is 300.

node_formattersdict, optional

A dictionary mapping node types to functions that return a dictionary of node attributes. Check out graphviz documentation for more information on available attributes. https://graphviz.org/docs/nodes/

include_dim_lengthsbool

Include the dim lengths in the plate label. Default is True.

Examples

How to plot the graph of the model.

import numpy as np
from pymc import HalfCauchy, Model, Normal, model_to_graphviz

J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])

with Model() as schools:
    eta = Normal("eta", 0, 1, shape=J)
    mu = Normal("mu", 0, sigma=1e6)
    tau = HalfCauchy("tau", 25)

    theta = mu + tau * eta

    obs = Normal("obs", theta, sigma=sigma, observed=y)

model_to_graphviz(schools)

Note that this code automatically plots the graph if executed in a Jupyter notebook. If executed non-interactively, such as in a script or python console, the graph needs to be rendered explicitly:

# creates the file `schools.pdf`
model_to_graphviz(schools).render("schools")

Display Free Random Variables and Observed Random Variables nodes with custom formatting.

node_formatters = {
    "Free Random Variable": lambda var: {"shape": "circle", "label": var.name},
    "Observed Random Variable": lambda var: {"shape": "square", "label": var.name},
}
model_to_graphviz(schools, node_formatters=node_formatters)