機(jī)器翻譯中的seq2seq的模型框架及相應(yīng)參數(shù);
Seq2SeqModel(source_vocab_size, target_vocab_size, buckets, size, num_layers, max_gradient_norm, batch_size, learning_rate, learning_rate_decay_factor, use_lstm=False, num_samples=512, forward_only=False, dtype=tf.float32)參數(shù)詳解: source_vocab_size,在序列到序列的任務(wù)中,訓(xùn)練數(shù)據(jù)的源數(shù)據(jù)的詞匯表大小;如序列對(duì)(A,B)中A的大小 target_vocab_size,同上目標(biāo)詞匯表B的大小 buckets,為了解決不同長(zhǎng)度輸入而設(shè)定的,如[(5,10),(10,15),(15,20),(20,40)],如輸入長(zhǎng)度為9時(shí),選擇(10,15)的范圍; size,某一層的單元數(shù) num_layers,網(wǎng)絡(luò)層數(shù) max_gradient_norm,表示梯度最大限度的被削減到這個(gè)規(guī)范 batch_size,每批讀取數(shù)據(jù)數(shù) learning_rate,學(xué)習(xí)率 learning_rate_decay_factor,學(xué)習(xí)率衰減因子 use_lstm=False,使用lstm嗎?GRU num_samples=512,采樣softmax的個(gè)數(shù),當(dāng)個(gè)數(shù)小于詞匯表時(shí)才有意義; forward_only=False,是否更新參數(shù)
Sequence-to-Sequence中的一些重要的函數(shù):(model內(nèi)部的)
model.get_batch(self,data,bucket_id):該函數(shù)返回的是batch_encoder_inputs,batch_decoder_inputs,batch_weights三個(gè)參數(shù)encoder_size, decoder_size = self.buckets[bucket_id]將輸入與輸出補(bǔ)成同encoder_size, decoder_size大小一樣的尺寸。encoder_inputs, # Encoder inputs are padded and then reversed.(不夠補(bǔ)0)decoder_inputs,#留一個(gè)給GO_ID,其余類(lèi)似decoder_inputsweight:對(duì)于補(bǔ)充的數(shù),weight值為0,原有的為1batch_encoder_inputs,batch_size大小一批的輸入batch_decoder_inputs,batch_size大小的一批輸出batch_weights,尺寸同batch_decoder_inputsdef step(self, session, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only):該函數(shù)返回三個(gè)參數(shù)值(Gradient norm, loss, outputs) session, 會(huì)話狀態(tài) 以下三個(gè)參數(shù)來(lái)源于get_batch的返回值 encoder_inputs, decoder_inputs, target_weights, bucket_id, #使用那個(gè)bucket,會(huì)被指定 forward_only,forward_only當(dāng)為T(mén)rue時(shí),返回# No gradient norm, loss, outputs;False時(shí),返回Gradient norm, loss, no outputs.
新聞熱點(diǎn)
疑難解答
圖片精選
網(wǎng)友關(guān)注