我写了下面的类来执行示例分割并返回给定类的掩码。代码似乎是随机运行的,它不是确定性的。(以及标签的数量)在每次执行时都会改变,即使我在包含单个人的同一输入图像上运行代码。代码没有打印任何警告或异常。请注意,我正在CPU上运行代码。
import numpy as np
import torch
from torch import Tensor
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
import torchvision.transforms as T
import PIL
from PIL import Image
class RetinaNet:
def __init__(self, weights: RetinaNet_ResNet50_FPN_V2_Weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1):
# Load the pre-trained DeepLabV3 model
self.weights = weights
self.model = retinanet_resnet50_fpn_v2(
pretrained=RetinaNet_ResNet50_FPN_V2_Weights
)
self.model.eval()
# Check if a GPU is available and if not, use a CPU
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
# Define the transformation
self.transform = T.Compose([
T.ToTensor(),
])
def infer_on_image(self, image: PIL.Image.Image, label: str) -> Tensor:
# Transform image
input_tensor = self.transform(image)
input_tensor = input_tensor.unsqueeze(0)
input_tensor.to(self.device)
# Run model
with torch.no_grad():
predictions = self.model(input_tensor)
# Post-processing to create masks for requested label
label_index = self.get_label_index(label)
boxes = predictions[0]['boxes'][predictions[0]['labels'] == label_index]
print('labels', predictions[0]['labels']) # random output
masks = torch.zeros((len(boxes), input_tensor.shape[1], input_tensor.shape[2]), dtype=torch.uint8)
for i, box in enumerate(boxes.cpu().numpy()):
x1, y1, x2, y2 = map(int, box)
masks[i, y1:y2, x1:x2] = 1
return masks
def get_label_index(self,label: str) -> int:
return self.weights.value.meta['categories'].index(label)
def get_label(self, label_index: int) -> str:
return self.weights.value.meta['categories'][label_index]
@staticmethod
def load_image(file_path: str) -> PIL.Image.Image:
return Image.open(file_path).convert("RGB")
if __name__ == '__main__':
from matplotlib import pyplot as plt
image_path = 'person.jpg'
# Run inference
retinanet = RetinaNet()
masks = retinanet.infer_on_image(
image=retinanet.load_image(image_path),
label='person'
)
# Plot image
plt.imshow(retinanet.load_image(image_path))
plt.show()
# PLot mask
for i, mask in enumerate(masks):
mask = mask.unsqueeze(2)
plt.title(f'mask {i}')
plt.imshow(mask)
plt.show()
字符串
1条答案
按热度按时间ruoxqz4g1#
对我来说,我总是实现下面的脚本并重现完全相同的结果,除了使用DDP。
在
__main__
脚本的起点,字符串
带有异步任务的DDP数据加载器采样器再现了不同时间的数据增量。它可以用一些技巧来处理,但在我的方法中没有使用。
在
dataloader
类的实现中,型