Pyarrow basic auth: How to prevent `Stream is closed`?
I am new to Arrow Flight and pyarrow (v=6.0.1), and am trying to implement basic auth but I am always getting an error:
OSError: Stream is closed
I have created a minimal reproducing sample, by running the following two files sequentially (representing server and client respectively):
from typing import Dict, Union
from pyarrow.lib import tobytes
from pyarrow.flight import BasicAuth, FlightUnauthenticatedError, ServerAuthHandler, FlightServerBase
from pyarrow._flight import ServerAuthSender, ServerAuthReader
class ServerBasicAuthHandler(ServerAuthHandler):
def __init__(self, creds: Dict[str, str]):
self.creds = {user.encode(): pw.encode() for user, pw in creds.items()}
def authenticate(self, outgoing: ServerAuthSender, incoming: ServerAuthReader):
buf = incoming.read() # this line raises "OSError: Stream is closed"
auth = BasicAuth.deserialize(buf)
if auth.username not in self.creds:
raise FlightUnauthenticatedError("unknown user")
if self.creds[auth.username] != auth.password:
raise FlightUnauthenticatedError("wrong password")
outgoing.write(tobytes(auth.username))
def is_valid(self, token: bytes) -> Union[bytes, str]:
if not token:
raise FlightUnauthenticatedError("no basic auth provided")
if token not in self.creds:
raise FlightUnauthenticatedError("unknown user")
return token
service = FlightServerBase(
location=f"grpc://[::]:50051",
auth_handler=ServerBasicAuthHandler({"user": "pw"}),
)
service.serve()
from pyarrow.flight import FlightClient
client = FlightClient(location=f"grpc://localhost:50051")
client.authenticate_basic_token("user", "pw")
I basically copied the ServerAuthHandler
implementation from their tests, so it is proven to work. However, I can't get it to work.
The error message Stream is closed
hard to debug. I don't know where it comes from and I can't trace it to anywhere within the pyarrow implementation (neither Pythonside nor C++ side). I can't see where it comes from.
Any help or hints on how to prevent this error would be appreciated.
The example in the OP is mixing up two authentication implementations (which is indeed confusing). The "BasicAuth" object isn't actual HTTP basic authentication that the authenticate_basic_token
method implements; this is because contributors have implemented a variety of authentication methods over the years. The actual test is as follows:
header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
no_op_auth_handler = NoopAuthHandler()
def test_authenticate_basic_token():
"""Test authenticate_basic_token with bearer token and auth headers."""
with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
"auth": HeaderAuthServerMiddlewareFactory()
}) as server:
client = FlightClient(('localhost', server.port))
token_pair = client.authenticate_basic_token(b'test', b'password')
assert token_pair[0] == b'authorization'
assert token_pair[1] == b'Bearer token1234'
i.e. we're not using authenticate
but rather a "middleware" to do the implementation. A full example looks as follows:
import base64
import pyarrow.flight as flight
class BasicAuthServerMiddlewareFactory(flight.ServerMiddlewareFactory):
def __init__(self, creds):
self.creds = creds
def start_call(self, info, headers):
token = None
for header in headers:
if header.lower() == "authorization":
token = headers[header]
break
if not token:
raise flight.FlightUnauthenticatedError("No credentials supplied")
values = token[0].split(' ', 1)
if values[0] == 'Basic':
decoded = base64.b64decode(values[1])
pair = decoded.decode("utf-8").split(':')
if pair[0] not in self.creds:
raise flight.FlightUnauthenticatedError("No credentials supplied")
if pair[1] != self.creds[pair[0]]:
raise flight.FlightUnauthenticatedError("No credentials supplied")
return BasicAuthServerMiddleware("BearerTokenValue")
raise flight.FlightUnauthenticatedError("No credentials supplied")
class BasicAuthServerMiddleware(flight.ServerMiddleware):
def __init__(self, token):
self.token = token
def sending_headers(self):
return {'authorization': f'Bearer {self.token}'}
class NoOpAuthHandler(flight.ServerAuthHandler):
def authenticate(self, outgoing, incoming):
pass
def is_valid(self, token):
return ""
with flight.FlightServerBase(auth_handler=NoOpAuthHandler(), middleware={
"basic": BasicAuthServerMiddlewareFactory({"test": "password"})
}) as server:
client = flight.connect(('localhost', server.port))
token_pair = client.authenticate_basic_token(b'test', b'password')
print(token_pair)
assert token_pair[0] == b'authorization'
assert token_pair[1] == b'Bearer BearerTokenValue'