Source code for pymc.variational.stein

#   Copyright 2024 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import pytensor.tensor as pt

from pytensor.graph.replace import graph_replace

from pymc.pytensorf import floatX
from pymc.util import WithMemoization, locally_cachedmethod
from pymc.variational.opvi import node_property
from pymc.variational.test_functions import rbf

__all__ = ["Stein"]


[docs] class Stein(WithMemoization):
[docs] def __init__(self, approx, kernel=rbf, use_histogram=True, temperature=1): self.approx = approx self.temperature = floatX(temperature) self._kernel_f = kernel self.use_histogram = use_histogram
@property def input_joint_matrix(self): if self.use_histogram: return self.approx.joint_histogram else: return self.approx.symbolic_random @node_property def approx_symbolic_matrices(self): if self.use_histogram: return self.approx.collect("histogram") else: return self.approx.symbolic_randoms @node_property def dlogp(self): logp = self.logp_norm.sum() grad = pt.grad(logp, self.approx_symbolic_matrices) def flatten2(tensor): return tensor.flatten(2) return pt.concatenate(list(map(flatten2, grad)), -1) @node_property def grad(self): n = floatX(self.input_joint_matrix.shape[0]) temperature = self.temperature svgd_grad = self.density_part_grad / temperature + self.repulsive_part_grad return svgd_grad / n @node_property def density_part_grad(self): Kxy = self.Kxy dlogpdx = self.dlogp return pt.dot(Kxy, dlogpdx) @node_property def repulsive_part_grad(self): t = self.approx.symbolic_normalizing_constant dxkxy = self.dxkxy return dxkxy / t @property def Kxy(self): return self._kernel()[0] @property def dxkxy(self): return self._kernel()[1] @node_property def logp_norm(self): sized_symbolic_logp = self.approx.sized_symbolic_logp if self.use_histogram: sized_symbolic_logp = graph_replace( sized_symbolic_logp, dict(zip(self.approx.symbolic_randoms, self.approx.collect("histogram"))), strict=False, ) return sized_symbolic_logp / self.approx.symbolic_normalizing_constant @locally_cachedmethod def _kernel(self): return self._kernel_f(self.input_joint_matrix)