Pyarrow basic auth: How to prevent `Stream is closed`?

I am new to Arrow Flight and pyarrow (v=6.0.1), and am trying to implement basic auth but I am always getting an error:

OSError: Stream is closed

I have created a minimal reproducing sample, by running the following two files sequentially (representing server and client respectively):

from typing import Dict, Union
from pyarrow.lib import tobytes
from pyarrow.flight import BasicAuth, FlightUnauthenticatedError, ServerAuthHandler, FlightServerBase
from pyarrow._flight import ServerAuthSender, ServerAuthReader


class ServerBasicAuthHandler(ServerAuthHandler):
    def __init__(self, creds: Dict[str, str]):
        self.creds = {user.encode(): pw.encode() for user, pw in creds.items()}

    def authenticate(self, outgoing: ServerAuthSender, incoming: ServerAuthReader):
        buf = incoming.read()  # this line raises "OSError: Stream is closed"
        auth = BasicAuth.deserialize(buf)
        if auth.username not in self.creds:
            raise FlightUnauthenticatedError("unknown user")
        if self.creds[auth.username] != auth.password:
            raise FlightUnauthenticatedError("wrong password")
        outgoing.write(tobytes(auth.username))

    def is_valid(self, token: bytes) -> Union[bytes, str]:
        if not token:
            raise FlightUnauthenticatedError("no basic auth provided")
        if token not in self.creds:
            raise FlightUnauthenticatedError("unknown user")
        return token

service = FlightServerBase(
    location=f"grpc://[::]:50051",
    auth_handler=ServerBasicAuthHandler({"user": "pw"}),
)

service.serve()
from pyarrow.flight import FlightClient

client = FlightClient(location=f"grpc://localhost:50051")
client.authenticate_basic_token("user", "pw")

I basically copied the ServerAuthHandler implementation from their tests, so it is proven to work. However, I can't get it to work.

The error message Stream is closed hard to debug. I don't know where it comes from and I can't trace it to anywhere within the pyarrow implementation (neither Pythonside nor C++ side). I can't see where it comes from.

Any help or hints on how to prevent this error would be appreciated.


The example in the OP is mixing up two authentication implementations (which is indeed confusing). The "BasicAuth" object isn't actual HTTP basic authentication that the authenticate_basic_token method implements; this is because contributors have implemented a variety of authentication methods over the years. The actual test is as follows:

header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
no_op_auth_handler = NoopAuthHandler()


def test_authenticate_basic_token():
    """Test authenticate_basic_token with bearer token and auth headers."""
    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
        "auth": HeaderAuthServerMiddlewareFactory()
    }) as server:
        client = FlightClient(('localhost', server.port))
        token_pair = client.authenticate_basic_token(b'test', b'password')
        assert token_pair[0] == b'authorization'
        assert token_pair[1] == b'Bearer token1234'

i.e. we're not using authenticate but rather a "middleware" to do the implementation. A full example looks as follows:

import base64
import pyarrow.flight as flight

class BasicAuthServerMiddlewareFactory(flight.ServerMiddlewareFactory):
    def __init__(self, creds):
        self.creds = creds

    def start_call(self, info, headers):
        token = None
        for header in headers:
            if header.lower() == "authorization":
                token = headers[header]
                break

        if not token:
            raise flight.FlightUnauthenticatedError("No credentials supplied")

        values = token[0].split(' ', 1)
        if values[0] == 'Basic':
            decoded = base64.b64decode(values[1])
            pair = decoded.decode("utf-8").split(':')
            if pair[0] not in self.creds:
                raise flight.FlightUnauthenticatedError("No credentials supplied")
            if pair[1] != self.creds[pair[0]]:
                raise flight.FlightUnauthenticatedError("No credentials supplied")
            return BasicAuthServerMiddleware("BearerTokenValue")

        raise flight.FlightUnauthenticatedError("No credentials supplied")


class BasicAuthServerMiddleware(flight.ServerMiddleware):
    def __init__(self, token):
        self.token = token

    def sending_headers(self):
        return {'authorization': f'Bearer {self.token}'}


class NoOpAuthHandler(flight.ServerAuthHandler):
    def authenticate(self, outgoing, incoming):
        pass

    def is_valid(self, token):
        return ""


with flight.FlightServerBase(auth_handler=NoOpAuthHandler(), middleware={
    "basic": BasicAuthServerMiddlewareFactory({"test": "password"})
}) as server:
    client = flight.connect(('localhost', server.port))
    token_pair = client.authenticate_basic_token(b'test', b'password')
    print(token_pair)
    assert token_pair[0] == b'authorization'
    assert token_pair[1] == b'Bearer BearerTokenValue'