【问题标题】:TypeError when training Tensorflow Random Forest using TensorForestEstimator使用 TensorForestEstimator 训练 Tensorflow 随机森林时出现 TypeError
【发布时间】:2017-12-18 12:38:07
【问题描述】:

尝试使用 TensorForestEstimator 训练 Tensorflow 随机森林时出现 TypeError。

TypeError: Input 'input_data' of 'CountExtremelyRandomStats' Op has type float64 that does not match expected type of float32.

我尝试过使用 Python 2.7 和 Python 3,也尝试过使用 tf.cast() 将所有内容放入 float32 中,但没有帮助。我检查了执行时的数据类型,它是 float32。问题似乎不是我提供的数据(所有浮点数的 csv),所以我不确定从这里去哪里。

任何我可以尝试的事情的建议将不胜感激。

代码:

# Build an estimator.
def build_estimator(model_dir):
  params = tensor_forest.ForestHParams(
      num_classes=2, num_features=760,
      num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
  graph_builder_class = tensor_forest.RandomForestGraphs
  if FLAGS.use_training_loss:
    graph_builder_class = tensor_forest.TrainingLossForest
  # Use the SKCompat wrapper, which gives us a convenient way to split in-memory data into batches.
  return estimator.SKCompat(random_forest.TensorForestEstimator(params, graph_builder_class=graph_builder_class, model_dir=model_dir))


# Train and evaluate the model.
def train_and_eval():

  # load datasets
  training_set = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_train.csv', dtype=np.float32, header=None)
  test_set = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_test.csv', dtype=np.float32, header=None)

  print('###########')
  print(training_set.loc[:,1].dtype)  # this prints float32

  # load labels
  training_labels = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_train_class.csv', dtype=np.int32, names=LABEL, header=None)
  test_labels = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_test_class.csv', dtype=np.int32, names=LABEL, header=None)

  # define the path where the model will be stored - default is current directory
  model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
  print('model directory = %s' % model_dir)

  # build the random forest estimator
  est = build_estimator(model_dir)

  tf.cast(training_set, tf.float32) #error occurs with/without casts
  tf.cast(test_set, tf.float32)
  # train the forest to fit the training data
  est.fit(x=training_set, y=training_labels)  #this line throws the error

【问题讨论】:

    标签: python python-3.x machine-learning tensorflow random-forest


    【解决方案1】:

    您使用 tf.cast 的方式不正确

    tf.cast(training_set, tf.float32) #error occurs with/without casts
    

    应该是

    training_set = tf.cast(training_set, tf.float32)
    

    tf.cast 不是 in-place 方法,它是张量流操作,与任何其他操作一样,需要分配和运行。

    【讨论】:

    猜你喜欢
    • 2018-06-13
    • 2017-01-22
    • 2012-10-25
    • 1970-01-01
    • 2019-07-20
    • 1970-01-01
    • 1970-01-01
    • 2015-12-11
    • 2018-12-06
    相关资源
    最近更新 更多