使用 Ray 和 Apache Arrow 实现快速 Python 序列化


已发布 2017 年 10 月 15 日
作者 Philipp Moritz, Robert Nishihara

本文最初发表在 Ray 博客上。 Philipp MoritzRobert Nishihara 是加州大学伯克利分校的研究生。

本文详细阐述了 RayApache Arrow 之间的集成。它主要解决的问题是 数据序列化

来自 维基百科序列化

… 将数据结构或对象状态转换为可以存储… 或传输… 并稍后(可能在不同的计算机环境中)重建的格式的过程。

为什么需要任何转换? 好吧,当您创建一个 Python 对象时,它可能具有指向其他 Python 对象的指针,并且这些对象都分配在内存的不同区域中,所有这些都必须在另一台机器上的另一个进程解包时有意义。

序列化和反序列化是 并行和分布式计算中的瓶颈,尤其是在具有大型对象和大量数据的机器学习应用程序中。

设计目标

由于 Ray 针对机器学习和 AI 应用程序进行了优化,因此我们非常关注序列化和数据处理,并具有以下设计目标

  1. 它应该对 大型数值数据 非常有效(包括 NumPy 数组和 Pandas DataFrames,以及递归包含 Numpy 数组和 Pandas DataFrames 的对象)。
  2. 对于 通用 Python 类型,它应该与 Pickle 一样快。
  3. 它应该与 共享内存 兼容,允许多个进程使用相同的数据而无需复制。
  4. 反序列化 应该非常快(如果可能,它不应要求读取整个序列化对象)。
  5. 它应该是 与语言无关的 (最终,我们希望 Python 工作进程能够使用 Java 或其他语言的工作进程创建的对象,反之亦然)。

我们的方法和替代方案

Python 中首选的序列化方法是 pickle 模块。 Pickle 非常通用,特别是如果您使用像 cloudpickle 这样的变体。 但是,它不满足要求 1、3、4 或 5。 像 json 这样的替代方案满足 5,但不满足 1-4。

我们的方法: 为了满足要求 1-5,我们选择使用 Apache Arrow 格式作为我们的底层数据表示。 与 Apache Arrow 团队合作,我们构建了 ,用于将通用 Python 对象映射到 Arrow 格式以及从 Arrow 格式映射。 此方法的一些属性

  • 数据布局与语言无关(要求 5)。
  • 可以在恒定时间内计算到序列化数据 blob 中的偏移量,而无需读取完整对象(要求 1 和 4)。
  • Arrow 支持 零拷贝读取,因此对象可以自然地存储在共享内存中并由多个进程使用(要求 1 和 3)。
  • 对于我们无法很好处理的任何内容,我们可以自然地回退到 pickle(要求 2)。

Arrow 的替代方案: 我们可以构建在 Protocol Buffers 之上,但协议缓冲区实际上并非专为数值数据而设计,并且该方法不满足 1、3 或 4。 构建在 Flatbuffers 之上实际上可以实现,但这需要实现 Arrow 已经具有的许多功能,并且我们更喜欢针对大数据优化的柱状数据布局。

加速

在这里,我们展示了 Python 的 pickle 模块的一些性能改进。 这些实验是使用 pickle.HIGHEST_PROTOCOL 完成的。 用于生成这些图的代码包含在文章末尾。

使用 NumPy 数组: 在机器学习和 AI 应用程序中,数据(例如,图像、神经网络权重、文本文档)通常表示为包含 NumPy 数组的数据结构。 当使用 NumPy 数组时,速度提升令人印象深刻。

Ray 的反序列化条几乎不可见,这不是错误。 这是支持零拷贝读取的结果(节省的费用主要来自缺少内存移动)。

请注意,最大的优势在于反序列化。 此处的速度提升是多个数量级,并且随着 NumPy 数组变得越来越大而变得更好(感谢设计目标 1、3 和 4)。 快速 反序列化 非常重要,原因有两个。 首先,一个对象可能会被序列化一次,然后被反序列化多次(例如,广播到所有工作进程的对象)。 其次,一种常见的模式是并行序列化许多对象,然后在一个工作进程上一次聚合和反序列化一个对象,从而使反序列化成为瓶颈。

不使用 NumPy 数组: 当使用常规 Python 对象时,我们无法利用共享内存,结果与 pickle 相当。

这些只是几个有趣的 Python 对象的示例。 最重要的案例是 NumPy 数组嵌套在其他对象中的案例。 请注意,我们的序列化库适用于非常通用的 Python 类型,包括自定义 Python 类和深度嵌套的对象。

API

序列化库可以通过 pyarrow 直接使用,如下所示。 更多文档请参见 这里

x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
serialized_x = pyarrow.serialize(x).to_buffer()
deserialized_x = pyarrow.deserialize(serialized_x)

它可以直接通过 Ray API 使用,如下所示。

x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
x_id = ray.put(x)
deserialized_x = ray.get(x_id)

数据表示

我们使用 Apache Arrow 作为底层与语言无关的数据布局。 对象存储在两个部分中: schemadata blob。 在高层次上,data blob 大致是对象中递归包含的所有数据值的扁平连接,并且 schema 定义了 data blob 的类型和嵌套结构。

技术细节: Python 序列(例如,字典、列表、元组、集合)被编码为其他类型(例如,布尔值、整数、字符串、字节、浮点数、双精度浮点数、date64s、张量(即 NumPy 数组)、列表、元组、字典和集合)的 Arrow UnionArrays。 嵌套序列使用 Arrow ListArrays 编码。 所有张量都被收集并附加到序列化对象的末尾,并且 UnionArray 包含对这些张量的引用。

为了给出一个具体的例子,考虑以下对象。

[(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]

它将在 Arrow 中以以下结构表示。

UnionArray(type_ids=[tuple, string, int, int, ndarray],
           tuples=ListArray(offsets=[0, 2],
                            UnionArray(type_ids=[int, int],
                                       ints=[1, 2])),
           strings=['hello'],
           ints=[3, 4],
           ndarrays=[<offset of numpy array>])

Arrow 使用 Flatbuffers 对序列化 schema 进行编码。 仅使用 schema,我们就可以计算数据 blob 中每个值的偏移量,而无需扫描数据 blob(与 Pickle 不同,这是实现快速反序列化的原因)。 这意味着我们可以避免在反序列化期间复制或以其他方式转换大型数组和其他值。 张量附加在 UnionArray 的末尾,并且可以使用共享内存有效地共享和访问。

请注意,实际对象将在内存中按如下所示布局。

Python 对象在堆中的布局。 每个框都分配在不同的内存区域中,并且框之间的箭头表示指针。


Arrow 序列化表示将如下所示。

Arrow 序列化对象的内存布局。


参与其中

我们欢迎贡献,尤其是在以下领域。

  • 使用 Arrow 的 C++ 和 Java 实现来实现 C++ 和 Java 的此版本。
  • 实现对更多 Python 类型的支持和更好的测试覆盖率。

重现上述图形

为了参考,可以使用以下代码重现这些图。 基准测试 ray.putray.get 而不是 pyarrow.serializepyarrow.deserialize 给出了相似的图形。 这些图是在此 commit 生成的。

import pickle
import pyarrow
import matplotlib.pyplot as plt
import numpy as np
import timeit


def benchmark_object(obj, number=10):
    # Time serialization and deserialization for pickle.
    pickle_serialize = timeit.timeit(
        lambda: pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL),
        number=number)
    serialized_obj = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
    pickle_deserialize = timeit.timeit(lambda: pickle.loads(serialized_obj),
                                       number=number)

    # Time serialization and deserialization for Ray.
    ray_serialize = timeit.timeit(
        lambda: pyarrow.serialize(obj).to_buffer(), number=number)
    serialized_obj = pyarrow.serialize(obj).to_buffer()
    ray_deserialize = timeit.timeit(
        lambda: pyarrow.deserialize(serialized_obj), number=number)

    return [[pickle_serialize, pickle_deserialize],
            [ray_serialize, ray_deserialize]]


def plot(pickle_times, ray_times, title, i):
    fig, ax = plt.subplots()
    fig.set_size_inches(3.8, 2.7)

    bar_width = 0.35
    index = np.arange(2)
    opacity = 0.6

    plt.bar(index, pickle_times, bar_width,
            alpha=opacity, color='r', label='Pickle')

    plt.bar(index + bar_width, ray_times, bar_width,
            alpha=opacity, color='c', label='Ray')

    plt.title(title, fontweight='bold')
    plt.ylabel('Time (seconds)', fontsize=10)
    labels = ['serialization', 'deserialization']
    plt.xticks(index + bar_width / 2, labels, fontsize=10)
    plt.legend(fontsize=10, bbox_to_anchor=(1, 1))
    plt.tight_layout()
    plt.yticks(fontsize=10)
    plt.savefig('plot-' + str(i) + '.png', format='png')


test_objects = [
    [np.random.randn(50000) for i in range(100)],
    {'weight-' + str(i): np.random.randn(50000) for i in range(100)},
    {i: set(['string1' + str(i), 'string2' + str(i)]) for i in range(100000)},
    [str(i) for i in range(200000)]
]

titles = [
    'List of large numpy arrays',
    'Dictionary of large numpy arrays',
    'Large dictionary of small sets',
    'Large list of strings'
]

for i in range(len(test_objects)):
    plot(*benchmark_object(test_objects[i]), titles[i], i)