Keras Backpropagation when Later Layers are Frozen
I am working on a project with facial image translation and GANs and still have some conceptual misunderstandings. In my definition of my model, I extract a deep embedding of my generated image and the input image using a state of the art CNN which I mark as untrainable, calculate the distance between these embeddings and use this distance itself as a loss in my model definition. If the model from which the embeddings come from is untrainable, will the error propagate backward and train the weights of the generator? I build off CycleGan code as such:
# define a composite model for updating generators by adversarial and cycle loss
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
# ensure the model we're updating is trainable
g_model_1.trainable = True
# mark discriminator as not trainable
d_model.trainable = False
# mark other generator model as not trainable
g_model_2.trainable = False
# discriminator element
input_gen = Input(shape=image_shape)
gen1_out = g_model_1(input_gen)
output_d = d_model(gen1_out)
# identity element
input_id = Input(shape=image_shape)
output_id = g_model_1(input_id)
# forward cycle
output_f = g_model_2(gen1_out)
# backward cycle
gen2_out = g_model_2(input_id)
output_b = g_model_1(gen2_out)
gen1_out_scaled = Rescaling(scale=(127.5), offset=127.5)(gen1_out)
gen1_out_prep = Lambda(lambda x: tf_preprocess_input(x))(gen1_out_scaled)
# fake_embed
gen_embed = vggmodel(gen1_out_prep)
real_src_scaled = Rescaling(scale=(127.5), offset=127.5)(input_gen)
real_src_prep = Lambda(lambda x: tf_preprocess_input(x))(real_src_scaled)
# true_embed_src
real_src_embed = vggmodel(real_src_prep)
real_tar_scaled = Rescaling(scale=(127.5), offset=127.5)(input_id)
real_tar_prep = Lambda(lambda x: tf_preprocess_input(x))(real_tar_scaled)
# true_embed
real_tar_embed = vggmodel(real_tar_prep)
embeds = tf.stack([gen_embed, real_src_embed, real_tar_embed])
# define model graph
model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b, embeds])
# define optimization algorithm configuration
opt = Adam(lr=0.0002, beta_1=0.5)
# compile model with weighting of least squares loss and L1 loss
model.compile(loss=['mse', 'mae', 'mae', 'mae', face_verif_loss], loss_weights=[1, 5, 10,
10, 4], optimizer=opt)
return model
where vggmodel
is a state of the art model I set trainable = False
.
Topic cyclegan gan keras backpropagation computer-vision
Category Data Science