Source code for interfacy.naming.flag_strategy

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal, Protocol

from objinspect import Parameter
from stdl.st import kebab_case, snake_case

from interfacy.naming.abbreviations import AbbreviationGenerator
from interfacy.naming.name_mapping import NameMapping
from interfacy.util import is_list_or_list_alias

FlagStyle = Literal["keyword_only", "required_positional"]
TranslationMode = Literal["none", "kebab", "snake"]

NAME_TRANSLATORS: dict[TranslationMode, Callable[[str], str]] = {
    "none": lambda s: s,
    "kebab": kebab_case,
    "snake": snake_case,
}


def build_name_mapping(mode: TranslationMode) -> NameMapping:
    """
    Build a NameMapping for a translation mode.

    Args:
        mode (TranslationMode): Translation mode to use.

    Raises:
        ValueError: If the translation mode is unsupported.
    """
    if mode not in NAME_TRANSLATORS:
        raise ValueError(
            f"Invalid flag translation mode: {mode}. "
            f"Valid modes are: {', '.join(NAME_TRANSLATORS.keys())}"
        )

    return NameMapping(NAME_TRANSLATORS[mode])


[docs] class FlagStrategy(Protocol): argument_translator: NameMapping command_translator: NameMapping style: FlagStyle translation_mode: TranslationMode def get_arg_flags( self, name: str, param: Parameter, taken_flags: list[str], abbrev_gen: AbbreviationGenerator, ) -> tuple[str, ...]: ...
@dataclass class FlagAllocationState: """Per-command flag allocation state for positional parameters.""" consumed_required_list_positional: bool = False class FlagParamView: """Proxy parameter that overrides selected attributes without losing the original shape.""" def __init__(self, param: Parameter, **overrides: Any) -> None: self._param = param self._overrides = overrides def __getattr__(self, name: str) -> Any: if name in self._overrides: return self._overrides[name] return getattr(self._param, name) def _is_required_list_positional_candidate(strategy: FlagStrategy, param: Parameter) -> bool: return ( strategy.style == "required_positional" and param.is_required and not (param.is_typed and param.type is bool) and is_list_or_list_alias(param.type) ) def get_arg_flags_for_parameter( strategy: FlagStrategy, name: str, param: Parameter, taken_flags: list[str], abbrev_gen: AbbreviationGenerator, *, allocation_state: FlagAllocationState | None = None, ) -> tuple[str, ...]: """ Generate flags using build-local positional allocation for required list parameters. The strategy remains stateless. Callers that want "first greedy required list stays positional" can pass a local allocation state for the current command/build. """ if _is_required_list_positional_candidate(strategy, param): if allocation_state is None: return (name,) if not allocation_state.consumed_required_list_positional: allocation_state.consumed_required_list_positional = True return (name,) param = FlagParamView(param, is_required=False) return strategy.get_arg_flags(name, param, taken_flags, abbrev_gen)
[docs] class DefaultFlagStrategy(FlagStrategy): """ Default flag strategy for generating CLI flag names. Args: style (FlagStyle): Flag style for required/optional parameters. translation_mode (TranslationMode): Name translation mode. nested_separator (str): Separator for nested model paths. """ def __init__( self, style: FlagStyle = "required_positional", translation_mode: TranslationMode = "kebab", nested_separator: str = ".", ) -> None: self.style = style self.translation_mode = translation_mode self.nested_separator = nested_separator self.argument_translator = build_name_mapping(self.translation_mode) self.command_translator = build_name_mapping(self.translation_mode) def get_arg_flags( self, name: str, param: Parameter, taken_flags: list[str], abbrev_gen: AbbreviationGenerator, ) -> tuple[str, ...]: """ Generate CLI flag names for a given parameter based on its name and already taken flags. Args: name (str): The name of the parameter for which to generate flags. param (Parameter): Parameter object containing type and other metadata. taken_flags (list[str]): Flags that are already in use. abbrev_gen (AbbreviationGenerator): AbbreviationGenerator instance. Returns: tuple[str, ...]: A tuple containing the long flag (and short flag if applicable). """ is_bool_flag = param.is_typed and param.type is bool if not is_bool_flag and param.is_required and self.style == "required_positional": return (name,) flag_long = f"--{name}".strip() if is_bool_flag or len(name) > 1 else f"-{name}".strip() flags: tuple[str, ...] = (flag_long,) abbrev_name = name if is_bool_flag: default_value = param.default if param.has_default else False if default_value is True: abbrev_name = f"no-{name}" if flag_short := abbrev_gen.generate(abbrev_name, taken_flags): flag_short = flag_short.strip() if flag_short and flag_short not in (name, abbrev_name): flags = (f"-{flag_short}", flag_long) return flags
__all__ = [ "DefaultFlagStrategy", "FlagAllocationState", "FlagStrategy", "FlagStyle", "TranslationMode", "build_name_mapping", "get_arg_flags_for_parameter", ]