We often need to save the PB file of the TensorFlow model, which is very handy when using the Tf.graph_util.convert_variables_to_constants function.
1. Training Network: fully_conected.py
Import argparse import OS import time import TensorFlow as TF import datasets_mnist # Basic model parameters as external
Flags.
FLAGS = None num_classes = # The mnist images are always 28x28. image_size = Image_pixels = image_size * image_size def placeholder_inputs (batch_size): Images_placeholder = Tf.place Holder (Tf.float32, shape= (batch_size,image_pixels)) Labels_placeholder = Tf.placeholder (Tf.int32, shape= (batch_size ) return Images_placeholder, Labels_placeholder def fill_feed_dict (Data_set, IMAGES_PL, LABELS_PL): Images_feed, L
Abels_feed = Data_set.next_batch (flags.batch_size,flags.fake_data) Feed_dict_value = {Images_pl:images_feed, Labels_pl:labels_feed,} return Feed_dict_value def conv2d (X, W): Return tf.nn.conv2d (x, W, strides=[1, 1, 1, 1) , padding= ' SAME ') def max_pool_2x2 (x): Return Tf.nn.max_pool (x, Ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding= ' SAME ') def inference (images): "" "Build the MnisT model up to where it is used for inference.
Args:images:Images placeholder, from inputs ().
Hidden1_units:size of the The the ' the ' the ' the ' hidden layer.
Hidden2_units:size of the second hidden layer.
Returns:softmax_linear:Output tensor with the computed logits. "" "W_CONV1 = tf. Variable (Tf.truncated_normal ([5, 5, 1, 32],stddev=0.1)) B_conv1 = tf.
Variable (0.0,[32]) X_image = tf.reshape (images, [ -1,28,28,1]) H_conv1 = Tf.nn.relu (conv2d (x_image) + w_conv1) H_pool1 = max_pool_2x2 (h_conv1) w_conv2 = tf. Variable (Tf.truncated_normal ([5, 5, 64],stddev=0.1)) B_conv2 = tf. Variable (0.0,[64]) H_conv2 = Tf.nn.relu (conv2d (h_pool1, w_conv2) + b_conv2) H_pool2 = max_pool_2x2 (h_conv2) W_FC1 = TF. Variable (Tf.truncated_normal ([7 * 7 *, 1024],stddev=0.1)) B_FC1 = tf. Variable (0.0,[1024]) H_pool2_flat = Tf.reshape (H_pool2, [-1, 7*7*64]) H_fc1 = Tf.nn.relu (Tf.matmul (H_pool2_flat, W_fc 1) + b_fc1) # Keep_prob = Tf.placeholder ("float") h_fc1_drop = Tf.nn.dropout (H_FC1, 0.5) W_FC2 = tf. Variable (Tf.truncated_normal ([1024, 10],stddev=0.1)) B_FC2 = tf. Variable (0.0,[10]) Logits=tf.matmul (H_fc1_drop, W_FC2) + B_FC2 return logits def loss (logits, labels): labels = t F.to_int64 (labels) # labels = tf.to_float (labels) # labels= tf.one_hot (labels, ten) cross_entropy = Tf.nn.sparse_softma X_cross_entropy_with_logits (logits=logits,labels=labels, name= ' xentropy ') # Y_conv = Tf.nn.softmax (logits) # Cross_ Entropy =-tf.reduce_sum (Labels*tf.log (Y_conv)) return Tf.reduce_mean (cross_entropy, name= ' Xentropy_mean ') def Traini Ng (loss): Optimizer = Tf.train.AdamOptimizer () Train_op = optimizer.minimize (loss) return Train_op def evaluation (Logits, labels): Labels1=tf.one_hot (labels,10) correct = tf.nn.in_top_k (logits, labels, 1) correct1 = tf.equal (tf . Argmax (logits,1), Tf.argmax (labels1,1)) return Tf.reduce_mean (Tf.cast (correct, Tf.float32)), Tf.reduce_mean ( Tf.cast (Correct1, Tf.float32)) def Do_eval (sESS, Eval_correct, Images_placeholder, Labels_placeholder, data_set): t
Rue_count = 0 # Counts The number of correct predictions. Steps_per_epoch = data_set.num_examples//Flags.batch_size num_examples = Steps_per_epoch * FLAGS.batch_size for step In range (Steps_per_epoch): Feed_dict_value = Fill_feed_dict (Data_set, Images_placehold
Er, labels_placeholder) True_count + + sess.run (eval_correct, Feed_dict=feed_dict_value)
Precision = Float (true_count)/num_examples print (' num examples:%d num correct:%d precision @ 1:%0.04f '%
(Num_examples, True_count, precision)) def run_training (logits,labels_placeholder): If Tf.gfile.Exists (flags.log_dir): tf.gfile.DeleteRecursively (FLAGS.L Og_dir) Tf.gfile.MakeDirs (flags.log_dir) loss_value= loss (logits, Labels_placeholder) tf.summary.scalar (' Loss ', Los S_value) Train_op = Training (loss_vaLue) Eval_correct,eval_correct1 = Evaluation (logits, Labels_placeholder) tf.summary.scalar (' precision ', eval_correct ) Summary = Tf.summary.merge_all () init = Tf.global_variables_initializer () Sess = tf. Session () Summary_writer = Tf.summary.FileWriter (Flags.log_dir, sess.graph) Sess.run (init) for step in range (FLAG S.max_steps): Start_time = Time.time () Feed_dict_value = Fill_feed_dict (train, image S_placeholder, Labels_placeholder) _, Loss_value1,eval_correct_value,eval_correct_value 1 = Sess.run ([Train_op, loss_value,eval_correct,eval_correct],feed_dict=feed_dict_value) duration = Time.time ()-STA Rt_time if step% = = 0:print (' Step%d:loss =%.2f (%.3f sec), precision=%.3f,%.3f '% "(step, Loss_value1, duration,eval_correct_value,eval_correct_value1)) Summary_str = Sess.run (summary, Feed_dict=feed_dict_value) s
Ummary_writer.add_summary (SUMMARY_STR, Step)Summary_writer.flush () # Save a checkpoint and evaluate the model periodically.
if (step + 1)% 1000 = = 0 or (step + 1) = = FLAGS.max_steps:saver.save (Sess, Checkpoint_file, Global_step=step)
# Evaluate against the training set. Print (' Training Data Eval: ') do_eval (Sess, Eval_correct, Images_placeholder, lab
Els_placeholder, train) # Evaluate against the validation set. Print (' Validation Data Eval: ') do_eval (Sess, Eval_correct, Images_placeholder, L
Abels_placeholder, validation) # Evaluate against the test set. Print (' Test Data Eval: ') do_eval (Sess, Eval_correct, Images_placeholder, Labels_ Placeholder, test) def run_testing (): SESS=TF. Session () Saver.restore (Sess, Tf.train.latest_checkpoint (' ckpt ')) feed_dict_value=fill_feed_dict (Test,images_ Placeholder,labels_placeholder) A,Accuracy=evaluation (Logits,labels_placeholder) accuracy_=sess.run (a,feed_dict=feed_dict_value) print (' accuracy is %f '%accuracy_) if __name__ = = ' __main__ ': parser = argparse. Argumentparser () parser.add_argument ('--max_steps ', Type=int, default=500, help= ' number of step
S to run trainer. ' ) parser.add_argument ('--batch_size ', Type=int, default=100, help= ' batch size.
Must divide evenly into the dataset sizes. ' ) parser.add_argument ('--input_data_dir ', Type=str, Default=os.path.join (' Datasets '), help= ' dir
Ectory to put the input data. ' ) parser.add_argument ('--log_dir ', Type=str, Default=os.path.join (' log '), help= ' Directory to Pu
t the log data. ' ) parser.add_argument ('--fake_data ', Default=false, help= ' If true, uses fake data for unit testing. ')
, action= ' store_true ') parser.add_argument ('--train ', Type=bool, default=true ) parser.add_argument ('--test ', type=bool,default=true) FLAGS, unparsed = Parser.parse_known_args () Checkpoint_file = Os.path.join (' log ', ' Model.ckpt ') train, validation,test = Datasets_mnist.read_data_sets (flags.inp
Ut_data_dir, Flags.fake_data) # Generate placeholders for the images and labels. Images_placeholder, Labels_placeholder = Placeholder_inputs (flags.batch_size) # Build a Graph that computes predic
tions from the inference model. Logits = Inference (images_placeholder) saver = Tf.train.Saver () if FLAGS.train:run_training (Logits,labels_placeho
Lder) # exit (' training finished ') # run_testing ()
2. Export PB file: export.py
Import fully_conected as Model
import TensorFlow as TF
def export_graph (model_name):
graph = tf. Graph ()
with Graph.as_default ():
input_image = Tf.placeholder (Tf.float32, shape=[none,28*28), Name= ' Inputdata ')
logits = model.inference (input_image)
Y_conv = Tf.nn.softmax (logits,name= ' outputdata ')
Restore_saver = Tf.train.Saver () with
TF. Session (Graph=graph) as Sess:
Sess.run (Tf.global_variables_initializer ())
latest_ckpt = Tf.train.latest_ Checkpoint (' log ')
Restore_saver.restore (Sess, latest_ckpt)
output_graph_def = Tf.graph_util.convert_ Variables_to_constants (
sess, Graph.as_graph_def (), [' Outputdata '])
# tf.train.write_graph (output _graph_def, ' log ', Model_name, As_text=false) with
tf.gfile.GFile (' LOG/MNIST.PB ', "WB") as F:
F.write ( Output_graph_def. Serializetostring ())
export_graph (' MNIST.PB ')
3. Test: test.py
From __future__ import Absolute_import, unicode_literals from datasets_mnist import read_data_sets import TensorFlow as TF Train,validation,test = Read_data_sets ("datasets/", one_hot=true) with TF. Graph (). As_default (): Output_graph_def = tf. Graphdef () Output_graph_path = ' LOG/MNIST.PB ' # sess.graph.add_to_collection ("input", mnist.test.images) with Open (Output_graph_path, "RB") as F:output_graph_def. Parsefromstring (F.read ()) Tf.import_graph_def (Output_graph_def, Name= "") with TF. Session () as Sess:tf.initialize_all_variables (). Run () input_x = Sess.graph.get_tensor_by_name ("Inputdata : 0 ") output = Sess.graph.get_tensor_by_name (" outputdata:0 ") y_conv_2 = Sess.run (output,{input_x : Test.images}) Print ("Y_conv_2", y_conv_2) # test trained model #y__2 = Tf.placeholder ("float",
[None, 10]) Y__2 = Test.labels correct_prediction_2 = tf.equal (Tf.argmax (y_conv_2, 1), TF.ARGMAx (y__2, 1)) print ("Correct_prediction_2", correct_prediction_2) accuracy_2 = Tf.reduce_mean (Tf.cast (cor Rect_prediction_2, "float") print ("Accuracy_2", accuracy_2) print ("Check accuracy%g"% Accuracy_2.eva L ())
4. The data used here is mnist data, the code is: datasets_mnist.py
Import gzip import OS import numpy from six.moves import xrange # Pylint:disable=redefined-builtin from Tensorflow.cont Rib.learn.python.learn.datasets Import Base from tensorflow.python.framework import Dtypes Tensorflow.python.framework Import Random_seed # CVDF Mirror of Http://yann.lecun.com/exdb/mnist/SOURCE_URL = ' https:// storage.googleapis.com/cvdf-datasets/mnist/' Def _read32 (bytestream): dt = Numpy.dtype (Numpy.uint32). Newbyteorder (' > ') return Numpy.frombuffer (Bytestream.read (4), DTYPE=DT) [0] def extract_images (f): "" "extract the images into a
4D uint8 numpy Array [index, y, x, depth].
Args:f: A file object that can is passed into a gzip reader.
Returns:data:a 4D uint8 numpy Array [index, y, x, depth].
Raises:ValueError:If The ByteStream does not start with 2051. "" "Print (' extracting ', f.name) with Gzip. Gzipfile (fileobj=f) as Bytestream:magic = _read32 (ByteStream) if Magic!= 2051:raise (' ValueErrorMagic number%d in Mnist image file:%s '% (Magic, f.name)) Num_images = _read32 (ByteStream) rows = _read32 (bytestream) cols = _read32 (bytestream) buf = bytestream.read (Rows * cols * num_images) data = Numpy.frombuffer (buf, dtype=numpy.uint8) data = Data.reshape (num_images, rows, cols, 1) return Data def dense_t
O_one_hot (Labels_dense, num_classes): "" "Convert class labels from scalars to one-hot vectors." " Num_labels = labels_dense.shape[0] Index_offset = Numpy.arange (num_labels) * num_classes labels_one_hot = Numpy.zeros (
(Num_labels, num_classes)) Labels_one_hot.flat[index_offset + labels_dense.ravel ()] = 1 return labels_one_hot # def extract_labels (F, one_hot=Fals
E, num_classes=10): "" "Extract the labels into a 1D uint8 numpy array [index].
Args:f: A file object that can is passed into a gzip reader.
One_hot:does one hot encoding to the result. Num_classes:number of classes for the one hot EncodinG. returns:labels:a 1D uint8 numpy Array.
Raises:ValueError:If the Bystream doesn ' t start with 2049. "" "Print (' extracting ', f.name) with Gzip. Gzipfile (fileobj=f) as Bytestream:magic = _read32 (ByteStream) if Magic!= 2049:raise valueerror (' Invalid m
Agic number%d in mnist label file:%s '% (Magic, f.name)) Num_items = _read32 (ByteStream) BUF = Bytestream.read (num_items) labels = Numpy.frombuffer (buf, dtype=numpy.uint8) if One_hot:return dense_
To_one_hot (labels, num_classes) return labels # class DataSet (object): Def __init__ (self, images,
Labels, fake_data=false, One_hot=false, Dtype=dtypes.float32,
Reshape=true, Seed=none): "" "Construct a DataSet. One_hot arg is used only if Fake_data is true. ' Dtype ' can be either ' uint8 ' to leave the input as ' [0, 255] ', or ' float32 ' to RescAle into ' [0, 1] '.
Seed ARG provides for convenient deterministic testing. "" "seed1, Seed2 = Random_seed.get_seed (Seed) # If op level seed isn't set, use whatever graph level seed is Retu rned numpy.random.seed (seed1 If seed is None else seed2) Dtype = Dtypes.as_dtype (dtype). Base_dtype If Dtype No
T in (Dtypes.uint8, Dtypes.float32): Raise TypeError (' Invalid image Dtype%r, expected uint8 or float32 '% Dtype) If Fake_data:self._num_examples = 10000 Self.one_hot = One_hot Else:assert
Images.shape[0] = = Labels.shape[0], (' Images.shape:%s labels.shape:%s '% (Images.shape, labels.shape)) Self._num_examples = images.shape[0] # Convert shape from [num examples, rows, columns, depth] # to [Num exa Mples, Rows*columns] (assuming depth = = 1) if Reshape:assert images.shape[3] = = 1 Images = IMAGES.R Eshape (Images.shape[0], images. shape[1] * images.shape[2]) if Dtype = = Dtypes.float32: # Convert from [0, 255]-> [0.0, 1.0]. Images = Images.astype (numpy.float32) images = numpy.multiply (images, 1.0/255.0) self._images = images S Elf._labels = Labels self._epochs_completed = 0 Self._index_in_epoch = 0 @property def images (self): Retu RN Self._images @property def labels (self): return self._labels @property def num_examples (self): return Self._num_examples @property def epochs_completed (self): return self._epochs_completed def next_batch (self, b
Atch_size, Fake_data=false, shuffle=true): "" "Return to the Next ' batch_size ' examples from this data set." " If fake_data:fake_image = [1] * 784 if Self.one_hot:fake_label = [1] + [0] * 9 else:f
Ake_label = 0 Return [fake_image to _ in Xrange (batch_size)], [Fake_label to _ in Xrange (batch_size) ] Start = self._index_in_epoch # Shuffle for the ' the ' the ' the ' the ' the ' the ' the ' the ' epoch if self._epochs_completed = 0 and start = 0 Y.arange (self._num_examples) numpy.random.shuffle (perm0) self._images = Self.images[perm0] Self._labels
= Self.labels[perm0] # go to the next epoch if start + batch_size > Self._num_examples: # Finished Epoch self._epochs_completed + + 1 # Get the rest examples into this epoch Rest_num_examples = Self._num_examples -Start Images_rest_part = self._images[start:self._num_examples] Labels_rest_part = self._labels[start:self. _num_examples] # Shuffle the data if shuffle:perm = Numpy.arange (self._num_examples) Numpy.ra Ndom.shuffle (perm) self._images = self.images[perm] Self._labels = self.labels[perm] # Start Next EP och start = 0 Self._index_in_epoch = Batch_size-rest_num_examples end = Self._index_in_epoch ima Ges_new_part = Self._imaGes[start:end] Labels_new_part = Self._labels[start:end] return numpy.concatenate (Images_rest_part, Images_ne W_part), axis=0), Numpy.concatenate ((Labels_rest_part, Labels_new_part), axis=0) Else:self._index_in_epoch = Batch_size end = Self._index_in_epoch return self._images[start:end], Self._labels[start:end] def Read_data_ Sets (Train_dir, Fake_data=false, One_hot=false, Dtype=dtypes.floa T32, Reshape=true, validation_size=5000, Seed=none): If Fake_da
Ta:def fake (): Return DataSet ([], [], Fake_data=true, One_hot=one_hot, Dtype=dtype, Seed=seed) Train = fake () validation = fake () test = fake () return base. Datasets (Train=train, Validation=validation, test=test) # train_images = ' train-images-idx3-ubyte.gz ' # TRAIN_LABELS = ' train-labels-idx1-ubyte.gz ' # test_images = ' t10k-images-idx3-ubyte.gz' # test_labels = ' t10k-labels-idx1-ubyte.gz ' local_file = Os.path.join (' Datasets ', ' train-images-idx3-ubyte.gz ') wit H Open (Local_file, ' RB ') as F:train_images = Extract_images (f) local_file = Os.path.join (' Datasets ', ' train-labels- Idx1-ubyte.gz ') with open (Local_file, ' RB ') as F:train_labels = Extract_labels (f, one_hot=one_hot) Local_file = Os.path.join (' Datasets ', ' t10k-images-idx3-ubyte.gz ') with open (Local_file, ' RB ')