DLPack 协议#

DLPack 协议 是一种稳定的内存数据结构,允许在处理多维数组或张量的主要框架之间进行交换。它旨在提供跨硬件支持,这意味着它允许在 CPU 以外的设备(例如 GPU)上交换数据。

DLPack 协议已被 Python 数据 API 标准联盟 选定为 Python 数组 API 标准,旨在实现 Python 生态系统中数组/张量库之间的设备感知数据交换。更多关于该标准的信息,请参阅 协议文档;更多关于 DLPack 的信息,请参阅 DLPack Python 规范

PyArrow 中 DLPack 的实现#

DLPack 协议的生产端已为 pa.Array 实现,可用于在 PyArrow 和其他张量库之间交换数据。支持的数据类型包括整数、无符号整数和浮点数。该协议不支持缺失数据,这意味着包含缺失值的 PyArrow 数组无法通过 DLPack 协议传输。目前,Arrow 对该协议的实现仅支持 CPU 设备上的数据。

协议的数据交换语法包括

  1. from_dlpack(x):使用实现 __dlpack__ 方法的数组对象,并在共享内存的同时创建一个新数组。

  2. __dlpack__(self, stream=None)__dlpack_device__:生成包含 DLPack 结构的 PyCapsule,该结构在 from_dlpack(x) 内部被调用。

PyArrow 实现了协议的后一部分(__dlpack__(self, stream=None)__dlpack_device__),因此可以被实现了 from_dlpack 的库所使用。

示例#

将 PyArrow CPU 数组转换为 NumPy 数组

>>> import pyarrow as pa
>>> import numpy as np
>>> array = pa.array([2, 0, 2, 4])
>>> array
<pyarrow.lib.Int64Array object at ...>
[
  2,
  0,
  2,
  4
]
>>> np.from_dlpack(array)
array([2, 0, 2, 4])

将 PyArrow CPU 数组转换为 PyTorch 张量

>>> import torch
>>> torch.from_dlpack(array)
tensor([2, 0, 2, 4])

将 PyArrow CPU 数组转换为 JAX 数组

>>> import jax
>>> jax.numpy.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)
>>> jax.dlpack.from_dlpack(array)
Array([2, 0, 2, 4], dtype=int32)