How to slice an input in keras?

I give to keras an input of shape input_shape=(500,).

For some reasons, I would like to decompose the input vector into to vectors of respective shapes input_shape_1=(300,) and input_shape_2=(200,)

I want to do this within the definition of the model, using the Functional API. In a way, I would like to perform slicing on a tf.Tensor object.

Help is welcome!

Topic reshape keras

Category Data Science


If it's just the input you like to decompose, you can preprocess the input data and use two input layers:

import tensorflow as tf

inputs_first_half = tf.keras.Input(shape=(300,))
inputs_second_half = tf.keras.Input(shape=(200,))

# do something with it
first_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(inputs_first_half)
second_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(inputs_second_half)
outputs = tf.keras.layers.Add()([first_half, second_half])

model = tf.keras.Model(inputs=[inputs_first_half,inputs_second_half],outputs=outputs)

data = np.random.randn(10,500)
out = model.predict([data[:,:300],data[:,300:]])

If you like to split after the input layer you could try reshaping and cropping, e.g,:

inputs = tf.keras.Input(shape=(500,))

# do something
intermediate = tf.keras.layers.Dense(500,activation=tf.nn.relu)(inputs)

# split vector with cropping
intermediate = tf.keras.layers.Reshape((500,1), input_shape=(500,))(intermediate)

first_half = tf.keras.layers.Cropping1D(cropping=(0,200))(intermediate)
first_half = tf.keras.layers.Reshape((300,), input_shape=(300,1))(first_half)

second_half = tf.keras.layers.Cropping1D(cropping=(300,0))(intermediate)
second_half = tf.keras.layers.Reshape((200,), input_shape=(200,1))(second_half)


# do something with decomposed vectors
first_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(first_half)
second_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(second_half)
outputs = tf.keras.layers.Add()([first_half, second_half])

model = tf.keras.Model(inputs=inputs, outputs=outputs)

data = np.random.randn(10,500)
out = model.predict(data)

The Cropping1D() function expects a three-dimensional input (batch_size, axis_to_crop, features) and only crops along the first dimension, therefore we need to add "pseudo-dimension" to our vector by reshaping it.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.