keras 在Colab中断后继续进行U-Net培训的问题

jhdbpxl9  于 7个月前  发布在  其他
关注(0)|答案(1)|浏览(79)

我正在使用Google Colab来训练一个用于物体检测的U-Net网络(单个对象)。我的问题是,很多时候,由于各种原因,训练被中断,迫使我重新开始这个过程。我使用下面的代码试图从我停止的地方继续训练,但每次我尝试,算法都从epoch 1开始训练。例如,如果我训练了30个epoch,它应该从时期31向前继续,但是相反,算法从第一时期开始,这指示训练不是继续而是重新开始。
下面是我试图从我离开的地方恢复训练的特定部分:

import os
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Caminho para o diretório de checkpoints (o mesmo usado durante o treinamento)
checkpoint_dir = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/checkpoints/'

# Carregue o modelo (certifique-se de que a arquitetura é a mesma)
model = unet()

# Defina o otimizador e compile o modelo (use os mesmos parâmetros do treinamento original)
optimizer = SGD(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

# Carregue os pesos do modelo a partir do último checkpoint
model.load_weights(os.path.join(checkpoint_dir, 'best_checkpoint.ckpt'))

# ModelCheckpoint para salvar os melhores checkpoints
model_checkpoint = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'best_checkpoint.ckpt'),
    save_weights_only=True,  # Salvar apenas os pesos do modelo
    save_best_only=True,  # Salvar apenas o melhor checkpoint com base na métrica de validação
    monitor='val_loss',  # Métrica de validação a ser monitorada (ajuste conforme necessário)
    verbose=1
)

# EarlyStopping para interromper o treinamento se não houver melhora na métrica de validação
early_stopping = EarlyStopping(
    patience=10,  # Número de épocas sem melhora para interromper o treinamento
    restore_best_weights=True,  # Restaurar os melhores pesos do modelo
    verbose=1
)

# Continue o treinamento a partir da última época
history = model.fit(
    np.array(train_images),
    np.array(train_masks),
    epochs=500,  # Ajuste conforme necessário
    batch_size=6,  # Ajuste conforme necessário
    validation_data=(np.array(test_images), np.array(test_masks)),
    callbacks=[model_checkpoint, early_stopping]
)

字符串
值得注意的是,该算法使用检查点生成三个具有默认训练的文件:best_checkpoint.ckpt.data-00000-of-00001、best_checkpoint.ckpt.index和checkpoint。
以下是我的全部代码,给那些想看的人:
输入块

import pandas as pd
import cv2
import numpy as np
import os
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
from sklearn.metrics import accuracy_score, recall_score, average_precision_score
from scipy.ndimage.measurements import label
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import load_model


列车与试验之间的划分(CSV)

# Caminhos para os diretórios de treinamento e teste
train_dir = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/train/images/'
test_dir = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/val/images/'

# Carregando o CSV geral
csv_path = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/nodules.csv'
data = pd.read_csv(csv_path)

# Criando os DataFrames para treinamento e teste
train_data = pd.DataFrame(columns=data.columns)
test_data = pd.DataFrame(columns=data.columns)

# Iterando sobre as imagens no diretório de treinamento
for img_name in os.listdir(train_dir):
    if img_name.endswith('.png'):
        img_name = img_name.strip()  # Remover espaços em branco extras, se houver
        # Procurar pela imagem correspondente no CSV geral
        rows = data[data['img_name'] == img_name]
        if not rows.empty:
            # Adicionar as linhas correspondentes ao DataFrame de treinamento
            train_data = pd.concat([train_data, rows])

# Iterando sobre as imagens no diretório de teste
for img_name in os.listdir(test_dir):
    if img_name.endswith('.png'):
        img_name = img_name.strip()  # Remover espaços em branco extras, se houver
        # Procurar pela imagem correspondente no CSV geral
        rows = data[data['img_name'] == img_name]
        if not rows.empty:
            # Adicionar as linhas correspondentes ao DataFrame de teste
            test_data = pd.concat([test_data, rows])

# Salvar os DataFrames em arquivos CSV separados
train_data.to_csv('/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/train_data.csv', index=False)
test_data.to_csv('/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/test_data.csv', index=False)


加载图像及其蒙版

# Carregar os CSVs de anotações para treinamento e teste
train_annotations = pd.read_csv('/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/train_data.csv')
test_annotations = pd.read_csv('/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/test_data.csv')

# Diretórios das imagens de treinamento e teste
train_image_dir = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/train/images/'
test_image_dir = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/val/images/'

# Inicializar as máscaras e as listas de imagens
image_shape = (640, 640)
train_masks = np.zeros((len(train_annotations),) + image_shape, dtype=np.uint8)
test_masks = np.zeros((len(test_annotations),) + image_shape, dtype=np.uint8)

# Carregar as imagens de treinamento e suas máscaras
train_images = []
for img_name in train_annotations['img_name']:
    img = cv2.imread(train_image_dir + img_name, cv2.IMREAD_GRAYSCALE)
    train_images.append(img)

    # Itere sobre as anotações e crie as máscaras de treinamento
    for i, row in train_annotations.iterrows():
        x, y, width, height = row['x'], row['y'], row['width'], row['height']
        train_masks[i, y:y+height, x:x+width] = 1

# Carregar as imagens de teste e suas máscaras
test_images = []
for img_name in test_annotations['img_name']:
    img = cv2.imread(test_image_dir + img_name, cv2.IMREAD_GRAYSCALE)
    test_images.append(img)

    # Iterar sobre as anotações e crie as máscaras de teste
    for i, row in test_annotations.iterrows():
        x, y, width, height = row['x'], row['y'], row['width'], row['height']
        test_masks[i, y:y+height, x:x+width] = 1

# Dividir os dados em conjuntos de treinamento e teste
X_train, X_test, y_train, y_test = train_test_split(train_images, train_masks, test_size=0.2, random_state=42)


模型

# Função para construir a U-Net
def unet(input_shape=(640, 640, 1)):
    inputs = Input(input_shape)

    # Parte de codificação (downsampling)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    (...)

    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)

    # Parte de decodificação (upsampling)
    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = Concatenate()([up6, conv4])
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)

    (...)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Concatenate()([up9, conv1])
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)

    # Saída
    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

model = unet()

# Compilação do modelo
from tensorflow.keras.optimizers import SGD

model.compile(optimizer=SGD(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# Visualização da arquitetura da U-Net
model.summary()


培训脚本

# Caminho para o diretório onde deseja salvar os checkpoints
checkpoint_dir = '/content/drive/MyDrive/Node21/YOLOv8/tentativa_1/checkpoints/'

# Certifique-se de que o diretório de checkpoints exista, caso contrário, crie-o
os.makedirs(checkpoint_dir, exist_ok=True)

# ModelCheckpoint para salvar apenas os melhores checkpoints
model_checkpoint = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, 'best_checkpoint.ckpt'),
    save_weights_only=True,  # Salvar apenas os pesos do modelo
    save_best_only=True,  # Salvar apenas o melhor checkpoint com base na métrica de validação
    monitor='val_loss',  # Métrica de validação a ser monitorada (pode ser ajustada)
    verbose=1
)

# EarlyStopping para interromper o treinamento se não houver melhora na métrica de validação
early_stopping = EarlyStopping(
    patience=10,  # Número de épocas sem melhora para interromper o treinamento
    restore_best_weights=True,  # Restaurar os melhores pesos do modelo
    verbose=1
)

# Continue o treinamento a partir da última época
history = model.fit(
    np.array(train_images),
    np.array(train_masks),
    epochs=100,
    batch_size=4,
    validation_data=(np.array(test_images), np.array(test_masks)),
    callbacks=[model_checkpoint, early_stopping]
)

zu0ti5jz

zu0ti5jz1#

正如@stateMachine所述,我们需要在fit()参数中使用initial_epoch。当涉及到我的代码时,我需要使用保存_weights_only = true来保存优化器状态。
感谢@stateMachine在评论中的回复。

相关问题