mypy arg should be one of a set of functions

from enum import Enum

def strategy_a(x: float, y: float) -> float:
    return x + y

def strategy_b(x: float, y: float) -> float:
    return x * y

class Strategy(Enum):
    A = strategy_a
    B = strategy_b

def run_strategy(x: float, y: float, strategy: Strategy) -> float:
    return strategy(x, y)

Let's say I have something like this where run_strategy's args strategy takes in some possible set of functions. How can I type it such that only those functions can be passed without mypy throwing an error.

Note: the above code throws an error as mypy complains Strategy is not a callable.

The above code is then run as

run_strategy(5, 17, Strategy.A)

A walk-around would be to specify a protocol Strategy that prescribes an implementation of call, and rebrand the Enum item as inheriting that protocol, see example blah.py:

from enum import Enum
from typing import Protocol


def strategy_a(x: float, y: float) -> float:
    return x + y


def strategy_b(x: float, y: float) -> float:
    return x * y


class Strategy(Protocol):
    def __call__(self, x: float, y: float) -> float:
        ...


class StrategyChoice(Enum, Strategy):
    A = strategy_a
    B = strategy_b


def run_strategy(x: float, y: float, strategy: StrategyChoice) -> float:
    return strategy(x, y)
mypy ./blah.py
Success: no issues found in 1 source file