利用卷積神經網絡訓練圖像數據分為以下幾個步驟
1.讀取圖片文件
2.產生用于訓練的批次
3.定義訓練的模型(包括初始化參數,卷積、池化層等參數、網絡)
4.訓練
1 讀取圖片文件
def get_files(filename): class_train = [] label_train = [] for train_class in os.listdir(filename): for pic in os.listdir(filename+train_class): class_train.append(filename+train_class+'/'+pic) label_train.append(train_class) temp = np.array([class_train,label_train]) temp = temp.transpose() #shuffle the samples np.random.shuffle(temp) #after transpose, images is in dimension 0 and label in dimension 1 image_list = list(temp[:,0]) label_list = list(temp[:,1]) label_list = [int(i) for i in label_list] #print(label_list) return image_list,label_list
這里文件名作為標簽,即類別(其數據類型要確定,后面要轉為tensor類型數據)。
然后將image和label轉為list格式數據,因為后邊用到的的一些tensorflow函數接收的是list格式數據。
2 產生用于訓練的批次
def get_batches(image,label,resize_w,resize_h,batch_size,capacity): #convert the list of images and labels to tensor image = tf.cast(image,tf.string) label = tf.cast(label,tf.int64) queue = tf.train.slice_input_producer([image,label]) label = queue[1] image_c = tf.read_file(queue[0]) image = tf.image.decode_jpeg(image_c,channels = 3) #resize image = tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h) #(x - mean) / adjusted_stddev image = tf.image.per_image_standardization(image) image_batch,label_batch = tf.train.batch([image,label], batch_size = batch_size, num_threads = 64, capacity = capacity) images_batch = tf.cast(image_batch,tf.float32) labels_batch = tf.reshape(label_batch,[batch_size]) return images_batch,labels_batch
首先使用tf.cast轉化為tensorflow數據格式,使用tf.train.slice_input_producer實現一個輸入的隊列。
label不需要處理,image存儲的是路徑,需要讀取為圖片,接下來的幾步就是讀取路徑轉為圖片,用于訓練。
CNN對圖像大小是敏感的,第10行圖片resize處理為大小一致,12行將其標準化,即減去所有圖片的均值,方便訓練。
接下來使用tf.train.batch函數產生訓練的批次。
最后將產生的批次做數據類型的轉換和shape的處理即可產生用于訓練的批次。
3 定義訓練的模型
(1)訓練參數的定義及初始化
def init_weights(shape): return tf.Variable(tf.random_normal(shape,stddev = 0.01))#init weightsweights = { "w1":init_weights([3,3,3,16]), "w2":init_weights([3,3,16,128]), "w3":init_weights([3,3,128,256]), "w4":init_weights([4096,4096]), "wo":init_weights([4096,2]) }#init biasesbiases = { "b1":init_weights([16]), "b2":init_weights([128]), "b3":init_weights([256]), "b4":init_weights([4096]), "bo":init_weights([2]) }
新聞熱點
疑難解答