基于TensorFlow的数据导入机制.doc
《基于TensorFlow的数据导入机制.doc》由会员分享,可在线阅读,更多相关《基于TensorFlow的数据导入机制.doc(4页珍藏版)》请在三一文库上搜索。
1、基于TensorFlow的数据导入机制聊一聊TensorFlow的数据导入机制今天我们要讲的是TensorFlow中的数据导入机制,传统的做法是习惯于先构建好TF图模型,然后开启一个会话(Session),在运行图模型之前将数据feed到图中,这种做法的缺点是数据IO带来的时间消耗很大,那么在训练非常庞大的数据集的时候,不提倡采用这种做法,TensorFlow中取而代之的是tf.data.Dataset模块,今天我们重点介绍这个。tf.data是一个十分强大的可以用于构建复杂的数据导入机制的API,例如,如果你要处理的是图像,那么tf.data可以帮助你把分布在不同位置的文件整合到一起,并且对
2、每幅图片添加微小的随机噪声,以及随机选取一部分图片作为一个batch进行训练;又或者是你要处理文本,那么tf.data可以帮助从文本中解析符号并且转换成embedding矩阵,然后将不同长度的序列变成一个个batch。我们可以用tf.data.Dataset来构建一个数据集,数据集的来源可以有多种方式,例如如果你的数据集是预先以TFRecord格式写在硬盘上的,那么你可以用tf.data.TFRecordDataset来构建;如果你的数据集是内存中的tensor变量,那么可以用tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tenso
3、r_slices()来构建。下面我将通过代码来演示它们。首先,我们来看从内存中的tensor变量来构建数据集,如下代码所示,首先构建了一个010的数据集,然后构建迭代器,迭代器可以每次从数据集中提取一个元素:import tensorflow as tf dataset=tf.data.Dataset.range(10) iterator=dataset.make_one_shot_iterator() next_element = iterator.get_next()with tf.Session() as sess: for _ in range(10): print(sess.run(
4、next_element)如上代码所示,range()是tf.data.Dataset类的一个静态函数,用于产生一段序列。需要注意的是,构建的数据集需要是同一种数据类型以及内部结构。除此之外,由于range(10)代表09一共十个数,因此,这里的iterator只能运行10次,超过以后将会抛出tf.errors.OutOfRangeError异常。如果希望不抛出异常,则可以调用dataset.repeat(count)即可实现count次自动重复的迭代器。range的范围我们也可以在运行时才确定,即定义max_range为placeholder变量,这个时候需要调用Dataset的make_i
5、nitializable_iterator方法来构建迭代器,并且这个迭代器的operation需要在迭代之前被运行,代码如下所示:max_range=tf.placeholder(tf.int64, shape=) dataset = tf.data.Dataset.range(max_range) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()with tf.Session() as sess: sess.run(iterator.initializer, feed_dic
6、t=max_range: 10) for _ in range(10): print(sess.run(next_element)也可以为不同的数据集创建同一个迭代器,为了使得这个迭代器可以被重复使用,需要保证不同数据集的类型和维度是一致的。例如,下面的代码演示了如何使用同一个迭代器来构建训练集和验证集,可以看到,当我们开始训练训练集的时候,就需要先执行training_init_op,目的是使得迭代器开始加载训练数据;而当进行验证的时候,则需要先执行validation_init_op,道理一样。training_data = tf.data.Dataset.range(100).map(l
7、ambda x: x+tf.random_uniform(, -10, 10, tf.int64) validation_data = tf.data.Dataset.range(50) iterator = tf.Iterator.from_structure(training_data.output_types, training_data.output_shapes) iterator = tf.data.Iterator.from_structure(training_data.output_types, training_data.output_shapes) next_elemen
8、t = iterator.get_next() training_init_op=iterator.make_initializer(training_data) validation_init_op=iterator.make_initializer(validation_data)with tf.Session() as sess: for epoch in range(10): sess.run(training_init_op) for _ in range(100): sess.run(next_element) sess.run(validation_init_op) for _
9、in range(50): sess.run(next_element)也可以通过Tensor变量构建tf.data.Dataset,如下代码所示,需要注意的是,这里的Tensor的维度是410,因此,传入到迭代器中就是可以运行4次,每次运行生成一个长度为10的向量。import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform(4, 10) iterator = dataset.make_initializable_iterator() next_element = iterator
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- 基于 TensorFlow 数据 导入 机制
链接地址:https://www.31doc.com/p-3416347.html