pytorch 如何使用TorchServe保存图像或文件?

cfh9epnr  于 6个月前  发布在  其他
关注(0)|答案(1)|浏览(82)

我正在使用TorchServe运行一个Yolov8对象检测器。在我的custom_handler中,我试图获取检测输出JSON,并获取带注解的边界框的图像。
当我运行下面的代码时,我没有得到任何错误,但没有保存任何图像。我还尝试使用Python的基本文件IO创建随机文件,它也不会创建这些文件。
是否可以直接在此处保存图像?如果不可以,最佳做法是什么?

import logging
import os
from collections import Counter
from PIL import Image

import torch
from torchvision import transforms
from ultralytics import YOLO

from ts.torch_handler.object_detector import ObjectDetector

logger = logging.getLogger(__name__)

try:
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
except ImportError as error:
    XLA_AVAILABLE = False

class Yolov8Handler(ObjectDetector):
    image_processing = transforms.Compose(
        [transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]
    )

    def __init__(self):
        super(Yolov8Handler, self).__init__()

    def initialize(self, context):
        # Set device type
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif XLA_AVAILABLE:
            self.device = xm.xla_device()
        else:
            self.device = torch.device("cpu")

        # Load the model
        properties = context.system_properties
        self.manifest = context.manifest
        model_dir = properties.get("model_dir")
        self.model_pt_path = None
        if "serializedFile" in self.manifest["model"]:
            serialized_file = self.manifest["model"]["serializedFile"]
            self.model_pt_path = os.path.join(model_dir, serialized_file)
        self.model = self._load_torchscript_model(self.model_pt_path)
        logger.debug("Model file %s loaded successfully", self.model_pt_path)

        self.initialized = True

    def _load_torchscript_model(self, model_pt_path):
        """Loads the PyTorch model and returns the NN model object.

        Args:
            model_pt_path (str): denotes the path of the model file.

        Returns:
            (NN Model Object) : Loads the model object.
        """
        # TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved

        model = YOLO(model_pt_path)
        model.to(self.device)
        return model

    def postprocess(self, res):
        output = []
        for data in res:

            classes = data.boxes.cls.tolist()
            names = data.names

            # Map to class names
            classes = map(lambda cls: names[int(cls)], classes)

            # Get a count of objects detected
            result = Counter(classes)
            output.append(dict(result))

            img_array = data.plot()
            im = Image.fromarray(img_array[..., ::-1])
            im.save('./result.jpg')

            f = open("random.txt", "w")
            f.write("Save me!")
            f.close()

        return output

字符串

sr4lhrrt

sr4lhrrt1#

我使用一个日志程序进行了调试,并通过os.getcwd()发现TorchServe将会话的文件存储在/tmp/models/目录中
在我的例子中,文件存储在/tmp/models/b3 c9 cda 84767441 ab 93 c842245 ee 2dfb/result. jpg中
路径可以在im.save()中指定,指向更合适的目录

im.save('/preferred/output/path/result.jpg')

字符串

相关问题