tensorflow 由于tf.function +自定义训练函数导致的内存泄漏

jobtbby3  于 6个月前  发布在  其他
关注(0)|答案(2)|浏览(82)

我得到了以下模型

class FRAE(tf.keras.Model):
def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, bypass=False, trainable=True,**kwargs):
    super(FRAE, self).__init__(**kwargs)
    self.latent_dim = latent_dim
    self.shape = shape
    self.ht = ht
    self.buffer = tf.Variable(initial_value=tf.zeros(shape=(1,shape[0] * self.ht), dtype=tf.float32), trainable=False)
    self.bypass = bypass
    self.quantizer = None
    self.trainable = trainable
    
    self.l1 = tf.keras.layers.Dense(n1, activation='tanh', input_shape=shape)
    self.l2 = tf.keras.layers.Dense(n1, activation='tanh')
    self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')

    self.l3 = tf.keras.layers.Dense(n3, activation='tanh')
    self.l4 = tf.keras.layers.Dense(n4, activation='tanh')
    self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')

def get_config(self):
    config = super(FRAE,self).get_config().copy()
    config.update({'latent_dim':self.latent_dim, 'bypass':self.bypass, 'quantizer':self.quantizer, 
                   "encoder":self.encoder, "buffer":self.buffer,
                   'decoder':self.decoder,"ht":self.ht, "shape":self.shape, "name":self.name})        
    
    return config
      
@tf.function(experimental_compile=True)
def update_buffer(self, new_element):
    n = self.shape[0]
    self.buffer.assign(tf.keras.backend.concatenate([new_element, self.buffer[:, :-n]], axis=1))

@tf.function(experimental_compile=True)
def resetBuffer(self):
    self.buffer[:,:].assign(tf.zeros(shape=(1,self.shape[0] * self.ht), dtype=tf.float32))

@tf.function(experimental_compile=True)
def call(self, x):        
    x = tf.squeeze(x,axis=0)
    decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[0])
    for i in tf.range(tf.shape(x)[0]):

        xexpand = tf.expand_dims(x[i],axis=0)
        xin = tf.concat((xexpand, self.buffer), axis=1) # xin = tf.concat((xexpand, self.buffer), axis=1)

        encoded = self.ls(self.l2(self.l1(xin)))
        decin = tf.concat([encoded, self.buffer], axis=1)
        y = self.l5(self.l4(self.l3(decin)))
        decoded = decoded.write(i,y)
        i += 1
        # self.update_buffer(tf.squeeze(y))
        self.update_buffer(y)

    tmp = tf.transpose(decoded.stack(),[1,0,2])
    return tmp

@tf.function(experimental_compile=True)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
    x, y = data

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compute_loss(y=y, y_pred=y_pred)

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    for metric in self.metrics:
        if metric.name == "loss":
            metric.update_state(loss)
        else:
            metric.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

字符串
最后运行良好。然而,我注意到当我训练这个模型时,我的程序的内存使用每隔几秒就增加一次。最有可能的原因是我使用的缓冲区,这是存储中间输出所必需的。目前,我使用tf.Variable。更准确地说,这一行

self.buffer.assign(tf.keras.backend.concatenate([new_element, self.buffer[:, :-n]], axis=1))


似乎是原因。有什么替代方法吗?偶尔调用gc.collect()没有任何作用。我使用的是tensorflow 2.0,所以没有tf. placeholder。我可以在我的情况下使用什么?
编辑:我刚刚测试过,简单的预测不会增加内存消耗。所以泄漏发生在训练期间。我的训练调用是

stoploss = trainutil.stopAtLossValue()
    resetbuffercallback = trainutil.ResetBufferCallback(frae)
    
    
    frae.fit(train_enc_left, train_enc_left, batch_size=1, epochs=10, callbacks=[stoploss, resetbuffercallback], verbose=1)


回调函数定义为

class ResetBufferCallback(tf.keras.callbacks.Callback):
def __init__(self):
    super(ResetBufferCallback, self).__init__()

def on_batch_end(self, batch, logs=None):
    self.model.resetBuffer()
   
class stopAtLossValue(Callback):
        def on_batch_end(self, batch, logs={}):
            THR = 10 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') > THR or math.isnan(logs.get('loss')) is True:
                 self.model.stop_training = True


我不知道回调如何导致内存泄漏,所以可能是自定义training_step导致了这个问题。
编辑:所以很明显,如果你使用@ tf. function. See https://github.com/tensorflow/tensorflow/issues/50765 Veeery不幸的是,我需要编译这个来让模型达到训练的速度。我不能编译它而不编译训练步骤函数。有人知道如何解决这个问题吗?

m528fe3b

m528fe3b1#

内存泄漏给你带来了什么问题?除了等待tensorflow修复底层问题之外,你不一定能做很多事情来阻止泄漏本身,但是有一些缓解策略。
如果问题是您在训练期间耗尽了内存,那么您是否可以获得更多的内存,或者进行更改以减少训练的内存占用-足以完成训练?
如果问题是在训练之后,你没有足够的内存,你不能释放内存,那么一个选择是在一个单独的进程中运行建模训练,保存模型,然后杀死进程。此时内存将被释放,你可以重新加载模型。

r7knjye2

r7knjye22#

不幸的是,没有解决这个问题的办法。我不得不在没有相当大的速度提升的情况下生活,并选择了更大的批量大小

相关问题