国产探花免费观看_亚洲丰满少妇自慰呻吟_97日韩有码在线_资源在线日韩欧美_一区二区精品毛片,辰东完美世界有声小说,欢乐颂第一季,yy玄幻小说排行榜完本

首頁(yè) > 編程 > Python > 正文

tensorflow saver 保存和恢復(fù)指定 tensor的實(shí)例講解

2020-01-04 14:46:05
字體:
來(lái)源:轉(zhuǎn)載
供稿:網(wǎng)友

在實(shí)踐中經(jīng)常會(huì)遇到這樣的情況:

1、用簡(jiǎn)單的模型預(yù)訓(xùn)練參數(shù)

2、把預(yù)訓(xùn)練的參數(shù)導(dǎo)入復(fù)雜的模型后訓(xùn)練復(fù)雜的模型

這時(shí)就產(chǎn)生一個(gè)問題:

如何加載預(yù)訓(xùn)練的參數(shù)。

下面就是我的總結(jié)。

為了方便說(shuō)明,做一個(gè)假設(shè):簡(jiǎn)單的模型只有一個(gè)卷基層,復(fù)雜模型有兩個(gè)。

卷積層的實(shí)現(xiàn)代碼如下:

import tensorflow as tf# PS:本篇的重?fù)?dān)是saver,不過為了方便閱讀還是說(shuō)明下參數(shù)# 參數(shù)# name:創(chuàng)建卷基層的代碼這么多,必須要函數(shù)化,而為了防止變量沖突就需要用tf.name_scope# input_data:輸入數(shù)據(jù)# width, high:卷積小窗口的寬、高# deep_before, deep_after:卷積前后的神經(jīng)元數(shù)量# stride:卷積小窗口的移動(dòng)步長(zhǎng)def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'): global parameters with tf.name_scope(name) asscope:  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)  bias = tf.add(conv,biases)  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己寫的batchnorm函數(shù)  conv =tf.maximum(0.1*bias, bias)  return conv

簡(jiǎn)單的預(yù)訓(xùn)練模型就下面一句話

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

復(fù)雜的模型是兩個(gè)卷基層,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)pool1= make_max_pool('layer1-pool1', conv1, 2, 2)conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

這時(shí)簡(jiǎn)簡(jiǎn)單單的在預(yù)訓(xùn)練模型中:

saver = tf.train.Saver()with tf.Session() as sess:saver.save(sess,'model.ckpt')

就不行了,因?yàn)椋?/p>

1,如果你在預(yù)訓(xùn)練模型中使用下面的話打印所有tensor

all_v =tf.global_variables()for i in all_v: print i

會(huì)發(fā)現(xiàn)tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref><tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref><tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref><tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref><tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref><tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在復(fù)雜模型中就是complex-conv1/weights和complex-conv1/biases,這是對(duì)不上號(hào)的。

2,預(yù)訓(xùn)練模型中只有1個(gè)卷積層,而復(fù)雜模型中有兩個(gè),而tensorflow默認(rèn)會(huì)從模型文件('model.ckpt')中找所有的“可訓(xùn)練的”tensor,找不到會(huì)報(bào)錯(cuò)。

解決方法:

1,在預(yù)訓(xùn)練模型中定義全局變量

parm_dict={}

并在“return conv”上面添加下面兩行

parm_dict['complex-conv1/weights']= weightsparm_dict['complex-conv1/']= biases

然后在定義saver時(shí)使用下面這句話:

saver= tf.train.Saver(parm_dict)

這樣保存后的模型文件就對(duì)應(yīng)到復(fù)雜模型上了。

2,在復(fù)雜模型中定義全局變量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判斷如果是第二個(gè)卷積層就不更新parameters。

接著在定義saver時(shí)使用下面這句話:

saver= tf.train.Saver(parameters)

這樣就可以告訴saver,只需要從模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3統(tǒng)統(tǒng)滾一邊去(上面紅色部分)。

最后使用下面的代碼加載就可以了

with tf.Session() as sess: ckpt= tf.train.get_checkpoint_state('.') if ckpt and ckpt.model_checkpoint_path:  saver.restore(sess,ckpt.model_checkpoint_path) else:  print ' no saver.'  exit()     

以上這篇tensorflow saver 保存和恢復(fù)指定 tensor的實(shí)例講解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持VEVB武林網(wǎng)。


注:相關(guān)教程知識(shí)閱讀請(qǐng)移步到python教程頻道。
發(fā)表評(píng)論 共有條評(píng)論
用戶名: 密碼:
驗(yàn)證碼: 匿名發(fā)表
主站蜘蛛池模板: 阜新| 三原县| 奉节县| 晋州市| 双城市| 龙门县| 彩票| 辰溪县| 龙州县| 灵石县| 两当县| 宁陕县| 望城县| 郓城县| 化德县| 尉犁县| 娄底市| 寻乌县| 天峻县| 顺昌县| 广西| 乐亭县| 横峰县| 剑阁县| 五寨县| 呼玛县| 仙桃市| 雷波县| 高陵县| 山丹县| 永靖县| 西充县| 余干县| 收藏| 图木舒克市| 抚州市| 古田县| 甘孜| 鄂托克前旗| 南川市| 阿勒泰市|