python – 使用Tensorflow中的多层感知器模型预测文本标签
我正在按照教程进行操作,可以浏览代码,该代码可以训练神经网络并评估其准确性. 但我不知道如何在新的单个输入(字符串)上使用训练模型来预测其标签. 你能告诉我们如何做到这一点吗? 教程: https://medium.freecodecamp.org/big-picture-machine-learning-classifying-text-with-neural-networks-and-tensorflow-d94036ac2274 会话代码: # Launch the graph with tf.Session() as sess: sess.run(init) # Training cycle for epoch in range(training_epochs): avg_cost = 0. total_batch = int(len(newsgroups_train.data)/batch_size) # Loop over all batches for i in range(total_batch): batch_x,batch_y = get_batch(newsgroups_train,i,batch_size) # Run optimization op (backprop) and cost op (to get loss value) c,_ = sess.run([loss,optimizer],feed_dict={input_tensor: batch_x,output_tensor:batch_y}) # Compute average loss avg_cost += c / total_batch # Display logs per epoch step if epoch % display_step == 0: print("Epoch:",'%04d' % (epoch+1),"loss=", "{:.9f}".format(avg_cost)) print("Optimization Finished!") # Test model correct_prediction = tf.equal(tf.argmax(prediction,1),tf.argmax(output_tensor,1)) # Calculate accuracy accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) total_test_data = len(newsgroups_test.target) batch_x_test,batch_y_test = get_batch(newsgroups_test,total_test_data) print("Accuracy:",accuracy.eval({input_tensor: batch_x_test,output_tensor: batch_y_test})) 我有一些Python的经验,但基本上没有Tensorflow的经验. 解决方法首先,我们需要将文本转换为数组:def text_to_vector(text): layer = np.zeros(total_words,dtype=float) for word in text.split(' '): layer[word2index[word.lower()]] += 1 return layer # Convert text to vector so we can send it to our model vector_txt = text_to_vector(text) # Wrap vector like we do in get_batches() input_array = np.array([vector_txt]) 我们可以保存并加载模型以供重用.我们首先创建一个Saver对象然后保存会话(在训练模型之后): saver = tf.train.Saver() ... train the model ... save_path = saver.save(sess,"/tmp/model.ckpt") 在示例模型中,模型体系结构中的最后一个“步骤”(即在polym_perceptron方法中完成的最后一件事)是: 'out': tf.Variable(tf.random_normal([n_classes])) 因此,为了获得预测,我们得到该数组的最大值的索引(预测类): saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess,"/tmp/model.ckpt") print("Model restored.") classification = sess.run(tf.argmax(prediction,feed_dict={input_tensor: input_array}) print("Predicted category:",classification) 您可以在这里查看整个代码:https://github.com/dmesquita/understanding_tensorflow_nn (编辑:甘南站长网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |
- python – 将参数传递给apscheduler处理函数
- python – 重新分发字典值列表
- python – 根据网络重复边缘更新权重信息
- import pyttsx在python 2.7中工作,但不在python3中
- python – ImportError:Elastic Beanstalk中没有名为djang
- Python基于Tkinter实现的记事本实例
- Django和Elastic Beanstalk URL运行状况检查
- python – 循环通过日期,除了周末
- python – 为什么使用整数作为pymongo的键不起作用?
- 有没有办法让Django的USStateField()没有预先选择的值?