Arrow Flight¶
与利用 Arrow Flight 协议相关的食谱
使用 Arrow Flight 的简单 Parquet 存储服务¶
假设您想实现一个服务,该服务可以使用 Arrow Flight 协议存储、发送和接收 Parquet 文件,pyarrow
在 pyarrow.flight
中提供了一个实现框架,特别是通过 pyarrow.flight.FlightServerBase
类。
import pathlib
import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet
class FlightServer(pa.flight.FlightServerBase):
def __init__(self, location="grpc://0.0.0.0:8815",
repo=pathlib.Path("./datasets"), **kwargs):
super(FlightServer, self).__init__(location, **kwargs)
self._location = location
self._repo = repo
def _make_flight_info(self, dataset):
dataset_path = self._repo / dataset
schema = pa.parquet.read_schema(dataset_path)
metadata = pa.parquet.read_metadata(dataset_path)
descriptor = pa.flight.FlightDescriptor.for_path(
dataset.encode('utf-8')
)
endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
return pyarrow.flight.FlightInfo(schema,
descriptor,
endpoints,
metadata.num_rows,
metadata.serialized_size)
def list_flights(self, context, criteria):
for dataset in self._repo.iterdir():
yield self._make_flight_info(dataset.name)
def get_flight_info(self, context, descriptor):
return self._make_flight_info(descriptor.path[0].decode('utf-8'))
def do_put(self, context, descriptor, reader, writer):
dataset = descriptor.path[0].decode('utf-8')
dataset_path = self._repo / dataset
data_table = reader.read_all()
pa.parquet.write_table(data_table, dataset_path)
def do_get(self, context, ticket):
dataset = ticket.ticket.decode('utf-8')
dataset_path = self._repo / dataset
return pa.flight.RecordBatchStream(pa.parquet.read_table(dataset_path))
def list_actions(self, context):
return [
("drop_dataset", "Delete a dataset."),
]
def do_action(self, context, action):
if action.type == "drop_dataset":
self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
else:
raise NotImplementedError
def do_drop_dataset(self, dataset):
dataset_path = self._repo / dataset
dataset_path.unlink()
示例服务器公开了 pyarrow.flight.FlightServerBase.list_flights()
,它是负责返回可用于获取的数据流列表的方法。
同样,pyarrow.flight.FlightServerBase.get_flight_info()
提供有关单个特定数据流的信息。
然后我们公开 pyarrow.flight.FlightServerBase.do_get()
,它负责实际获取公开的数据流并将其发送到客户端。
如果我们没有公开一种创建数据流的方法,那么允许列出和下载数据流将毫无用处,这是 pyarrow.flight.FlightServerBase.do_put()
的责任,它负责接收来自客户端的新数据并处理它(在本例中将其保存到 parquet 文件中)。
这些是最常见的 Arrow Flight 请求,如果我们需要添加更多功能,我们可以使用自定义操作来实现。
在前面的示例中,添加了一个 drop_dataset
自定义操作。所有自定义操作都通过 pyarrow.flight.FlightServerBase.do_action()
方法执行,因此服务器子类负责正确调度它们。在本例中,当 action.type 是我们期望的类型时,我们调用 do_drop_dataset 方法。
然后,我们的服务器可以使用 pyarrow.flight.FlightServerBase.serve()
启动。
if __name__ == '__main__':
server = FlightServer()
server._repo.mkdir(exist_ok=True)
server.serve()
服务器启动后,我们可以构建一个客户端来向其发出请求。
import pyarrow as pa
import pyarrow.flight
client = pa.flight.connect("grpc://0.0.0.0:8815")
我们可以创建一个新表并将其上传,以便将其存储到一个新的 parquet 文件中。
# Upload a new dataset
data_table = pa.table(
[["Mario", "Luigi", "Peach"]],
names=["Character"]
)
upload_descriptor = pa.flight.FlightDescriptor.for_path("uploaded.parquet")
writer, _ = client.do_put(upload_descriptor, data_table.schema)
writer.write_table(data_table)
writer.close()
上传后,我们应该能够检索新上传表的元数据。
# Retrieve metadata of newly uploaded dataset
flight = client.get_flight_info(upload_descriptor)
descriptor = flight.descriptor
print("Path:", descriptor.path[0].decode('utf-8'), "Rows:", flight.total_records, "Size:", flight.total_bytes)
print("=== Schema ===")
print(flight.schema)
print("==============")
Path: uploaded.parquet Rows: 3 Size: ...
=== Schema ===
Character: string
==============
我们可以获取数据集的内容。
# Read content of the dataset
reader = client.do_get(flight.endpoints[0].ticket)
read_table = reader.read_all()
print(read_table.to_pandas().head())
Character
0 Mario
1 Luigi
2 Peach
完成后,我们可以调用我们的自定义操作来删除我们新上传的数据集。
# Drop the newly uploaded dataset
client.do_action(pa.flight.Action("drop_dataset", "uploaded.parquet".encode('utf-8')))
为了确认我们的数据集已被删除,我们可以列出服务器当前存储的所有 parquet 文件。
# List existing datasets.
for flight in client.list_flights():
descriptor = flight.descriptor
print("Path:", descriptor.path[0].decode('utf-8'), "Rows:", flight.total_records, "Size:", flight.total_bytes)
print("=== Schema ===")
print(flight.schema)
print("==============")
print("")
流式 Parquet 存储服务¶
我们可以改进 Parquet 存储服务,并通过流式传输数据来避免将整个数据集保存在内存中。Flight 读取器和写入器与 PyArrow 中的其他读取器和写入器一样,可以进行迭代,因此让我们更新之前的服务器以利用这一点。
import pathlib
import pyarrow as pa
import pyarrow.flight
import pyarrow.parquet
class FlightServer(pa.flight.FlightServerBase):
def __init__(self, location="grpc://0.0.0.0:8815",
repo=pathlib.Path("./datasets"), **kwargs):
super(FlightServer, self).__init__(location, **kwargs)
self._location = location
self._repo = repo
def _make_flight_info(self, dataset):
dataset_path = self._repo / dataset
schema = pa.parquet.read_schema(dataset_path)
metadata = pa.parquet.read_metadata(dataset_path)
descriptor = pa.flight.FlightDescriptor.for_path(
dataset.encode('utf-8')
)
endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
return pyarrow.flight.FlightInfo(schema,
descriptor,
endpoints,
metadata.num_rows,
metadata.serialized_size)
def list_flights(self, context, criteria):
for dataset in self._repo.iterdir():
yield self._make_flight_info(dataset.name)
def get_flight_info(self, context, descriptor):
return self._make_flight_info(descriptor.path[0].decode('utf-8'))
def do_put(self, context, descriptor, reader, writer):
dataset = descriptor.path[0].decode('utf-8')
dataset_path = self._repo / dataset
# Read the uploaded data and write to Parquet incrementally
with dataset_path.open("wb") as sink:
with pa.parquet.ParquetWriter(sink, reader.schema) as writer:
for chunk in reader:
writer.write_table(pa.Table.from_batches([chunk.data]))
def do_get(self, context, ticket):
dataset = ticket.ticket.decode('utf-8')
# Stream data from a file
dataset_path = self._repo / dataset
reader = pa.parquet.ParquetFile(dataset_path)
return pa.flight.GeneratorStream(
reader.schema_arrow, reader.iter_batches())
def list_actions(self, context):
return [
("drop_dataset", "Delete a dataset."),
]
def do_action(self, context, action):
if action.type == "drop_dataset":
self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
else:
raise NotImplementedError
def do_drop_dataset(self, dataset):
dataset_path = self._repo / dataset
dataset_path.unlink()
首先,我们修改了 pyarrow.flight.FlightServerBase.do_put()
。我们不再在写入之前将所有上传的数据读入 pyarrow.Table
,而是迭代每个批次,并在它到来时将其添加到 Parquet 文件中。
然后,我们修改了 pyarrow.flight.FlightServerBase.do_get()
以将数据流式传输到客户端。这使用了 pyarrow.flight.GeneratorStream
,它接受一个模式和任何可迭代对象或迭代器。Flight 然后迭代并向客户端发送每个记录批次,允许我们处理即使是无法放入内存的大型 Parquet 文件。
虽然 GeneratorStream 具有可以流式传输数据的优点,但这意味着 Flight 必须为要发送的每个记录批次回调到 Python。相比之下,RecordBatchStream 要求所有数据都预先在内存中,但一旦创建,所有数据传输都完全在 C++ 中处理,无需调用 Python 代码。
让我们让服务器运行起来。和以前一样,我们将启动服务器。
if __name__ == '__main__':
server = FlightServer()
server._repo.mkdir(exist_ok=True)
server.serve()
我们创建一个客户端,这次我们将向写入器写入批次,就好像我们有一个数据流而不是内存中的表一样。
import pyarrow as pa
import pyarrow.flight
client = pa.flight.connect("grpc://0.0.0.0:8815")
# Upload a new dataset
NUM_BATCHES = 1024
ROWS_PER_BATCH = 4096
upload_descriptor = pa.flight.FlightDescriptor.for_path("streamed.parquet")
batch = pa.record_batch([
pa.array(range(ROWS_PER_BATCH)),
], names=["ints"])
writer, _ = client.do_put(upload_descriptor, batch.schema)
with writer:
for _ in range(NUM_BATCHES):
writer.write_batch(batch)
和以前一样,我们也可以读回它。同样,我们将从流中读取每个批次,而不是将它们全部读入一个表中。
# Read content of the dataset
flight = client.get_flight_info(upload_descriptor)
reader = client.do_get(flight.endpoints[0].ticket)
total_rows = 0
for chunk in reader:
total_rows += chunk.data.num_rows
print("Got", total_rows, "rows total, expected", NUM_BATCHES * ROWS_PER_BATCH)
Got 4194304 rows total, expected 4194304
使用用户名/密码进行身份验证¶
通常,服务需要一种方法来验证用户身份并识别他们的身份。Flight 提供了 几种实现身份验证的方法;最简单的方法是使用用户名/密码方案。在启动时,客户端使用用户名和密码向服务器进行身份验证。服务器返回一个授权令牌,将其包含在未来的请求中。
警告
身份验证只能在安全的加密通道上使用,即应启用 TLS。
注意
虽然该方案被称为“(HTTP) 基本身份验证”,但它实际上并没有实现 HTTP 身份验证(RFC 7325)本身。
虽然 Flight 提供了一些接口来实现这种方案,但服务器必须提供实际的实现,如下所示。这里的实现并不安全,仅作为最小示例提供。
import base64
import secrets
import pyarrow as pa
import pyarrow.flight
class EchoServer(pa.flight.FlightServerBase):
"""A simple server that just echoes any requests from DoAction."""
def do_action(self, context, action):
return [action.type.encode("utf-8"), action.body]
class BasicAuthServerMiddlewareFactory(pa.flight.ServerMiddlewareFactory):
"""
Middleware that implements username-password authentication.
Parameters
----------
creds: Dict[str, str]
A dictionary of username-password values to accept.
"""
def __init__(self, creds):
self.creds = creds
# Map generated bearer tokens to users
self.tokens = {}
def start_call(self, info, headers):
"""Validate credentials at the start of every call."""
# Search for the authentication header (case-insensitive)
auth_header = None
for header in headers:
if header.lower() == "authorization":
auth_header = headers[header][0]
break
if not auth_header:
raise pa.flight.FlightUnauthenticatedError("No credentials supplied")
# The header has the structure "AuthType TokenValue", e.g.
# "Basic <encoded username+password>" or "Bearer <random token>".
auth_type, _, value = auth_header.partition(" ")
if auth_type == "Basic":
# Initial "login". The user provided a username/password
# combination encoded in the same way as HTTP Basic Auth.
decoded = base64.b64decode(value).decode("utf-8")
username, _, password = decoded.partition(':')
if not password or password != self.creds.get(username):
raise pa.flight.FlightUnauthenticatedError("Unknown user or invalid password")
# Generate a secret, random bearer token for future calls.
token = secrets.token_urlsafe(32)
self.tokens[token] = username
return BasicAuthServerMiddleware(token)
elif auth_type == "Bearer":
# An actual call. Validate the bearer token.
username = self.tokens.get(value)
if username is None:
raise pa.flight.FlightUnauthenticatedError("Invalid token")
return BasicAuthServerMiddleware(value)
raise pa.flight.FlightUnauthenticatedError("No credentials supplied")
class BasicAuthServerMiddleware(pa.flight.ServerMiddleware):
"""Middleware that implements username-password authentication."""
def __init__(self, token):
self.token = token
def sending_headers(self):
"""Return the authentication token to the client."""
return {"authorization": f"Bearer {self.token}"}
class NoOpAuthHandler(pa.flight.ServerAuthHandler):
"""
A handler that implements username-password authentication.
This is required only so that the server will respond to the internal
Handshake RPC call, which the client calls when authenticate_basic_token
is called. Otherwise, it should be a no-op as the actual authentication is
implemented in middleware.
"""
def authenticate(self, outgoing, incoming):
pass
def is_valid(self, token):
return ""
然后我们可以启动服务器。
if __name__ == '__main__':
server = EchoServer(
auth_handler=NoOpAuthHandler(),
location="grpc://0.0.0.0:8816",
middleware={
"basic": BasicAuthServerMiddlewareFactory({
"test": "password",
})
},
)
server.serve()
然后,我们可以创建一个客户端并登录。
import pyarrow as pa
import pyarrow.flight
client = pa.flight.connect("grpc://0.0.0.0:8816")
token_pair = client.authenticate_basic_token(b'test', b'password')
print(token_pair)
(b'authorization', b'Bearer ...')
对于未来的调用,我们将身份验证令牌与调用一起包含。
action = pa.flight.Action("echo", b"Hello, world!")
options = pa.flight.FlightCallOptions(headers=[token_pair])
for response in client.do_action(action=action, options=options):
print(response.body.to_pybytes())
b'echo'
b'Hello, world!'
如果我们没有这样做,我们会收到身份验证错误。
try:
list(client.do_action(action=action))
except pa.flight.FlightUnauthenticatedError as e:
print("Unauthenticated:", e)
else:
raise RuntimeError("Expected call to fail")
Unauthenticated: No credentials supplied. Detail: Unauthenticated
或者,如果我们在登录时使用错误的凭据,我们也会收到错误。
try:
client.authenticate_basic_token(b'invalid', b'password')
except pa.flight.FlightUnauthenticatedError as e:
print("Unauthenticated:", e)
else:
raise RuntimeError("Expected call to fail")
Unauthenticated: Unknown user or invalid password. Detail: Unauthenticated
使用 TLS 保护连接¶
继之前通过用户名和密码管理服务器流量的场景之后,HTTPS(更具体地说是 TLS)通信通过加密客户端和服务器之间的消息提供了额外的安全层。这是使用证书实现的。在开发过程中,最简单的方法是使用自签名证书进行开发。在启动时,服务器加载公钥和私钥,客户端使用 TLS 根证书对服务器进行身份验证。
注意
在生产环境中,建议使用由证书颁发机构签名的证书。
步骤 1 - 生成自签名证书
使用 dotnet 在 Windows 上或使用 openssl 在 Linux 或 MacOS 上生成自签名证书。或者,可以使用 Arrow 测试数据存储库 中的自签名证书。根据生成的文件,您可能需要将其转换为 .crt 和 .key 文件,如 Arrow 服务器所需。实现此目的的一种方法是 openssl,请访问此 IBM 文章 以获取更多信息。
步骤 2 - 运行启用 TLS 的服务器
以下代码是用于接收数据的 Arrow 服务器的最小工作示例,该服务器使用 TLS。
import argparse
import pyarrow
import pyarrow.flight
class FlightServer(pyarrow.flight.FlightServerBase):
def __init__(self, host="localhost", location=None,
tls_certificates=None, verify_client=False,
root_certificates=None, auth_handler=None):
super(FlightServer, self).__init__(
location, auth_handler, tls_certificates, verify_client,
root_certificates)
self.flights = {}
@classmethod
def descriptor_to_key(self, descriptor):
return (descriptor.descriptor_type.value, descriptor.command,
tuple(descriptor.path or tuple()))
def do_put(self, context, descriptor, reader, writer):
key = FlightServer.descriptor_to_key(descriptor)
print(key)
self.flights[key] = reader.read_all()
print(self.flights[key])
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--tls", nargs=2, default=None, metavar=('CERTFILE', 'KEYFILE'))
args = parser.parse_args()
tls_certificates = []
scheme = "grpc+tls"
host = "localhost"
port = "5005"
with open(args.tls[0], "rb") as cert_file:
tls_cert_chain = cert_file.read()
with open(args.tls[1], "rb") as key_file:
tls_private_key = key_file.read()
tls_certificates.append((tls_cert_chain, tls_private_key))
location = "{}://{}:{}".format(scheme, host, port)
server = FlightServer(host, location,
tls_certificates=tls_certificates)
print("Serving on", location)
server.serve()
if __name__ == '__main__':
main()
运行服务器,您应该看到 Serving on grpc+tls://127.0.0.1:5005
。
步骤 3 - 安全连接到服务器假设我们想连接到客户端并将一些数据推送到它。以下代码使用 TLS 加密安全地将信息发送到服务器。
import argparse
import pyarrow
import pyarrow.flight
import pandas as pd
# Assumes incoming data object is a Pandas Dataframe
def push_to_server(name, data, client):
object_to_send = pyarrow.Table.from_pandas(data)
writer, _ = client.do_put(pyarrow.flight.FlightDescriptor.for_path(name), object_to_send.schema)
writer.write_table(object_to_send)
writer.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--tls-roots', default=None,
help='Path to trusted TLS certificate(s)')
parser.add_argument('--host', default="localhost",
help='Host endpoint')
parser.add_argument('--port', default=5005,
help='Host port')
args = parser.parse_args()
kwargs = {}
with open(args.tls_roots, "rb") as root_certs:
kwargs["tls_root_certs"] = root_certs.read()
client = pyarrow.flight.FlightClient(f"grpc+tls://{args.host}:{args.port}", **kwargs)
data = {'Animal': ['Dog', 'Cat', 'Mouse'], 'Size': ['Big', 'Small', 'Tiny']}
df = pd.DataFrame(data, columns=['Animal', 'Size'])
push_to_server("AnimalData", df, client)
if __name__ == '__main__':
try:
main()
except Exception as e:
print(e)
运行客户端脚本,您应该看到服务器打印出有关它刚刚接收的数据的信息。
传播 OpenTelemetry 跟踪¶
使用 OpenTelemetry 进行分布式追踪,可以收集 Flight 服务中跨调用级别的性能指标。为了在 Flight 客户端和服务器之间关联跨度,必须在两者之间传递追踪上下文。这可以通过 pyarrow.flight.FlightCallOptions
中的标头手动传递,也可以使用中间件自动传播。
此示例展示了如何通过中间件实现追踪传播。客户端中间件需要将追踪上下文注入到调用标头中。服务器中间件需要从标头中提取追踪上下文,并将上下文传递到新的跨度中。可选地,客户端中间件还可以创建一个新的跨度来计时客户端调用。
步骤 1:定义客户端中间件
import pyarrow.flight as flight
from opentelemetry import trace
from opentelemetry.propagate import inject
from opentelemetry.trace.status import StatusCode
class ClientTracingMiddlewareFactory(flight.ClientMiddlewareFactory):
def __init__(self):
self._tracer = trace.get_tracer(__name__)
def start_call(self, info):
span = self._tracer.start_span(f"client.{info.method}")
return ClientTracingMiddleware(span)
class ClientTracingMiddleware(flight.ClientMiddleware):
def __init__(self, span):
self._span = span
def sending_headers(self):
ctx = trace.set_span_in_context(self._span)
carrier = {}
inject(carrier=carrier, context=ctx)
return carrier
def call_completed(self, exception):
if exception:
self._span.record_exception(exception)
self._span.set_status(StatusCode.ERROR)
print(exception)
else:
self._span.set_status(StatusCode.OK)
self._span.end()
步骤 2:定义服务器中间件
import pyarrow.flight as flight
from opentelemetry import trace
from opentelemetry.propagate import extract
from opentelemetry.trace.status import StatusCode
class ServerTracingMiddlewareFactory(flight.ServerMiddlewareFactory):
def __init__(self):
self._tracer = trace.get_tracer(__name__)
def start_call(self, info, headers):
context = extract(headers)
span = self._tracer.start_span(f"server.{info.method}", context=context)
return ServerTracingMiddleware(span)
class ServerTracingMiddleware(flight.ServerMiddleware):
def __init__(self, span):
self._span = span
def call_completed(self, exception):
if exception:
self._span.record_exception(exception)
self._span.set_status(StatusCode.ERROR)
print(exception)
else:
self._span.set_status(StatusCode.OK)
self._span.end()
步骤 3:配置追踪导出器、处理器和提供者
服务器和客户端都需要配置 OpenTelemetry SDK 来记录跨度并将其导出到某个地方。为了便于示例,我们将跨度收集到 Python 列表中,但在实际应用中,通常会将其配置为导出到像 Jaeger 这样的服务。有关导出器的其他示例,请参见 OpenTelemetry 导出器。
作为此步骤的一部分,您需要定义跨度运行的资源。至少需要服务名称,但也可以包含其他信息,例如主机名、进程 ID、服务版本和操作系统。
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
class TestSpanExporter(SpanExporter):
def __init__(self):
self.spans = []
def export(self, spans):
self.spans.extend(spans)
return SpanExportResult.SUCCESS
def configure_tracing():
# Service name is required for most backends,
# and although it's not necessary for console export,
# it's good to set service name anyways.
resource = Resource(attributes={
SERVICE_NAME: "my-service"
})
exporter = TestSpanExporter()
provider = TracerProvider(resource=resource)
processor = SimpleSpanProcessor(exporter)
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
return exporter
步骤 4:将中间件添加到服务器
现在,我们可以在之前的 EchoServer 中使用中间件。
if __name__ == '__main__':
exporter = configure_tracing()
server = EchoServer(
location="grpc://0.0.0.0:8816",
middleware={
"tracing": ServerTracingMiddlewareFactory()
},
)
server.serve()
步骤 5:将中间件添加到客户端
client = pa.flight.connect(
"grpc://0.0.0.0:8816",
middleware=[ClientTracingMiddlewareFactory()],
)
步骤 6:在活动跨度内使用客户端
当我们在 OpenTelemetry 跨度内使用客户端进行调用时,我们的客户端中间件将为客户端的 Flight 调用创建一个子跨度,然后将跨度上下文传播到服务器。我们的服务器中间件将接收该追踪上下文,并创建一个另一个子跨度。
from opentelemetry import trace
# Client would normally also need to configure tracing, but for this example
# the client and server are running in the same Python process.
# exporter = configure_tracing()
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("hello_world") as span:
action = pa.flight.Action("echo", b"Hello, world!")
# Call list() on do_action to drain all results.
list(client.do_action(action=action))
print(f"There are {len(exporter.spans)} spans.")
print(f"The span names are:\n {list(span.name for span in exporter.spans)}.")
print(f"The span status codes are:\n "
f"{list(span.status.status_code for span in exporter.spans)}.")
There are 3 spans.
The span names are:
['server.FlightMethod.DO_ACTION', 'client.FlightMethod.DO_ACTION', 'hello_world'].
The span status codes are:
[<StatusCode.OK: 1>, <StatusCode.OK: 1>, <StatusCode.UNSET: 0>].
正如预期的那样,我们有三个跨度:一个在我们的客户端代码中,一个在客户端中间件中,一个在服务器中间件中。