它需要一些工作,需要一个自定义层。比如你
cannot use tf.Variable with tf.keras.Lambda
class ConvNorm(layers.Layer):
def __init__(self, height, width, n_filters):
super(ConvNorm, self).__init__()
self.height = height
self.width = width
self.n_filters = n_filters
def build(self, input_shape):
self.filter = self.add_weight(shape=(self.height, self.width, input_shape[-1], self.n_filters),
initializer='glorot_uniform',
trainable=True)
# TODO: Add bias too
def call(self, x, scale, shift):
shift_reshaped = tf.expand_dims(tf.expand_dims(shift,1),1)
scale_reshaped = tf.expand_dims(tf.expand_dims(scale,1),1)
norm_conv_out = tf.nn.conv2d(x, self.filter*scale + shift, strides=(1,1,1,1), padding='SAME')
return norm_conv_out
使用图层
import tensorflow as tf
import tensorflow.keras.layers as layers
input_img = layers.Input(shape=(28, 28, 1))
label = layers.Input(shape=(10,)) # number of classes
num_filters = 32
shift = layers.Dense(num_filters, activation=None, name='shift')(label) # (32,)
scale = layers.Dense(num_filters, activation=None, name='scale')(label) # (32,)
conv_norm_out = ConvNorm(3,3,32)(input_img, scale, shift)
print(norm_conv_out.shape)
笔记
:请注意,我没有添加偏见。卷积层也需要偏置。但这是直截了当的。