自定义层函数需要继承layers.Layer,自定义网络需要继承keras.Model。
其内部需要定义两个函数:
1、__init__初始化函数,内部需要定义构造形式;
2、call函数,内部需要定义计算形式及返回值。
#self def layer class MyDense(layers.Layer):#inherit layers.Layer def __init__(self,input_dim,output_dim):#init super(MyDense,self).__init__() self.kernal = self.add_variable('w',[input_dim,output_dim]) self.bias = self.add_variable('b',[output_dim]) def call(self,inputs,training=None):#compute out = inputs @ self.kernal + self.bias return out 1234567891011
#self def network class MyModel(keras.Model):#inherit keras.Model def __init__(self):#init super(MyModel,self).__init__() self.fc1 = MyDense(input_dim=28*28,output_dim=512) self.fc2 = MyDense(input_dim=512, output_dim=256) self.fc3 = MyDense(input_dim=256, output_dim123456