PyBind11 Stub-Generation

Top

Questions to David Rotermund

ModuleStubsGenerator

If you editor / syntax highlighting complains then you might need module stubs. This is the way how I automatically generate my pybind11-stubs

pip install pybind11-stubgen
# Based on
# https://github.com/sizmailov/pybind11-stubgen/blob/master/pybind11_stubgen/__init__.py

from __future__ import annotations

import importlib
import logging
import re
from argparse import ArgumentParser, Namespace
from pathlib import Path
import glob

from pybind11_stubgen.parser.interface import IParser
from pybind11_stubgen.parser.mixins.error_handlers import (
    IgnoreAllErrors,
    IgnoreInvalidExpressionErrors,
    IgnoreInvalidIdentifierErrors,
    IgnoreUnresolvedNameErrors,
    LogErrors,
    LoggerData,
    SuggestCxxSignatureFix,
    TerminateOnFatalErrors,
)
from pybind11_stubgen.parser.mixins.filter import (
    FilterClassMembers,
    FilterInvalidIdentifiers,
    FilterPybind11ViewClasses,
    FilterPybindInternals,
    FilterTypingModuleAttributes,
)
from pybind11_stubgen.parser.mixins.fix import (
    FixBuiltinTypes,
    FixCurrentModulePrefixInTypeNames,
    FixMissing__all__Attribute,
    FixMissing__future__AnnotationsImport,
    FixMissingEnumMembersAnnotation,
    FixMissingFixedSizeImport,
    FixMissingImports,
    FixMissingNoneHashFieldAnnotation,
    FixNumpyArrayDimAnnotation,
    FixNumpyArrayDimTypeVar,
    FixNumpyArrayFlags,
    FixNumpyArrayRemoveParameters,
    FixNumpyDtype,
    FixPEP585CollectionNames,
    FixPybind11EnumStrDoc,
    FixRedundantBuiltinsAnnotation,
    FixRedundantMethodsFromBuiltinObject,
    FixScipyTypeArguments,
    FixTypingTypeNames,
    FixValueReprRandomAddress,
    OverridePrintSafeValues,
    RemoveSelfAnnotation,
    ReplaceReadWritePropertyWithField,
    RewritePybind11EnumValueRepr,
)
from pybind11_stubgen.parser.mixins.parse import (
    BaseParser,
    ExtractSignaturesFromPybind11Docstrings,
    ParserDispatchMixin,
)
from pybind11_stubgen.printer import Printer
from pybind11_stubgen.structs import QualifiedName
from pybind11_stubgen.writer import Writer


class CLIArgs(Namespace):
    output_dir: str
    root_suffix: str
    ignore_invalid_expressions: re.Pattern | None
    ignore_invalid_identifiers: re.Pattern | None
    ignore_unresolved_names: re.Pattern | None
    ignore_all_errors: bool
    enum_class_locations: list[tuple[re.Pattern, str]]
    numpy_array_wrap_with_annotated: bool
    numpy_array_use_type_var: bool
    numpy_array_remove_parameters: bool
    print_invalid_expressions_as_is: bool
    print_safe_value_reprs: re.Pattern | None
    exit_code: bool
    dry_run: bool
    stub_extension: str
    module_name: str


def arg_parser() -> ArgumentParser:
    def regex(pattern_str: str) -> re.Pattern:
        try:
            return re.compile(pattern_str)
        except re.error as e:
            raise ValueError(f"Invalid REGEX pattern: {e}")

    def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
        pattern_str, path = regex_path.rsplit(":", maxsplit=1)
        if any(not part.isidentifier() for part in path.split(".")):
            raise ValueError(f"Invalid PATH: {path}")
        return regex(pattern_str), path

    parser = ArgumentParser(
        prog="pybind11-stubgen", description="Generates stubs for specified modules"
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        help="The root directory for output stubs",
        default=".",
    )
    parser.add_argument(
        "--root-suffix",
        type=str,
        default=None,
        dest="root_suffix",
        help="Top-level module directory suffix",
    )

    parser.add_argument(
        "--ignore-invalid-expressions",
        metavar="REGEX",
        default=None,
        type=regex,
        help="Ignore invalid expressions matching REGEX",
    )
    parser.add_argument(
        "--ignore-invalid-identifiers",
        metavar="REGEX",
        default=None,
        type=regex,
        help="Ignore invalid identifiers matching REGEX",
    )

    parser.add_argument(
        "--ignore-unresolved-names",
        metavar="REGEX",
        default=None,
        type=regex,
        help="Ignore unresolved names matching REGEX",
    )

    parser.add_argument(
        "--ignore-all-errors",
        default=False,
        action="store_true",
        help="Ignore all errors during module parsing",
    )

    parser.add_argument(
        "--enum-class-locations",
        dest="enum_class_locations",
        metavar="REGEX:LOC",
        action="append",
        default=[],
        type=regex_colon_path,
        help="Locations of enum classes in "
        "<enum-class-name-regex>:<path-to-class> format. "
        "Example: `MyEnum:foo.bar.Baz`",
    )

    numpy_array_fix = parser.add_mutually_exclusive_group()
    numpy_array_fix.add_argument(
        "--numpy-array-wrap-with-annotated",
        default=False,
        action="store_true",
        help="Replace numpy/scipy arrays of "
        "'ARRAY_T[TYPE, [*DIMS], *FLAGS]' format with "
        "'Annotated[ARRAY_T, TYPE, FixedSize|DynamicSize(*DIMS), *FLAGS]'",
    )
    numpy_array_fix.add_argument(
        "--numpy-array-use-type-var",
        default=False,
        action="store_true",
        help="Replace 'numpy.ndarray[numpy.float32[m, 1]]' with "
        "'numpy.ndarray[tuple[M, typing.Literal[1]], numpy.dtype[numpy.float32]]'",
    )

    numpy_array_fix.add_argument(
        "--numpy-array-remove-parameters",
        default=False,
        action="store_true",
        help="Replace 'numpy.ndarray[...]' with 'numpy.ndarray'",
    )

    parser.add_argument(
        "--print-invalid-expressions-as-is",
        default=False,
        action="store_true",
        help="Suppress the replacement with '...' of invalid expressions"
        "found in annotations",
    )

    parser.add_argument(
        "--print-safe-value-reprs",
        metavar="REGEX",
        default=None,
        type=regex,
        help="Override the print-safe check for values matching REGEX",
    )

    parser.add_argument(
        "--exit-code",
        action="store_true",
        dest="exit_code",
        help="On error exits with 1 and skips stub generation",
    )

    parser.add_argument(
        "--dry-run",
        action="store_true",
        dest="dry_run",
        help="Don't write stubs. Parses module and report errors",
    )

    parser.add_argument(
        "--stub-extension",
        type=str,
        default="pyi",
        metavar="EXT",
        choices=["pyi", "py"],
        help="The file extension of the generated stubs. "
        "Must be 'pyi' (default) or 'py'",
    )

    return parser


def stub_parser_from_args(args: CLIArgs) -> IParser:
    error_handlers_top: list[type] = [
        LoggerData,
        *([IgnoreAllErrors] if args.ignore_all_errors else []),
        *([IgnoreInvalidIdentifierErrors] if args.ignore_invalid_identifiers else []),
        *([IgnoreInvalidExpressionErrors] if args.ignore_invalid_expressions else []),
        *([IgnoreUnresolvedNameErrors] if args.ignore_unresolved_names else []),
    ]
    error_handlers_bottom: list[type] = [
        LogErrors,
        SuggestCxxSignatureFix,
        *([TerminateOnFatalErrors] if args.exit_code else []),
    ]

    numpy_fixes: list[type] = [
        *([FixNumpyArrayDimAnnotation] if args.numpy_array_wrap_with_annotated else []),
        *([FixNumpyArrayDimTypeVar] if args.numpy_array_use_type_var else []),
        *(
            [FixNumpyArrayRemoveParameters]
            if args.numpy_array_remove_parameters
            else []
        ),
    ]

    class Parser(
        *error_handlers_top,  # type: ignore[misc]
        FixMissing__future__AnnotationsImport,
        FixMissing__all__Attribute,
        FixMissingNoneHashFieldAnnotation,
        FixMissingImports,
        FilterTypingModuleAttributes,
        FixPEP585CollectionNames,
        FixTypingTypeNames,
        FixScipyTypeArguments,
        FixMissingFixedSizeImport,
        FixMissingEnumMembersAnnotation,
        OverridePrintSafeValues,
        *numpy_fixes,  # type: ignore[misc]
        FixNumpyDtype,
        FixNumpyArrayFlags,
        FixCurrentModulePrefixInTypeNames,
        FixBuiltinTypes,
        RewritePybind11EnumValueRepr,
        FilterClassMembers,
        ReplaceReadWritePropertyWithField,
        FilterInvalidIdentifiers,
        FixValueReprRandomAddress,
        FixRedundantBuiltinsAnnotation,
        FilterPybindInternals,
        FilterPybind11ViewClasses,
        FixRedundantMethodsFromBuiltinObject,
        RemoveSelfAnnotation,
        FixPybind11EnumStrDoc,
        ExtractSignaturesFromPybind11Docstrings,
        ParserDispatchMixin,
        BaseParser,
        *error_handlers_bottom,  # type: ignore[misc]
    ):
        pass

    parser = Parser()

    if args.enum_class_locations:
        parser.set_pybind11_enum_locations(dict(args.enum_class_locations))
    if args.ignore_invalid_identifiers is not None:
        parser.set_ignored_invalid_identifiers(args.ignore_invalid_identifiers)
    if args.ignore_invalid_expressions is not None:
        parser.set_ignored_invalid_expressions(args.ignore_invalid_expressions)
    if args.ignore_unresolved_names is not None:
        parser.set_ignored_unresolved_names(args.ignore_unresolved_names)
    if args.print_safe_value_reprs is not None:
        parser.set_print_safe_value_pattern(args.print_safe_value_reprs)
    return parser


def main() -> None:

    files = glob.glob("*.so")

    for fid in files:
        idx: int = fid.find(".")
        module_name: str = fid[:idx]
        print("Processing: " + module_name)

        logging.basicConfig(
            level=logging.INFO,
            format="%(name)s - [%(levelname)7s] %(message)s",
        )
        args = arg_parser().parse_args(namespace=CLIArgs())

        parser = stub_parser_from_args(args)
        printer = Printer(
            invalid_expr_as_ellipses=not args.print_invalid_expressions_as_is
        )

        out_dir, sub_dir = to_output_and_subdir(
            output_dir=args.output_dir,
            module_name=module_name,
            root_suffix=args.root_suffix,
        )

        run(
            parser,
            printer,
            module_name,
            out_dir,
            sub_dir=sub_dir,
            dry_run=args.dry_run,
            writer=Writer(stub_ext=args.stub_extension),
        )


def to_output_and_subdir(
    output_dir: str, module_name: str, root_suffix: str | None
) -> tuple[Path, Path | None]:
    out_dir = Path(output_dir)

    module_path = module_name.split(".")

    if root_suffix is None:
        return out_dir.joinpath(*module_path[:-1]), None
    else:
        module_path = [f"{module_path[0]}{root_suffix}", *module_path[1:]]
        if len(module_path) == 1:
            sub_dir = Path(module_path[-1])
        else:
            sub_dir = None
        return out_dir.joinpath(*module_path[:-1]), sub_dir


def run(
    parser: IParser,
    printer: Printer,
    module_name: str,
    out_dir: Path,
    sub_dir: Path | None,
    dry_run: bool,
    writer: Writer,
):
    module = parser.handle_module(
        QualifiedName.from_str(module_name), importlib.import_module(module_name)
    )
    parser.finalize()

    if module is None:
        raise RuntimeError(f"Can't parse {module_name}")

    if dry_run:
        return

    out_dir.mkdir(exist_ok=True, parents=True)
    writer.write_module(module, printer, to=out_dir, sub_dir=sub_dir)


if __name__ == "__main__":
    main()

The source code is Open Source and can be found on GitHub.