1. 程式人生 > >Some tips about using google’s TPU

Some tips about using google’s TPU

About one month ago, I submit a request to Google Research Cloud for using TPU for free. Fortunately, I received the approvement yesterday. The approvement let me use 5 regular Cloud TPUs and 100 preemptible Cloud TPUs for free for 30 days with only submitting my GCP project name to it.
Then I have to change my previous Tensorflow program to let it run on TPUs. I can’t just change tf.device(‘/gpu:0’) to ‘tf.device(‘/tpu:0’) in code to run training on Google TPU. Actually, there are many documents about how to modify the code for this, such as

TPUEstimator, Using TPUs etc.

Here are some tips about porting code for TPUs:

1. We can only use TPUEstimator for training

Python
1234567 classifier=tf.contrib.tpu.TPUEstimator(model_fn=model_wrapper,config=run_config,use_tpu=FLAGS.use_tpu,train_batch_size=64,batch_axis=[0,0],params={'optimizer':opt})

Pay attention to the ‘batch_axis’. It tells TPU pod to split data by ‘0’ dimension for data and labels, for I use ‘NHWC’ data format.

2. model_fn and data_input_fn in TPUEstimator has arguments more than regular tf.estimator.Estimator. We need to fetch some arguments (‘batch_size’) from params.

Python
12345 defdata_input_fn(params):batch=params['batch_size']...defmodel_fn(features,labels,mode,config,params):...

3. TPU doesn’t support the operation like

Python
1 images=tf.contrib.image.rotate(images,tf.random_uniform([1],minval=-math.pi/4.0,maxval=math.pi/4.0))

So try to avoid using them

4. Carefully use tf.dataset or else it will report data shape error. The code below could run correctly so far

Python
1234567 dataset=files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,sloppy=True,cycle_length=buff_size))dataset=dataset.map(_parse_function)dataset=dataset.repeat()dataset=dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))dataset=dataset.shuffle(batch_size*buff_size)iterator=dataset.make_initializable_iterator()

5. Because using TPUEstimator, we can’t init iterator of tf.dataset in ‘session.run()’, so a little trick should be used:

Python
1234 defdata_input_fn():...tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,it.initializer)...

6. The Tensorflow in GCP VM instance only supports loading datasets from and storing model into GCP storage.

123456789 run_config=tf.contrib.tpu.RunConfig(master=master,evaluation_master=master,model_dir='gs://my-project/models/',session_config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=True),tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations,FLAGS.num_shards))

7. There aren’t any hooks for TPUEstimator currently in Tensorflow-1.9. So I can’t see any report from console after launching a TPU program. Hope Google could improve it as soon as possible.