Keras models break when I add batch normalization
I'm creating the model for a DDPG agent (keras-rl version) but i'm having some trouble with errors whenever I try adding in batch normalization in the first of two networks.
Here is the creation function as i'd like it to be:
def buildDDPGNets(actNum, obsSpace):
actorObsInput = Input(shape = (1,) + obsSpace, name = "actor_obs_input")
a = Flatten()(actorObsInput)
a = Dense(600, use_bias = False)(a)
a = BatchNormalization()(a)
a = Activation("relu")(a)
a = Dense(300, use_bias = False)(a)
a = BatchNormalization()(a)
a = Activation("relu")(a)
a = Dense(actNum)(a)
a = Activation("tanh")(a) # Bipdeal walker applies torque (-1, 1).
actor = Model(inputs = [actorObsInput], outputs = a)
criticActInput = Input(shape = (actNum,), name = "critic_action_input")
criticObsInput = Input(shape = (1,) + obsSpace, name = "critic_obs_input")
c = Flatten()(criticObsInput)
c = Dense(600, use_bias = False)(c)
c = BatchNormalization()(c)
c = Activation("relu")(c)
c = Concatenate()([c, criticActInput])
c = Dense(300, use_bias = False)(c)
c = BatchNormalization()(c)
c = Activation("relu")(c)
c = Dense(1)(c)
c = Activation("linear")(c)
critic = Model(inputs = [criticActInput, criticObsInput], outputs = c)
return (criticActInput, actor, critic)
This causes me to get the following error:
InvalidArgumentError: You must feed a value for placeholder tensor 'actor_obs_input' with dtype float and shape [?,1,24]
[[{{node actor_obs_input}}]]
However, If I remove the batch normalization from the first network (a) and do not change the second network (c) at all, It runs as it should.
a = Flatten()(actorObsInput)
a = Dense(600, use_bias = False)(a)
#a = BatchNormalization()(a)
a = Activation("relu")(a)
a = Dense(300, use_bias = False)(a)
#a = BatchNormalization()(a)
a = Activation("relu")(a)
a = Dense(actNum)(a)
Its also notable that if I do it the other way around (remove bn from c and leave it in a) the error still occurs. Any idea as to why that's happening? Its odd because it runs fine with batch norm in the critic, just not the actor. The models are being used by keras-rl DDPG agent btw.
Update: I've tried rewriting it with the sequential object instead of the functional api. Didn't help. Still got the error with no change. I'm beginning to think this is some sort of problem with keras's batch normalize class when being applied to systems of multiple models.
Topic keras-rl batch-normalization keras neural-network
Category Data Science