欢迎来到三一文库! | 帮助中心 三一文库31doc.com 一个上传文档投稿赚钱的网站
三一文库
全部分类
  • 研究报告>
  • 工作总结>
  • 合同范本>
  • 心得体会>
  • 工作报告>
  • 党团相关>
  • 幼儿/小学教育>
  • 高等教育>
  • 经济/贸易/财会>
  • 建筑/环境>
  • 金融/证券>
  • 医学/心理学>
  • ImageVerifierCode 换一换
    首页 三一文库 > 资源分类 > DOC文档下载
     

    基于TensorFlow的数据导入机制.doc

    • 资源ID:3416347       资源大小:24.50KB        全文页数:4页
    • 资源格式: DOC        下载积分:2
    快捷下载 游客一键下载
    会员登录下载
    微信登录下载
    三方登录下载: 微信开放平台登录 QQ登录   微博登录  
    二维码
    微信扫一扫登录
    下载资源需要2
    邮箱/手机:
    温馨提示:
    用户名和密码都是您填写的邮箱或者手机号,方便查询和重复下载(系统自动生成)
    支付方式: 支付宝    微信支付   
    验证码:   换一换

    加入VIP免费专享
     
    账号:
    密码:
    验证码:   换一换
      忘记密码?
        
    友情提示
    2、PDF文件下载后,可能会被浏览器默认打开,此种情况可以点击浏览器菜单,保存网页到桌面,就可以正常下载了。
    3、本站不支持迅雷下载,请使用电脑自带的IE浏览器,或者360浏览器、谷歌浏览器下载即可。
    4、本站资源下载后的文档和图纸-无水印,预览文档经过压缩,下载后原文更清晰。
    5、试题试卷类文档,如果标题没有明确说明有答案则都视为没有答案,请知晓。

    基于TensorFlow的数据导入机制.doc

    基于TensorFlow的数据导入机制聊一聊TensorFlow的数据导入机制今天我们要讲的是TensorFlow中的数据导入机制,传统的做法是习惯于先构建好TF图模型,然后开启一个会话(Session),在运行图模型之前将数据feed到图中,这种做法的缺点是数据IO带来的时间消耗很大,那么在训练非常庞大的数据集的时候,不提倡采用这种做法,TensorFlow中取而代之的是tf.data.Dataset模块,今天我们重点介绍这个。tf.data是一个十分强大的可以用于构建复杂的数据导入机制的API,例如,如果你要处理的是图像,那么tf.data可以帮助你把分布在不同位置的文件整合到一起,并且对每幅图片添加微小的随机噪声,以及随机选取一部分图片作为一个batch进行训练;又或者是你要处理文本,那么tf.data可以帮助从文本中解析符号并且转换成embedding矩阵,然后将不同长度的序列变成一个个batch。我们可以用tf.data.Dataset来构建一个数据集,数据集的来源可以有多种方式,例如如果你的数据集是预先以TFRecord格式写在硬盘上的,那么你可以用tf.data.TFRecordDataset来构建;如果你的数据集是内存中的tensor变量,那么可以用tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()来构建。下面我将通过代码来演示它们。首先,我们来看从内存中的tensor变量来构建数据集,如下代码所示,首先构建了一个010的数据集,然后构建迭代器,迭代器可以每次从数据集中提取一个元素:import tensorflow as tf dataset=tf.data.Dataset.range(10) iterator=dataset.make_one_shot_iterator() next_element = iterator.get_next()with tf.Session() as sess:        for _ in range(10):        print(sess.run(next_element)如上代码所示,range()是tf.data.Dataset类的一个静态函数,用于产生一段序列。需要注意的是,构建的数据集需要是同一种数据类型以及内部结构。除此之外,由于range(10)代表09一共十个数,因此,这里的iterator只能运行10次,超过以后将会抛出tf.errors.OutOfRangeError异常。如果希望不抛出异常,则可以调用dataset.repeat(count)即可实现count次自动重复的迭代器。range的范围我们也可以在运行时才确定,即定义max_range为placeholder变量,这个时候需要调用Dataset的make_initializable_iterator方法来构建迭代器,并且这个迭代器的operation需要在迭代之前被运行,代码如下所示:max_range=tf.placeholder(tf.int64, shape=) dataset = tf.data.Dataset.range(max_range) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()with tf.Session() as sess:    sess.run(iterator.initializer, feed_dict=max_range: 10)        for _ in range(10):        print(sess.run(next_element)也可以为不同的数据集创建同一个迭代器,为了使得这个迭代器可以被重复使用,需要保证不同数据集的类型和维度是一致的。例如,下面的代码演示了如何使用同一个迭代器来构建训练集和验证集,可以看到,当我们开始训练训练集的时候,就需要先执行training_init_op,目的是使得迭代器开始加载训练数据;而当进行验证的时候,则需要先执行validation_init_op,道理一样。training_data = tf.data.Dataset.range(100).map(lambda x: x+tf.random_uniform(, -10, 10, tf.int64) validation_data = tf.data.Dataset.range(50) iterator = tf.Iterator.from_structure(training_data.output_types, training_data.output_shapes) iterator = tf.data.Iterator.from_structure(training_data.output_types, training_data.output_shapes) next_element = iterator.get_next() training_init_op=iterator.make_initializer(training_data) validation_init_op=iterator.make_initializer(validation_data)with tf.Session() as sess:        for epoch in range(10):        sess.run(training_init_op)                for _ in range(100):            sess.run(next_element)        sess.run(validation_init_op)                for _ in range(50):            sess.run(next_element)也可以通过Tensor变量构建tf.data.Dataset,如下代码所示,需要注意的是,这里的Tensor的维度是4×10,因此,传入到迭代器中就是可以运行4次,每次运行生成一个长度为10的向量。import tensorflow as tf dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform(4, 10) iterator = dataset.make_initializable_iterator()     next_element = iterator.get_next()with tf.Session() as sess:    sess.run(iterator.initializer)        for i in range(4):      value = sess.run(next_element)      print(value)        最后,还有一种比较常见的读取数据的方式,就是从TFRecord文件中去读取,这里再介绍一下之前在语音识别项目里采取的TFRecord的读写代码。首先是将音频特征写入到TFRecord文件之中,在语音识别中,我们最常用的两个特征就是MFCC和LogFBank,要写入文件中的不仅仅是这两个变量,还要有文本标签Label以及特征序列的长度sequence_legnth,这四个变量中,只有sequence_length是整数标量,其他三个都是列表格式,所以这里对于列表使用字节来保存,而对于标量,使用整型来保存。def _bytes_feature(value):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)def _int64_feature(value):    return tf.train.Feature(int64_list=tf.train.Int64List(value=value)class RecordWriter(object):    def _init_(self):        pass    def write(self, content, tfrecords_filename):        writer = tf.python_io.TFRecordWriter(tfrecords_filename)                if isinstance(content, list):            feature_dict =                        for i in range(len(content):                feature = contenti                                if i=0:                    feature_raw = np.array(feature).tostring()                    feature_dictmfccFeat=_bytes_feature(feature_raw)                                elif i=1:                    feature_raw = np.array(feature).tostring()                    feature_dictlogfbankFeat=_bytes_feature(feature_raw)                                elif i=2:                    feature_raw = np.array(feature).tostring()                    feature_dictlabel=_bytes_feature(feature_raw)                                else:                    feature_dictsequence_length=_int64_feature(feature)            features_to_write = tf.train.Example(features=tf.train.Features(feature=feature_dict)            writer.write(features_to_write.SerializeToString()            writer.close()            print(Record has been writen:+tfrecords_filename)写好TFRecord以后,在读取的时候首先需要对TFRecord格式文件进行解析,解析函数如下:   def parse(self, serialized):        feature_dict=        feature_dictmfccFeat=tf.FixedLenFeature(, tf.string)        feature_dictlogfbankFeat=tf.FixedLenFeature(, tf.string)        feature_dictlabel=tf.FixedLenFeature(, tf.string)        feature_dictsequence_length=tf.FixedLenFeature(1, tf.int64)        features = tf.parse_single_example(            serialized,            features=feature_dict)        mfcc = tf.reshape(tf.decode_raw(featuresmfccFeat, tf.float32), -1, self.feature_num)        logfbank = tf.reshape(tf.decode_raw(featureslogfbankFeat, tf.float32), -1, self.feature_num)        label = tf.decode_raw(featureslabel, tf.int64)            return mfcc, logfbank, label, featuressequence_length然后我们可以直接通过调用tf.data.TFRecordDataset来导入TFRecord文件列表,以及对每个文件调用parse函数进行解析,并且由于每个文件的特征矩阵长度不一,所以需要对齐进行padding操作,最终可以获得迭代器,代码如下:       self.fileNameList = tf.placeholder(tf.string, None, )        padded_shapes= (-1,feature_num,-1,feature_num,-1,1)        padded_values = (0.0,0.0,np.int64(-1),np.int64(0)        dataset = tf.data.TFRecordDataset(self.fileNameList, buffer_size=self.buffer_size).map(self.parse, num_parallel_call).padded_batch(batch_size, padded_shapes, padded_values)        self.iterator = tf.data.Iterator.from_structure(tf.float32, tf.float32, tf.int64, tf.int64),                            (tf.TensorShape(None, None, 60), tf.TensorShape(None, None, 60),                               tf.TensorShape(None, None),  tf.TensorShape(None, None)        self.initializer = self.iterator.make_initializer(dataset)于是,关于TFRecord文件的读写就介绍完了,并且,基于TensorFlow的数据导入机制也介绍完了。

    注意事项

    本文(基于TensorFlow的数据导入机制.doc)为本站会员(白大夫)主动上传,三一文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。 若此文所含内容侵犯了您的版权或隐私,请立即通知三一文库(点击联系客服),我们立即给予删除!

    温馨提示:如果因为网速或其他原因下载失败请重新下载,重复下载不扣分。




    经营许可证编号:宁ICP备18001539号-1

    三一文库
    收起
    展开