tf.contrib.bayesflow.stochastic_tensor.StochasticTensor.__init__()

tf.contrib.bayesflow.stochastic_tensor.StochasticTensor.__init__(dist_cls, name=None, dist_value_type=None, loss_fn=score_function, **dist_args)

Construct a StochasticTensor.

StochasticTensor will instantiate a distribution from dist_cls and dist_args and its value method will return the same value each time it is called. What value is returned is controlled by the dist_value_type (defaults to SampleAndReshapeValue).

Some distributions' sample functions are not differentiable (e.g. a sample from a discrete distribution like a Bernoulli) and so to differentiate wrt parameters upstream of the sample requires a gradient estimator like the score function estimator. This is accomplished by passing a differentiable loss_fn to the StochasticTensor, which defaults to a function whose derivative is the score function estimator. Calling stochastic_graph.surrogate_loss(final_losses) will call loss() on every StochasticTensor upstream of final losses.

loss() will return None for StochasticTensors backed by reparameterized distributions; it will also return None if the value type is MeanValueType or if loss_fn=None.

Args:
  • dist_cls: a Distribution class.
  • name: a name for this StochasticTensor and its ops.
  • dist_value_type: a _StochasticValueType, which will determine what the value of this StochasticTensor will be. If not provided, the value type set with the value_type context manager will be used.
  • loss_fn: callable that takes (dt, dt.value(), influenced_loss), where dt is this StochasticTensor, and returns a Tensor loss. By default, loss_fn is the score_function, or more precisely, the integral of the score function, such that when the gradient is taken, the score function results. See the stochastic_gradient_estimators module for additional loss functions and baselines.
  • **dist_args: keyword arguments to be passed through to dist_cls on construction.
Raises:
  • TypeError: if dist_cls is not a Distribution.
  • TypeError: if loss_fn is not callable.
doc_TensorFlow
2016-10-14 12:44:25
Comments
Leave a Comment

Please login to continue.