Q-Learning experience replay: how to feed the neural network?

I'm trying to replicate the DQN Atari experiment. Actually my DQN isn't performing well; checking another one's codes, I saw something about experience replay which I don't understand. First, when you define your CNN, in the first layer you have to specify the size (I'm using Keras + Tensorflow so in my case it's something like (105, 80, 4), which corresponds to height, width and number of images I feed my CNN.). In the codes I revisited, when they get the minibatch from the memory, I see they usually fed the CNN without "packing" it on 4 batches. How it is possible? I mean for example if you get 32 random sampled experiences, don't you need to make batches of 4 before feeding it? Here are an example of what I'm saying: https://github.com/yilundu/DQN-DDQN-on-Space-Invaders/blob/master/replay_buffer.py https://github.com/yilundu/DQN-DDQN-on-Space-Invaders/blob/master/deep_Q.py In this code, that's how he/she stores the experiences:

def add(self, s, a, r, d, s2):
        """Add an experience to the buffer"""
        # S represents current state, a is action,
        # r is reward, d is whether it is the end, 
        # and s2 is next state
        experience = (s, a, r, d, s2)
        if self.count  self.buffer_size:
            self.buffer.append(experience)
            self.count += 1
        else:
            self.buffer.popleft()
            self.buffer.append(experience)

Then when you need to use them:

 def sample(self, batch_size):
        """Samples a total of elements equal to batch_size from buffer
        if buffer contains enough elements. Otherwise return all elements"""

        batch = []

        if self.count  batch_size:
            batch = random.sample(self.buffer, self.count)
        else:
            batch = random.sample(self.buffer, batch_size)

        # Maps each experience in batch in batches of states, actions, rewards
        # and new states
        s_batch, a_batch, r_batch, d_batch, s2_batch = list(map(np.array, list(zip(*batch))))

        return s_batch, a_batch, r_batch, d_batch, s2_batch

Ok, so now you have a batch of 32 states, actions, rewards, done and next states.

This is how you feed the state batch (s_batch) and next state batch (s2_batch) to the CNN:

def train(self, s_batch, a_batch, r_batch, d_batch, s2_batch, observation_num):
        """Trains network to fit given parameters"""
        batch_size = s_batch.shape[0]
        targets = np.zeros((batch_size, NUM_ACTIONS))

        for i in range(batch_size):
            targets[i] = self.model.predict(s_batch[i].reshape(1, 84, 84, NUM_FRAMES), batch_size = 1)
            fut_action = self.target_model.predict(s2_batch[i].reshape(1, 84, 84, NUM_FRAMES), batch_size = 1)
            targets[i, a_batch[i]] = r_batch[i]
            if d_batch[i] == False:
                targets[i, a_batch[i]] += DECAY_RATE * np.max(fut_action)

        loss = self.model.train_on_batch(s_batch, targets)

        # Print the loss every 10 iterations.
        if observation_num % 10 == 0:
            print("We had a loss equal to ", loss)

In my code (https://bitbucket.org/jocapal/dqn_public/src/master/Deimos_v2_13.py) I get a batch of 32 experiences; then make small batches of 4 experiences and feed the CNN. My question is: am I doing it wrong? And if so, how can I feed 32 experiences when my CNN is waiting for 4 experiences?

Another example of what I'm saying: https://yanpanlau.github.io/2016/07/10/FlappyBird-Keras.html

Topic dqn keras-rl q-learning reinforcement-learning python

Category Data Science


Input is a 4D tensor [batch_size, height, width, channels] . Single state is already 4 frames stacked together so when you sample a state from the experience replay you sample a 3D tensor [height, width, channels]. When you sample 32 states you actually sample 32 of those 3D tensors and feed them directly to the network. For more details on preprocessing refer to the page 6 of the original DQN paper found here.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.