tensorflow模型保存為saver = tf.train.Saver()函數(shù),saver.save()保存模型,代碼如下:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")v2= tf.Variable(tf.zeros([200]), name="v2")saver = tf.train.Saver()with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) saver.save(sess,"checkpoint/model_test",global_step=1)
當(dāng)我們保存模型后,我們可以通過(guò)saver.restore()來(lái)加載模型,初始化變量:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")v2= tf.Variable(tf.zeros([200]), name="v2")saver = tf.train.Saver()with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) saver.restore(sess, "checkpoint/model_test-1") # saver.save(sess,"checkpoint/model_test",global_step=1)
神經(jīng)網(wǎng)絡(luò)訓(xùn)練時(shí),有時(shí)候我們需要從預(yù)訓(xùn)練的模型中加載部分參數(shù),初始化當(dāng)前模型,例如加入CNN有6層,我們需要從已有的模型初始化CNN前5層參數(shù).這可以通過(guò)saver.restore()實(shí)現(xiàn).
之前我們已經(jīng)介紹可以通過(guò)tf.train.Saver()的保存部分變量的方法,即需要保存的變量列表,同樣的,在變量初始化的時(shí)候,我們可以對(duì)需要單獨(dú)初始化的變量分別定義一個(gè)tf.train.Saver()函數(shù),這樣就可以單獨(dú)對(duì)該部分變量初始化,例如下面代碼,saver1用于初始化變量v1,saver2用于初始化變量v2,v3:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")v2= tf.Variable(tf.zeros([200]), name="v2")v3= tf.Variable(tf.zeros([100]), name="v3")#saver = tf.train.Saver()saver1 = tf.train.Saver([v1])saver2 = tf.train.Saver([v2]+[v3])with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) saver1.restore(sess, "checkpoint/model_test-1") saver2.restore(sess, "checkpoint/model_test-1") # saver.save(sess,"checkpoint/model_test",global_step=1)
以上這篇tensorflow 加載部分變量的實(shí)例講解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持VEVB武林網(wǎng)。
新聞熱點(diǎn)
疑難解答
圖片精選