tf.sparse_split()

tf.sparse_split(split_dim, num_split, sp_input, name=None)

Split a SparseTensor into num_split tensors along split_dim.

If the sp_input.shape[split_dim] is not an integer multiple of num_split each slice starting from 0:shape[split_dim] % num_split gets extra one dimension. For example, if split_dim = 1 and num_split = 2 and the input is:

input_tensor = shape = [2, 7]
[    a   d e  ]
[b c          ]

Graphically the output tensors are:

output_tensor[0] =
[    a ]
[b c   ]

output_tensor[1] =
[ d e  ]
[      ]
Args:
  • split_dim: A 0-D int32 Tensor. The dimension along which to split.
  • num_split: A Python integer. The number of ways to split.
  • sp_input: The SparseTensor to split.
  • name: A name for the operation (optional).
Returns:

num_split SparseTensor objects resulting from splitting value.

Raises:
  • TypeError: If sp_input is not a SparseTensor.
doc_TensorFlow
2016-10-14 13:09:16
Comments
Leave a Comment

Please login to continue.