使用 Ray 和 Apache Arrow 实现快速 Python 序列化


发布于 2017年10月15日
作者 Philipp Moritz, Robert Nishihara

这最初发布在 Ray 博客上。Philipp MoritzRobert Nishihara 是加州大学伯克利分校的研究生。

这篇帖子详细阐述了 RayApache Arrow 之间的集成。它主要解决的问题是数据序列化

根据维基百科序列化

... 将数据结构或对象状态转换为可存储... 或传输... 并稍后(可能在不同的计算机环境中)重建的格式的过程。

为什么需要任何转换?当您创建一个 Python 对象时,它可能指向其他 Python 对象,这些对象都分配在内存的不同区域,所有这些在另一台机器上的另一个进程解包时都必须有意义。

序列化和反序列化是并行和分布式计算的瓶颈,特别是在具有大型对象和大量数据的机器学习应用程序中。

设计目标

由于 Ray 针对机器学习和人工智能应用程序进行了优化,我们非常关注序列化和数据处理,并遵循以下设计目标

  1. 它对大型数值数据应该非常高效(这包括 NumPy 数组和 Pandas DataFrames,以及递归包含 NumPy 数组和 Pandas DataFrames 的对象)。
  2. 对于通用 Python 类型,它的速度应该与 Pickle 相当。
  3. 它应该与共享内存兼容,允许多个进程使用相同的数据而无需复制。
  4. 反序列化应该非常快(如果可能,不应需要读取整个序列化对象)。
  5. 它应该是语言独立的(最终我们希望 Python worker 能够使用 Java 或其他语言的 worker 创建的对象,反之亦然)。

我们的方法和替代方案

Python 中常用的序列化方法是 pickle 模块。Pickle 非常通用,特别是如果您使用像 cloudpickle 这样的变体。但是,它不满足要求 1、3、4 或 5。像 json 这样的替代方案满足 5,但不满足 1-4。

我们的方法: 为了满足要求 1-5,我们选择使用 Apache Arrow 格式作为底层数据表示。我们与 Apache Arrow 团队合作,构建了将通用 Python 对象映射到 Arrow 格式和从 Arrow 格式映射的。这种方法的一些特性:

  • 数据布局是语言独立的(要求 5)。
  • 可以在恒定时间内计算序列化数据块中的偏移量,而无需读取整个对象(要求 1 和 4)。
  • Arrow 支持零拷贝读取,因此对象可以自然地存储在共享内存中并由多个进程使用(要求 1 和 3)。
  • 对于我们无法很好处理的任何情况,我们可以自然地回退到 pickle(要求 2)。

Arrow 的替代方案: 我们本可以基于 Protocol Buffers 构建,但 protocol buffers 并非为数值数据设计,该方法无法满足 1、3 或 4。基于 Flatbuffers 构建实际上可以工作,但这需要实现 Arrow 已经拥有的许多功能,而且我们更喜欢针对大数据进行优化的列式数据布局。

加速

这里我们展示了一些相对于 Python 的 pickle 模块的性能改进。实验使用 pickle.HIGHEST_PROTOCOL 完成。生成这些图的代码包含在帖子末尾。

使用 NumPy 数组: 在机器学习和人工智能应用程序中,数据(例如,图像、神经网络权重、文本文档)通常表示为包含 NumPy 数组的数据结构。使用 NumPy 数组时,加速效果令人印象深刻。

Ray 在反序列化方面的条形图几乎不可见并非错误。这是对零拷贝读取支持的结果(节省主要来自没有内存移动)。

请注意,最大的优势在于反序列化。这里的加速是几个数量级的,并且随着 NumPy 数组变大而变得更好(这得益于设计目标 1、3 和 4)。使反序列化快速很重要,原因有两个。首先,一个对象可能被序列化一次,然后反序列化多次(例如,一个广播到所有 worker 的对象)。其次,一个常见的模式是许多对象并行序列化,然后在一个 worker 上逐个聚合和反序列化,从而使反序列化成为瓶颈。

没有 NumPy 数组: 当使用常规 Python 对象时,我们无法利用共享内存,结果与 pickle 相当。

这些只是有趣 Python 对象的几个例子。最重要的案例是 NumPy 数组嵌套在其他对象中的情况。请注意,我们的序列化库适用于非常通用的 Python 类型,包括自定义 Python 类和深度嵌套对象。

API

序列化库可以通过 pyarrow 直接使用,如下所示。更多文档可在此处获取。

x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
serialized_x = pyarrow.serialize(x).to_buffer()
deserialized_x = pyarrow.deserialize(serialized_x)

它可以通过 Ray API 直接使用,如下所示。

x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
x_id = ray.put(x)
deserialized_x = ray.get(x_id)

数据表示

我们使用 Apache Arrow 作为底层的语言无关数据布局。对象存储为两部分:一个模式和一个数据块。从高层次来看,数据块大致是对象中递归包含的所有数据值的扁平化连接,而模式定义了数据块的类型和嵌套结构。

技术细节: Python 序列(例如,字典、列表、元组、集合)被编码为其他类型(例如,布尔值、整数、字符串、字节、浮点数、双精度浮点数、date64s、张量(即 NumPy 数组)、列表、元组、字典和集合)的 Arrow 联合数组。嵌套序列使用 Arrow 列表数组进行编码。所有张量都被收集并附加到序列化对象的末尾,联合数组包含对这些张量的引用。

举一个具体的例子,考虑以下对象。

[(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]

它将以以下结构在 Arrow 中表示。

UnionArray(type_ids=[tuple, string, int, int, ndarray],
           tuples=ListArray(offsets=[0, 2],
                            UnionArray(type_ids=[int, int],
                                       ints=[1, 2])),
           strings=['hello'],
           ints=[3, 4],
           ndarrays=[<offset of numpy array>])

Arrow 使用 Flatbuffers 编码序列化模式。仅使用模式,我们就可以计算数据块中每个值的偏移量,而无需扫描数据块(与 Pickle 不同,这使得快速反序列化成为可能)。这意味着我们可以在反序列化期间避免复制或以其他方式转换大型数组和其他值。张量附加在联合数组的末尾,并且可以使用共享内存高效地共享和访问。

请注意,实际对象将按如下所示在内存中布局。

堆中 Python 对象的布局。每个框都分配在不同的内存区域,框之间的箭头表示指针。

Arrow 序列化表示将如下所示。

Arrow 序列化对象的内存布局。

参与其中

我们欢迎贡献,特别是在以下领域。

  • 使用 Arrow 的 C++ 和 Java 实现来为 C++ 和 Java 实现此版本。
  • 实现对更多 Python 类型的支持和更好的测试覆盖率。

重现上述图表

作为参考,可以使用以下代码重现这些图表。基准测试 ray.putray.get 而不是 pyarrow.serializepyarrow.deserialize 会给出相似的图表。这些图表是在此提交处生成的。

import pickle
import pyarrow
import matplotlib.pyplot as plt
import numpy as np
import timeit


def benchmark_object(obj, number=10):
    # Time serialization and deserialization for pickle.
    pickle_serialize = timeit.timeit(
        lambda: pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL),
        number=number)
    serialized_obj = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
    pickle_deserialize = timeit.timeit(lambda: pickle.loads(serialized_obj),
                                       number=number)

    # Time serialization and deserialization for Ray.
    ray_serialize = timeit.timeit(
        lambda: pyarrow.serialize(obj).to_buffer(), number=number)
    serialized_obj = pyarrow.serialize(obj).to_buffer()
    ray_deserialize = timeit.timeit(
        lambda: pyarrow.deserialize(serialized_obj), number=number)

    return [[pickle_serialize, pickle_deserialize],
            [ray_serialize, ray_deserialize]]


def plot(pickle_times, ray_times, title, i):
    fig, ax = plt.subplots()
    fig.set_size_inches(3.8, 2.7)

    bar_width = 0.35
    index = np.arange(2)
    opacity = 0.6

    plt.bar(index, pickle_times, bar_width,
            alpha=opacity, color='r', label='Pickle')

    plt.bar(index + bar_width, ray_times, bar_width,
            alpha=opacity, color='c', label='Ray')

    plt.title(title, fontweight='bold')
    plt.ylabel('Time (seconds)', fontsize=10)
    labels = ['serialization', 'deserialization']
    plt.xticks(index + bar_width / 2, labels, fontsize=10)
    plt.legend(fontsize=10, bbox_to_anchor=(1, 1))
    plt.tight_layout()
    plt.yticks(fontsize=10)
    plt.savefig('plot-' + str(i) + '.png', format='png')


test_objects = [
    [np.random.randn(50000) for i in range(100)],
    {'weight-' + str(i): np.random.randn(50000) for i in range(100)},
    {i: set(['string1' + str(i), 'string2' + str(i)]) for i in range(100000)},
    [str(i) for i in range(200000)]
]

titles = [
    'List of large numpy arrays',
    'Dictionary of large numpy arrays',
    'Large dictionary of small sets',
    'Large list of strings'
]

for i in range(len(test_objects)):
    plot(*benchmark_object(test_objects[i]), titles[i], i)