• 教程 >
  • Python 用 Flask REST API 部署 PyTorch
Shortcuts

Python 用 Flask REST API 部署 PyTorch

作者Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开 REST API 进行模型推理。 尤其是我们将部署预先训练的 DenseNet 121 模型来检测图像。

提示

此处使用的所有代码都根据 MIT 许可证发布,在Github上可访问。

这是在生产环境中部署 PyTorch 模型系列中的第一个教程。 以使用 Flask 这种方式为 PyTorch 模型提供服务是迄今为止最简单的方法,但它不适用于具有高性能要求的情况。 为此:

API 定义

我们将首先定义 API 接点、请求和响应类型。 我们的 API 结点将位于/predict,该结点接受HTTP POST 请求,其带有file参数包含图像。 响应将是包含预测结果的 JSON 格式的响应:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖

通过运行以下命令安装所需的依赖:

$ pip install Flask==1.0.3 torchvision-0.3.0

简单的 Web 服务器

下面是一个简单的网络服务器,取自 Flask 的文档

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

将上述代码段保存在名为app.py文件中,现在可以通过键入以下代码运行 Flask 开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当你在 Web 浏览器中访问http://localhost:5000/时,你将收到 Hello World! 文本

我们将对上述代码段进行细微更改,使其适合我们的 API 定义。 首先,我们将重命名要predict的方法。 我们将更新端点路径到/predict 由于映像文件将通过 HTTP POST 请求发送,我们将更新它,以便它也只接受 POST 请求:

@app.route('/predict', methods=['POST'])
def predict():
    return 'Hello World!'

我们还将更改响应类型,以便它返回包含 ImageNet 类 ID 和名称的 JSON 响应。 更新app.py文件现在将:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理|

在接下来的章节中,我们将重点介绍推理代码的编写。 这将涉及两个部分,一部分是我们准备图像以便将其馈送到 DenseNet;接下来,我们将编写代码,以便从模型获取实际预测。

准备图像|

DenseNet 模型要求图像为大小为 224 x 224 的 3 通道 RGB 图像。 我们还将使用所需的均值和标准偏差值对图像张数进行规范化。 你可以在这里阅读更多关于它。

我们将使用torchvision库的transforms,并构建一个转换管道,根据需要转换图像。 您可以在此处阅读有关转换的更多内容。

import io

import torchvision.transforms as transforms
from PIL import 图像

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    图像 = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法以字节为单位获取图像数据,应用一系列转换并返回张量。 要测试上述方法,请以字节模式读取图像文件(首先替换. 。/_static/img/sample_file.jpeg与计算机上文件的实际路径),并查看您是否获得回的张条:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

输出:

tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3302, -1.3473],
          [ 0.5364,  0.4851,  0.4508,  ..., -1.2959, -1.3130, -1.3302],
          [ 0.7077,  0.6392,  0.6049,  ..., -1.2959, -1.3302, -1.3644],
          ...,
          [ 1.3755,  1.3927,  1.4098,  ...,  1.1700,  1.3584,  1.6667],
          [ 1.8893,  1.7694,  1.4440,  ...,  1.2899,  1.4783,  1.5468],
          [ 1.6324,  1.8379,  1.8379,  ...,  1.4783,  1.7352,  1.4612]],

         [[ 0.5728,  0.5378,  0.5203,  ..., -1.3704, -1.3529, -1.3529],
          [ 0.6604,  0.6078,  0.5728,  ..., -1.3004, -1.3179, -1.3354],
          [ 0.8529,  0.7654,  0.7304,  ..., -1.3004, -1.3354, -1.3704],
          ...,
          [ 1.4657,  1.4657,  1.4832,  ...,  1.3256,  1.5357,  1.8508],
          [ 2.0084,  1.8683,  1.5182,  ...,  1.4657,  1.6583,  1.7283],
          [ 1.7458,  1.9384,  1.9209,  ...,  1.6583,  1.9209,  1.6408]],

         [[ 0.7228,  0.6879,  0.6531,  ..., -1.6476, -1.6302, -1.6476],
          [ 0.8099,  0.7576,  0.7228,  ..., -1.6476, -1.6476, -1.6650],
          [ 1.0017,  0.9145,  0.8797,  ..., -1.6476, -1.6650, -1.6999],
          ...,
          [ 1.6291,  1.6291,  1.6465,  ...,  1.6291,  1.8208,  2.1346],
          [ 2.1868,  2.0300,  1.6814,  ...,  1.7685,  1.9428,  2.0125],
          [ 1.9254,  2.0997,  2.0823,  ...,  1.9428,  2.2043,  1.9080]]]])

预测|

现在将使用预先训练的 DenseNet 121 模型来预测图像类。 我们将使用一个从torchvision库,加载模型,并得到一个推论。 尽管我们在此示例中将使用预训练的模型,但您可以将此方法用于您自己的模型。 tutorial中查看有关加载模型的更多信息。

from torchvision import models

# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

y_haty_hat量将包含预测类 ID 的索引。 但是,我们需要一个人类可读的类名。 为此,我们需要一个类 ID 来命名映射。 下载此文件作为imagenet_class_index.json并记住你保存它的位置(或者,如果你按照本教程中的确切步骤,保存在教程/_static)。 此文件包含 ImageNet 类 ID 到 ImageNet 类名称的映射。 我们将加载此 JSON 文件并获取预测索引的类名称。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用imagenet_class_index字典之前,首先我们将张条值转换为字符串值,因为imagenet_class_index键是字符串。 我们将测试我们的上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

输出:

['n02124075', 'Egyptian_cat']

您应该会得到这样的响应:

['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类 ID,第二项是人可读的名称。

注意

您是否注意到model变量不是get_prediction方法的一部分? 或者为什么模型是全局变量? 在内存和计算方面,加载模型可能是一项代价高昂的操作。 如果我们在get_prediction方法中加载模型,那么每次调用该方法时,模型都会不必要地加载。 因为,我们正在构建一个 Web 服务器,每秒可能会有数千个请求,我们不应浪费时间冗余地加载模型进行每个推理。 因此,我们将模型只加载到内存中一次。 在生产系统中,必须有效地使用计算才能大规模地为请求提供服务,因此通常应在提供服务请求之前加载模型。

将模型集成到我们的 API 服务器中 |

在最后一部分中,我们将模型添加到我们的 Flask API 服务器中。 由于我们的 API 服务器期望接受图像文件,我们将更新predict方法以从请求中读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现已完成。 以下是完整版本;将路径替换为你保存文件的路径则应该能够运行:

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

让我们测试一下我们的网络服务器! 运行:

$ FLASK_ENV=development FLASK_APP=app.py flask run

我们可以使用 requests 库向我们的应用程序发送 POST 请求:

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

现在打印 resp.json() 将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

后续步骤|

我们编写的服务器非常琐碎,可能无法执行生产应用程序所需的一切。 因此,您可以执行一些操作来使其变得更好:

  • 终结点/predict假定请求中始终存在图像文件。 这可能不适用于所有请求。 我们的用户可以发送具有不同参数的图像,或者根本不发送任何图像。

  • 用户也可以发送非图像类型文件。 由于我们不处理错误,这将破坏我们的服务器。 添加显式错误处理路径,将引发异常将使我们能够更好地处理错误输入

  • 尽管模型可以识别大量类图像,但它可能无法识别所有图像。 增强实现,以处理模型无法识别映像中的任何内容时的情况。

  • 我们在开发模式下运行 Flask 服务器,这不适合在生产中部署。 您可以查看本教程,了解在生产中部署 Flask 服务器。

  • 还可以通过创建具有采用图像并显示预测的窗体的页面来添加 UI。 查看类似项目的演示及其源代码

  • 在本教程中,我们仅演示如何构建一个一次可以返回单个图像的预测的服务。 我们可以修改我们的服务,以便能够同时返回多个图像的预测。 此外,服务流库会自动将请求排队到服务,并将它们采样到可馈入模型的微型批处理中。 你可以查看本教程

  • 最后,我们鼓励您查看我们有关在页面顶部链接的 PyTorch 模型部署的其他教程。

脚本总运行时间: ( 5 分 9.824 秒)

由狮身人面像库生成的画廊