TensorFlow implementation of capsule network (capsule network)

Source: Internet
Author: User
Tags parent directory git clone

Now we all know that Geoffrey Hinton's capsule Network (capsule network) shook the entire AI field, pushing the limits of convolution neural networks (CNN) to a new level. There are already a lot of posts, articles and research papers on the web that discuss the theory of capsule networks and how it does better than the traditional CNN. So I'm not going to introduce the content, but try to use Google's Colaboratory tool to implement Cpnet on TensorFlow.

You can learn the theoretical part of cpnet by following a few links: Geoffrey Hinton's speech: "What is the problem with convolution neural networks?" "The capsule network is shaking the dynamic route between the artificial intelligence field capsules

Now we start writing code.

Before you begin, you can refer to my Colab notebook to execute the following code:

Colab URL: Https://goo.gl/43Jvju

Now clone the warehouse on the GitHub and install the dependent libraries. We then remove the Mnist dataset from the warehouse and move it to the parent directory:

!git clone https://github.com/bourdakos1/capsule-networks.git 
!pip install-r capsule-networks/requirements.txt 
!touch capsule-networks/__ init__.py 
!MV capsule-networks capsule!mv capsule/data
/./data/ 
!ls

Now let's import all the modules:

Import OS
import TensorFlow as TF from
TQDM import TQDM to

capsule.config import cfg from
capsule.utils Import load_mnist from
capsule.capsnet import capsnet
Initialization of
Capsnet = capsnet (is_training = cfg.is_training)

This is what the capsule network (cpnet) looks like on the Tensorboard map:

Training

Tf.logging.info (' graph loaded ')
SV = tf.train.Supervisor (graph = capsnet.graph,
                         logdir = Cfg.logdir,
                         save _model_secs = 0)

Path = cfg.results + '/accuracy.csv ' 
if not os.path.exists (cfg.results):
  Os.mkdir ( Cfg.results)
elif os.path.exists (path):
  os.remove (path) 

fd_results = open (Path, ' W ')

Now create a TF session and start execution.

By default, the model is trained to 50 epoch, with a batch size of 128. You can try a different combination of parameters:

With Sv.managed_session () as Sess:num_batch = Int (60000/cfg.batch_size) Num_test_batch = 10000//Cfg.batch_siz
            E TeX, Tey = Load_mnist (Cfg.dataset, False) for epoch in range (Cfg.epoch): If Sv.should_stop (): Break for step in TQDM (range (num_batch), Total=num_batch, ncols=70, Leave=false, unit= ' B '): global_s
                TEP = Sess.run (capsnet.global_step) Sess.run (capsnet.train_op) If step% cfg.train_sum_freq = 0: _, Summary_str = Sess.run ([Capsnet.train_op, Capsnet.train_summary]) sv.summary_writer.add_summary (su  MMARY_STR, Global_step) if (global_step + 1)% Cfg.test_sum_freq = = 0:TEST_ACC = 0 for I 
                    In range (Num_test_batch): start = i * cfg.batch_size end = start + cfg.batch_size TEST_ACC + = Sess.run (capsnet.batch_accuracy, {capsnet.x:tex[start:end), capsnet.labels:tey[start:en D]}) TEST_ACC = TEST_ACC/(cfg.batch_size * num_test_batch) fd_results.write (str (global_step + 1) + ', ' + str (tes T_ACC) + ' \ n ') Fd_results.flush () if epoch% Cfg.save_freq = = 0:sv.saver.save (Sess, CFG.LOGD IR + '/model_epoch_%04d_step_%02d '% (epoch, global_step)) Fd_results.close () tf.logging.info (' Training done ')

It took about 6 hours to run 50 epoch on the Nvidia TITANXP card.
But the trained network has been surprisingly effective, with total loss (loss) reaching an incredible 0.0038874.

download a good model of training

Cpnet Model URL: https://goo.gl/DN7SS3

Original: Running capsulenet on TensorFlow

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.