Custom GRU With 3D Spatial Convolution Layer In Keras

I am trying to implement a custom GRU model that is shown in this paper 3D-R2N2 The GRU pipeline looks like:

The original implementation is theano based and I am trying to apply the model in tf2/Keras.

I have tried to create a custom GRU Cell from keras recurrent layer.

The input to the GRU model is of shape (Batch Size,Sequence,1024) and the output is (Batch Size, 4, 4, 4, 128). I have issues implementing the convolution layer present in the diagram due to shape incompatibility issues.

This is my attempted GRU Cell:

class CGRUCell(Layer):
    def __init__(self, units,
                 activation='tanh',
                 recurrent_activation='sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 implementation=2,
                 reset_after=False,
                 **kwargs):
        super(CGRUCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        self.implementation = implementation
        self.reset_after = reset_after
        self.state_size = self.units
        self.output_size = self.units
        self._dropout_mask = None
        self._recurrent_dropout_mask = None

    def build(self, input_shape):
        
        input_dim = input_shape[-1]

        if isinstance(self.recurrent_initializer, initializers.Identity):
            def recurrent_identity(shape, gain=1., dtype=None):
                del dtype
                return gain * np.concatenate(
                    [np.identity(shape[0])] * (shape[1] // shape[0]), axis=1)

            self.recurrent_initializer = recurrent_identity

        self.kernel = self.add_weight(shape=(input_dim, self.units * 3),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 3),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)

        if self.use_bias:
            if not self.reset_after:
                bias_shape = (3 * self.units,)
            else:
                # separate biases for input and recurrent kernels
                # Note: the shape is intentionally different from CuDNNGRU biases
                # `(2 * 3 * self.units,)`, so that we can distinguish the classes
                # when loading and converting saved weights.
                bias_shape = (2, 3 * self.units)
            self.bias = self.add_weight(shape=bias_shape,
                                        name='bias',
                                        initializer=self.bias_initializer,
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
            if not self.reset_after:
                self.input_bias, self.recurrent_bias = self.bias, None
            else:
                # NOTE: need to flatten, since slicing in CNTK gives 2D array
                self.input_bias = K.flatten(self.bias[0])
                self.recurrent_bias = K.flatten(self.bias[1])
        else:
            self.bias = None

        # update gate
        self.kernel_z = self.kernel[:, :self.units]
        self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units]
        # reset gate
        self.kernel_r = self.kernel[:, self.units: self.units * 2]
        self.recurrent_kernel_r = self.recurrent_kernel[:,
                                                        self.units:
                                                        self.units * 2]
        # new gate
        self.kernel_h = self.kernel[:, self.units * 2:]
        self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:]

        if self.use_bias:
            # bias for inputs
            self.input_bias_z = self.input_bias[:self.units]
            self.input_bias_r = self.input_bias[self.units: self.units * 2]
            self.input_bias_h = self.input_bias[self.units * 2:]
            # bias for hidden state - just for compatibility with CuDNN
            if self.reset_after:
                self.recurrent_bias_z = self.recurrent_bias[:self.units]
                self.recurrent_bias_r = (
                    self.recurrent_bias[self.units: self.units * 2])
                self.recurrent_bias_h = self.recurrent_bias[self.units * 2:]
        else:
            self.input_bias_z = None
            self.input_bias_r = None
            self.input_bias_h = None
            if self.reset_after:
                self.recurrent_bias_z = None
                self.recurrent_bias_r = None
                self.recurrent_bias_h = None
        self.built = True

    def call(self, inputs, states, training=None):
        h_tm1 = states[0]  # previous memory

        if 0  self.dropout  1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=3)
        if (0  self.recurrent_dropout  1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(h_tm1),
                self.recurrent_dropout,
                training=training,
                count=3)

        # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        if self.implementation == 1:
            if 0.  self.dropout  1.:
                inputs_z = inputs * dp_mask[0]
                inputs_r = inputs * dp_mask[1]
                inputs_h = inputs * dp_mask[2]
            else:
                inputs_z = inputs
                inputs_r = inputs
                inputs_h = inputs

            if 0.  self.recurrent_dropout  1.:
                h_tm1_z = h_tm1 * rec_dp_mask[0]
                h_tm1_r = h_tm1 * rec_dp_mask[1]
                h_tm1_h = h_tm1 * rec_dp_mask[2]
            else:
                h_tm1_z = h_tm1 
                h_tm1_r = h_tm1 
                h_tm1_h = h_tm1 

            x_z = K.dot(h_tm1_z, K.transpose(self.kernel_z) )
            x_r = K.dot(h_tm1_r, K.transpose(self.kernel_r) )
            x_h = K.dot(h_tm1_h, K.transpose(self.kernel_h) )
            if self.use_bias:
                x_z = K.bias_add(x_z, self.input_bias_z)
                x_r = K.bias_add(x_r, self.input_bias_r)
                x_h = K.bias_add(x_h, self.input_bias_h)

            recurrent_z = K.dot(inputs_z, self.recurrent_kernel_z)
            recurrent_r = K.dot( inputs_r, self.recurrent_kernel_r)
            if self.reset_after and self.use_bias:
                recurrent_z = K.bias_add(recurrent_z, self.recurrent_bias_z)
                recurrent_r = K.bias_add(recurrent_r, self.recurrent_bias_r)
            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            # reset gate applied after/before matrix multiplication
            if self.reset_after:
                recurrent_h = K.dot( inputs_h, self.recurrent_kernel_h)
                if self.use_bias:
                    recurrent_h = K.bias_add(recurrent_h, self.recurrent_bias_h)
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = K.dot(r * h_tm1_h, self.recurrent_kernel_h)

            hh = self.activation(x_h + recurrent_h)
        else:
            if 0.  self.dropout  1.:
                inputs *= dp_mask[0]

            # inputs projected by all gate matrices at once
            matrix_x = K.dot(inputs, self.kernel)
            if self.use_bias:
                # biases: bias_z_i, bias_r_i, bias_h_i
                matrix_x = K.bias_add(matrix_x, self.input_bias)
            x_z = matrix_x[:, :self.units]
            x_r = matrix_x[:, self.units: 2 * self.units]
            x_h = matrix_x[:, 2 * self.units:]

            if 0.  self.recurrent_dropout  1.:
                h_tm1 *= rec_dp_mask[0]

            if self.reset_after:
                # hidden state projected by all gate matrices at once
                matrix_inner = K.dot(h_tm1, self.recurrent_kernel)
                if self.use_bias:
                    matrix_inner = K.bias_add(matrix_inner, self.recurrent_bias)
            else:
                # hidden state projected separately for update/reset and new
                matrix_inner = K.dot(h_tm1,
                                     self.recurrent_kernel[:, :2 * self.units])

            recurrent_z = matrix_inner[:, :self.units] #Changes Expected Here
            recurrent_r = matrix_inner[:, self.units: 2 * self.units]

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            if self.reset_after:
                recurrent_h = r * matrix_inner[:, 2 * self.units:]
            else:
                recurrent_h = K.dot(r * h_tm1,
                                    self.recurrent_kernel[:, 2 * self.units:])

            hh = self.activation(x_h + recurrent_h)

        # previous and candidate state mixed by update gate
        h = (1 - z) * h_tm1 + z * hh
        if 0  self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return h, [h]

    def get_config(self):
        config = {'units': self.units,
                  'activation': activations.serialize(self.activation),
                  'recurrent_activation':
                      activations.serialize(self.recurrent_activation),
                  'use_bias': self.use_bias,
                  'kernel_initializer':
                      initializers.serialize(self.kernel_initializer),
                  'recurrent_initializer':
                      initializers.serialize(self.recurrent_initializer),
                  'bias_initializer': initializers.serialize(self.bias_initializer),
                  'kernel_regularizer':
                      regularizers.serialize(self.kernel_regularizer),
                  'recurrent_regularizer':
                      regularizers.serialize(self.recurrent_regularizer),
                  'bias_regularizer': regularizers.serialize(self.bias_regularizer),
                  'kernel_constraint': constraints.serialize(self.kernel_constraint),
                  'recurrent_constraint':
                      constraints.serialize(self.recurrent_constraint),
                  'bias_constraint': constraints.serialize(self.bias_constraint),
                  'dropout': self.dropout,
                  'recurrent_dropout': self.recurrent_dropout,
                  'implementation': self.implementation,
                  'reset_after': self.reset_after}
        base_config = super(CGRUCell, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Equations for the GRU gates are:

Testing

def build_3dgru(features):
    gru = RNN(CGRUCell(1024, use_bias=True, implementation=1))(features)
    print(gru)
    exit()
    
# Test
x = np.ones((200, 24, 1024), dtype = np.float32)
y = tf.constant(x)
build_3dgru(y)

I found a Tensorflow 1 implementation which may help

def fcconv3d_layer(h_t, feature_x, filters, n_gru_vox, namew, nameb):
    out_shape = h_t.get_shape().as_list()
    fc_output = fully_connected_layer(feature_x, n_gru_vox * n_gru_vox * n_gru_vox * filters, namew, nameb)
    fc_output = relu_layer(fc_output)
    fc_output = tf.reshape(fc_output, out_shape)
    h_tn = fc_output + slim.conv3d(h_t, filters, [3, 3, 3])
    return h_tn

# Fully connected layer
def fully_connected_layer(x, n_out, namew, nameb):
    shape = x.get_shape().as_list()
    n_in = shape[-1]
    fcw = create_variable(name = namew, shape = [n_in, n_out], initializer = tf.uniform_unit_scaling_initializer(factor = 1.0))
    fcb = create_variable(name = nameb, shape = [n_out], initializer = tf.truncated_normal_initializer())
    
#    fcw = tf.get_variable(name = namew, shape = [n_in, n_out], initializer = tf.uniform_unit_scaling_initializer(factor = 1.0))
#    fcb = tf.get_variable(name = nameb, shape = [n_out], initializer = tf.truncated_normal_initializer())
    return tf.nn.xw_plus_b(x, fcw, fcb)

# Leaky relu layer
def relu_layer(x):
    return tf.nn.leaky_relu(x, alpha = 0.1)
def create_variable(name, shape, initializer = tf.contrib.layers.xavier_initializer()):
    regularizer = tf.contrib.layers.l2_regularizer(scale = 0.997)
    return tf.get_variable(name, shape = shape, initializer = initializer, regularizer = regularizer)

def recurrence(h_t, fc, filters, n_gru_vox, index):
    u_t = tf.sigmoid(layers.fcconv3d_layer(h_t, fc, filters, n_gru_vox, u_%d_weight % (index), u_%d_bias % (index)))
    r_t = tf.sigmoid(layers.fcconv3d_layer(h_t, fc, filters, n_gru_vox, r_%d_weight % (index), r_%d_bias % (index)))
    h_tn = (1.0 - u_t) * h_t + u_t * tf.tanh(layers.fcconv3d_layer(r_t * h_t, fc, filters, n_gru_vox, h_%d_weight % (index), h_%d_bias % (index)))
    return h_tn

# Build 3d gru network
def build_3dgru(features):
    # features is a tensor of size [bs, size, 1024]
    # Split the features into many sequences.    
    with tf.variable_scope(gru, reuse = tf.AUTO_REUSE):
        shape = features.get_shape().as_list()
#        newshape = [shape[1], shape[0], shape[2]] # [size, bs, 1024]
#        features = tf.reshape(features, newshape)
        h = [None for _ in range(shape[1] + 1)]
        h[0] = tf.zeros(shape = [shape[0], n_gru_vox, n_gru_vox, n_gru_vox, n_deconv_filters[0]], dtype = tf.float32)
        for i in range(shape[1]):
            fc = features[:, i, ...]
            h[i + 1] = recurrence(h[i], fc, n_deconv_filters[0], n_gru_vox, i)
        # [bs, 4, 4, 4, 128]
        return h[-1]
    
def test_3dgru():
    x = np.ones((16, 5, 1024), dtype = np.float32)
    y = tf.constant(x)
    output = build_3dgru(y)
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    print(OK!)
    print (output)
```

Topic gru keras rnn deep-learning machine-learning

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.