pymc.model_graph.model_to_graphviz#
- pymc.model_graph.model_to_graphviz(model=None, *, var_names=None, formatting='plain', save=None, figsize=None, dpi=300, node_formatters=None, include_dim_lengths=True)[source]#
Produce a graphviz Digraph from a PyMC model.
- Requires graphviz, which may be installed most easily with
conda install -c conda-forge python-graphviz
Alternatively, you may install the graphviz binaries yourself, and then pip install graphviz to get the python bindings. See http://graphviz.readthedocs.io/en/stable/manual.html for more information.
- Parameters:
- model
pm.Model
The model to plot. Not required when called from inside a modelcontext.
- var_namesiterable of
variable
names
, optional Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
- formatting
str
, optional one of { “plain” }
- save
str
, optional If provided, an image of the graph will be saved to this location. The format is inferred from the file extension.
- figsize
tuple
[int
,int
], optional Width and height of the figure in inches. If not provided, uses the default figure size. It only affect the size of the saved figure.
- dpi
int
, optional Dots per inch. It only affects the resolution of the saved figure. The default is 300.
- node_formatters
dict
, optional A dictionary mapping node types to functions that return a dictionary of node attributes. Check out graphviz documentation for more information on available attributes. https://graphviz.org/docs/nodes/
- include_dim_lengthsbool
Include the dim lengths in the plate label. Default is True.
- model
Examples
How to plot the graph of the model.
import numpy as np from pymc import HalfCauchy, Model, Normal, model_to_graphviz J = 8 y = np.array([28, 8, -3, 7, -1, 1, 18, 12]) sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) with Model() as schools: eta = Normal("eta", 0, 1, shape=J) mu = Normal("mu", 0, sigma=1e6) tau = HalfCauchy("tau", 25) theta = mu + tau * eta obs = Normal("obs", theta, sigma=sigma, observed=y) model_to_graphviz(schools)
Note that this code automatically plots the graph if executed in a Jupyter notebook. If executed non-interactively, such as in a script or python console, the graph needs to be rendered explicitly:
# creates the file `schools.pdf` model_to_graphviz(schools).render("schools")
Display Free Random Variables and Observed Random Variables nodes with custom formatting.
node_formatters = { "Free Random Variable": lambda var: {"shape": "circle", "label": var.name}, "Observed Random Variable": lambda var: {"shape": "square", "label": var.name}, } model_to_graphviz(schools, node_formatters=node_formatters)