在使用eager模式的时候,如果类继承的是tf.keras.layers.Layer
,要获得该类的变量,即类.variables
,需要在类中使用self.add_variable()
生成变量,一般可以在build函数中添加;如果继承的是tf.keras.Model,任意位置使用tf.get_variable()
生成变量,并赋值给self.变量名
。一般在__init__()
函数中 ,如果使用tf.layers
调用相关的层,比如conv2d层等,需要在init函数中声明,然后再在call函数中调用。