tf.contrib.learn.Estimator.__init__()

tf.contrib.learn.Estimator.__init__(model_fn=None, model_dir=None, config=None, params=None, feature_engineering_fn=None)

Constructs an Estimator instance.

Args:
  • model_fn: Model function, takes features and targets tensors or dicts of tensors and returns predictions and loss tensors. Supports next three signatures for the function:

    • (features, targets) -> (predictions, loss, train_op)
    • (features, targets, mode) -> (predictions, loss, train_op)
    • (features, targets, mode, params) -> (predictions, loss, train_op)

    Where

    • features are single Tensor or dict of Tensors (depending on data passed to fit),
    • targets are Tensor or dict of Tensors (for multi-head models). If mode is ModeKeys.INFER, targets=None will be passed. If the model_fn's signature does not accept mode, the model_fn must still be able to handle targets=None.
    • mode represents if this training, evaluation or prediction. See ModeKeys.
    • params is a dict of hyperparameters. Will receive what is passed to Estimator in params parameter. This allows to configure Estimators from hyper parameter tunning.
  • model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.

  • config: Configuration object.

  • params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.

  • feature_engineering_fn: Feature engineering function. Takes features and targets which are the output of input_fn and returns features and targets which will be fed into model_fn. Please check model_fn for a definition of features and targets.

Raises:
  • ValueError: parameters of model_fn don't match params.
doc_TensorFlow
2016-10-14 13:05:46
Comments
Leave a Comment

Please login to continue.