Matching networks in TensorFlow
Now, we will see how to build a matching network in TensorFlow step by step. We will see the final code at the end.
First, we import the libraries:
import tensorflow as tf slim = tf.contrib.slim rnn = tf.contrib.rnn
Now, we define a class called Matching_network
, where we define our network:
class Matching_network():
We define the __init__
method, where we initialize all of the variables:
def __init__(self, lr, n_way, k_shot, batch_size=32): #placeholder for support set self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1]) self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ]) #placeholder for query set self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1]) self.query_label = tf.placeholder(tf.int32, [None, ])
Let's say our support set and query set have images. Before feeding this raw image to the embedding function, first, we will extract the features...