我正在使用TorchServe运行一个Yolov8对象检测器。在我的custom_handler中,我试图获取检测输出JSON,并获取带注解的边界框的图像。
当我运行下面的代码时,我没有得到任何错误,但没有保存任何图像。我还尝试使用Python的基本文件IO创建随机文件,它也不会创建这些文件。
是否可以直接在此处保存图像?如果不可以,最佳做法是什么?
import logging
import os
from collections import Counter
from PIL import Image
import torch
from torchvision import transforms
from ultralytics import YOLO
from ts.torch_handler.object_detector import ObjectDetector
logger = logging.getLogger(__name__)
try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError as error:
XLA_AVAILABLE = False
class Yolov8Handler(ObjectDetector):
image_processing = transforms.Compose(
[transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]
)
def __init__(self):
super(Yolov8Handler, self).__init__()
def initialize(self, context):
# Set device type
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.device = torch.device("cpu")
# Load the model
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
self.model_pt_path = os.path.join(model_dir, serialized_file)
self.model = self._load_torchscript_model(self.model_pt_path)
logger.debug("Model file %s loaded successfully", self.model_pt_path)
self.initialized = True
def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
# TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved
model = YOLO(model_pt_path)
model.to(self.device)
return model
def postprocess(self, res):
output = []
for data in res:
classes = data.boxes.cls.tolist()
names = data.names
# Map to class names
classes = map(lambda cls: names[int(cls)], classes)
# Get a count of objects detected
result = Counter(classes)
output.append(dict(result))
img_array = data.plot()
im = Image.fromarray(img_array[..., ::-1])
im.save('./result.jpg')
f = open("random.txt", "w")
f.write("Save me!")
f.close()
return output
字符串
1条答案
按热度按时间sr4lhrrt1#
我使用一个日志程序进行了调试,并通过os.getcwd()发现TorchServe将会话的文件存储在/tmp/models/目录中
在我的例子中,文件存储在/tmp/models/b3 c9 cda 84767441 ab 93 c842245 ee 2dfb/result. jpg中
路径可以在im.save()中指定,指向更合适的目录
字符串