Source code for tmep.template

# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""This file originates from the file `beets/util/functemplate.py
<https://github.com/beetbox/beets/blob/master/beets/util/functemplate.py>`_
of the `beets project <http://beets.io>`_.

This module implements a string formatter based on the standard PEP
292 string.Template class extended with function calls. Variables, as
with string.Template, are indicated with $ and functions are delimited
with %.

This module assumes that everything is Unicode: the template and the
substitution values. Bytestrings are not supported. Also, the templates
always behave like the ``safe_substitute`` method in the standard
library: unknown symbols are left intact.

This is sort of like a tiny, horrible degeneration of a real templating
engine like Jinja2 or Mustache.
"""

from __future__ import annotations

import ast
import dis
import functools
import re
import sys
import types
from typing import Any, Callable

from .types import FunctionCollection, Values

SYMBOL_DELIM = "$"
FUNC_DELIM = "%"
GROUP_OPEN = "{"
GROUP_CLOSE = "}"
ARG_SEP = ","
ESCAPE_CHAR = "$"

VARIABLE_PREFIX = "__var_"
FUNCTION_PREFIX = "__func_"


class Environment:
    """Contains the values and functions to be substituted into a
    template.
    """

    values: Values
    functions: FunctionCollection

    def __init__(self, values: Values, functions: FunctionCollection) -> None:
        self.values = values
        self.functions = functions


# Code generation helpers.


def ex_rvalue(name: str) -> ast.Name:
    """A variable store expression.

    :param name: For example ``'str'``, ``'map'``, ``'__func_alpha'``
    """
    return ast.Name(name, ast.Load())


def ex_literal(
    val: int | float | bool | str | ast.Call | ast.List | ast.Name | None,
) -> ast.Constant:
    """An int, float, long, bool, string, or None literal with the given
    value.

    :param val: For example ``'abc123'``
    """
    return ast.Constant(val)


def ex_call(
    func: str | ast.Attribute | ast.Name,
    args: list[ast.Call | ast.List | ast.Name | Any],
) -> ast.Call:
    """A function-call expression with only positional parameters. The
    function may be an expression or the name of a function. Each
    argument may be an expression or a value to be used as a literal.
    """
    if isinstance(func, str):
        func = ex_rvalue(func)

    args = list(args)
    for i in range(len(args)):
        if not isinstance(args[i], ast.expr):
            args[i] = ex_literal(args[i])

    return ast.Call(func, args, [])


def compile_func(
    arg_names: list[str],
    statements: list[ast.Return],
    name: str = "_the_func",
    debug: bool = False,
) -> Callable[..., Any]:
    """Compile a list of statements as the body of a function and return
    the resulting Python function. If `debug`, then print out the
    bytecode of the compiled function.
    """
    args_fields = {
        "args": [ast.arg(arg=n, annotation=None) for n in arg_names],
        "kwonlyargs": [],
        "kw_defaults": [],
        "defaults": [ex_literal(None) for _ in arg_names],
    }
    if "posonlyargs" in ast.arguments._fields:  # Added in Python 3.8.
        args_fields["posonlyargs"] = []
    args = ast.arguments(**args_fields)

    func_def = ast.FunctionDef(
        name=name,
        args=args,
        body=statements,
        decorator_list=[],
    )

    # The ast.Module signature changed in 3.8 to accept a list of types to
    # ignore.
    if sys.version_info >= (3, 8):
        mod = ast.Module([func_def], [])
    else:
        mod = ast.Module([func_def])

    ast.fix_missing_locations(mod)

    prog = compile(mod, "<generated>", "exec")

    # Debug: show bytecode.
    if debug:
        dis.dis(prog)
        for const in prog.co_consts:
            if isinstance(const, types.CodeType):
                dis.dis(const)

    the_locals: dict[str, Callable[..., Any]] = {}
    exec(prog, {}, the_locals)
    return the_locals[name]


# AST nodes for the template language.


class Symbol:
    """A variable-substitution symbol in a template."""

    ident: str
    original: str

    def __init__(self, ident: str, original: str) -> None:
        self.ident = ident
        self.original = original

    def __repr__(self) -> str:
        return "Symbol(%s)" % repr(self.ident)

    def evaluate(self, env: Environment):
        """Evaluate the symbol in the environment, returning a Unicode
        string.
        """
        if self.ident in env.values:
            # Substitute for a value.
            return env.values[self.ident]
        else:
            # Keep original text.
            return self.original

    def translate(self) -> tuple[list[ast.Name], set[str], set[str]]:
        """Compile the variable lookup."""
        ident = self.ident
        expr = ex_rvalue(VARIABLE_PREFIX + ident)
        return [expr], {ident}, set()


class Call:
    """A function call in a template."""

    ident: str
    args: list[Expression]
    original: str

    def __init__(self, ident: str, args: list[Expression], original: str) -> None:
        self.ident = ident
        self.args = args
        self.original = original

    def __repr__(self) -> str:
        return "Call({}, {}, {})".format(
            repr(self.ident), repr(self.args), repr(self.original)
        )

    def evaluate(self, env: Environment) -> str:
        """Evaluate the function call in the environment, returning a
        Unicode string.
        """
        if self.ident in env.functions:
            arg_vals = [expr.evaluate(env) for expr in self.args]
            try:
                out = env.functions[self.ident](*arg_vals)
            except Exception as exc:
                # Function raised exception! Maybe inlining the name of
                # the exception will help debug.
                return "<%s>" % str(exc)
            return str(out)
        else:
            return self.original

    def translate(self):
        """Compile the function call."""
        varnames: set[str] = set()
        funcnames: set[str] = {self.ident}

        arg_exprs: list[ast.Call] = []
        for arg in self.args:
            subexprs, subvars, subfuncs = arg.translate()
            varnames.update(subvars)
            funcnames.update(subfuncs)

            # Create a subexpression that joins the result components of
            # the arguments.
            arg_exprs.append(
                ex_call(
                    ast.Attribute(ex_literal(""), "join", ast.Load()),
                    [
                        ex_call(
                            "map",
                            [
                                ex_rvalue(str.__name__),
                                ast.List(subexprs, ast.Load()),
                            ],
                        )
                    ],
                )
            )

        subexpr_call = ex_call(FUNCTION_PREFIX + self.ident, arg_exprs)
        return [subexpr_call], varnames, funcnames


class Expression:
    """Top-level template construct: contains a list of text blobs,
    Symbols, and Calls.
    """

    parts: list[str | Symbol | Call]

    def __init__(self, parts: list[str | Symbol | Call]) -> None:
        self.parts = parts

    def __repr__(self):
        return "Expression(%s)" % (repr(self.parts))

    def evaluate(self, env: Environment) -> str:
        """Evaluate the entire expression in the environment, returning
        a Unicode string.
        """
        out: list[str | Symbol | Call] = []
        for part in self.parts:
            if isinstance(part, str):
                out.append(part)
            else:
                out.append(part.evaluate(env))
        return "".join(map(str, out))

    def translate(self):
        """Compile the expression to a list of Python AST expressions, a
        set of variable names used, and a set of function names.
        """
        expressions: list[ast.Constant | ast.Name | ast.Call] = []
        varnames: set[str] = set()
        funcnames: set[str] = set()
        for part in self.parts:
            if isinstance(part, str):
                expressions.append(ex_literal(part))
            else:
                e, v, f = part.translate()
                expressions.extend(e)
                varnames.update(v)
                funcnames.update(f)
        return expressions, varnames, funcnames


# Parser.


class ParseError(Exception):
    pass


class Parser:
    """Parses a template expression string. Instantiate the class with
    the template source and call ``parse_expression``. The ``pos`` field
    will indicate the character after the expression finished and
    ``parts`` will contain a list of Unicode strings, Symbols, and Calls
    reflecting the concatenated portions of the expression.

    This is a terrible, ad-hoc parser implementation based on a
    left-to-right scan with no lexing step to speak of; it's probably
    both inefficient and incorrect. Maybe this should eventually be
    replaced with a real, accepted parsing technique (PEG, parser
    generator, etc.).
    """

    string: str
    in_argument: bool
    pos: int
    parts: list[str | Symbol | Call]

    def __init__(self, string: str, in_argument: bool = False) -> None:
        """Create a new parser.
        :param in_arguments: boolean that indicates the parser is to be
        used for parsing function arguments, ie. considering commas
        (`ARG_SEP`) a special character
        """
        self.string = string
        self.in_argument = in_argument
        self.pos = 0
        self.parts = []

    # Common parsing resources.
    special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE, ESCAPE_CHAR)
    special_char_re = re.compile(
        r"[%s]|\Z" % "".join(re.escape(c) for c in special_chars)
    )
    escapable_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP)
    terminator_chars = (GROUP_CLOSE,)

    def parse_expression(self) -> None:
        """Parse a template expression starting at ``pos``. Resulting
        components (Unicode strings, Symbols, and Calls) are added to
        the ``parts`` field, a list.  The ``pos`` field is updated to be
        the next character after the expression.
        """
        # Append comma (ARG_SEP) to the list of special characters only when
        # parsing function arguments.
        extra_special_chars = ()
        special_char_re = self.special_char_re
        if self.in_argument:
            extra_special_chars = (ARG_SEP,)
            special_char_re = re.compile(
                r"[%s]|\Z"
                % "".join(
                    re.escape(c) for c in self.special_chars + extra_special_chars
                )
            )

        text_parts: list[str] = []

        while self.pos < len(self.string):
            char = self.string[self.pos]

            if char not in self.special_chars + extra_special_chars:
                # A non-special character. Skip to the next special
                # character, treating the interstice as literal text.
                search = special_char_re.search(self.string[self.pos :])
                if search is None:
                    raise ParseError("no special character found")
                next_pos = search.start() + self.pos
                text_parts.append(self.string[self.pos : next_pos])
                self.pos = next_pos
                continue

            if self.pos == len(self.string) - 1:
                # The last character can never begin a structure, so we
                # just interpret it as a literal character (unless it
                # terminates the expression, as with , and }).
                if char not in self.terminator_chars + extra_special_chars:
                    text_parts.append(char)
                    self.pos += 1
                break

            next_char = self.string[self.pos + 1]
            if char == ESCAPE_CHAR and next_char in (
                self.escapable_chars + extra_special_chars
            ):
                # An escaped special character ($$, $}, etc.). Note that
                # ${ is not an escape sequence: this is ambiguous with
                # the start of a symbol and it's not necessary (just
                # using { suffices in all cases).
                text_parts.append(next_char)
                self.pos += 2  # Skip the next character.
                continue

            # Shift all characters collected so far into a single string.
            if text_parts:
                self.parts.append("".join(text_parts))
                text_parts = []

            if char == SYMBOL_DELIM:
                # Parse a symbol.
                self.parse_symbol()
            elif char == FUNC_DELIM:
                # Parse a function call.
                self.parse_call()
            elif char in self.terminator_chars + extra_special_chars:
                # Template terminated.
                break
            elif char == GROUP_OPEN:
                # Start of a group has no meaning hear; just pass
                # through the character.
                text_parts.append(char)
                self.pos += 1
            else:
                assert False

        # If any parsed characters remain, shift them into a string.
        if text_parts:
            self.parts.append("".join(text_parts))

    def parse_symbol(self) -> None:
        """Parse a variable reference (like ``$foo`` or ``${foo}``)
        starting at ``pos``. Possibly appends a Symbol object (or,
        failing that, text) to the ``parts`` field and updates ``pos``.
        The character at ``pos`` must, as a precondition, be ``$``.
        """
        assert self.pos < len(self.string)
        assert self.string[self.pos] == SYMBOL_DELIM

        if self.pos == len(self.string) - 1:
            # Last character.
            self.parts.append(SYMBOL_DELIM)
            self.pos += 1
            return

        next_char = self.string[self.pos + 1]
        start_pos = self.pos
        self.pos += 1

        if next_char == GROUP_OPEN:
            # A symbol like ${this}.
            self.pos += 1  # Skip opening.
            closer = self.string.find(GROUP_CLOSE, self.pos)
            if closer == -1 or closer == self.pos:
                # No closing brace found or identifier is empty.
                self.parts.append(self.string[start_pos : self.pos])
            else:
                # Closer found.
                ident = self.string[self.pos : closer]
                self.pos = closer + 1
                self.parts.append(Symbol(ident, self.string[start_pos : self.pos]))

        else:
            # A bare-word symbol.
            ident = self._parse_ident()
            if ident:
                # Found a real symbol.
                self.parts.append(Symbol(ident, self.string[start_pos : self.pos]))
            else:
                # A standalone $.
                self.parts.append(SYMBOL_DELIM)

    def parse_call(self):
        """Parse a function call (like ``%foo{bar,baz}``) starting at
        ``pos``.  Possibly appends a Call object to ``parts`` and update
        ``pos``. The character at ``pos`` must be ``%``.
        """
        assert self.pos < len(self.string)
        assert self.string[self.pos] == FUNC_DELIM

        start_pos = self.pos
        self.pos += 1

        ident = self._parse_ident()
        if not ident:
            # No function name.
            self.parts.append(FUNC_DELIM)
            return

        if self.pos >= len(self.string):
            # Identifier terminates string.
            self.parts.append(self.string[start_pos : self.pos])
            return

        if self.string[self.pos] != GROUP_OPEN:
            # Argument list not opened.
            self.parts.append(self.string[start_pos : self.pos])
            return

        # Skip past opening brace and try to parse an argument list.
        self.pos += 1
        args = self.parse_argument_list()
        if self.pos >= len(self.string) or self.string[self.pos] != GROUP_CLOSE:
            # Arguments unclosed.
            self.parts.append(self.string[start_pos : self.pos])
            return

        self.pos += 1  # Move past closing brace.
        self.parts.append(Call(ident, args, self.string[start_pos : self.pos]))

    def parse_argument_list(self) -> list[Expression]:
        """Parse a list of arguments starting at ``pos``, returning a
        list of Expression objects. Does not modify ``parts``. Should
        leave ``pos`` pointing to a } character or the end of the
        string.
        """
        # Try to parse a subexpression in a subparser.
        expressions: list[Expression] = []

        while self.pos < len(self.string):
            subparser = Parser(self.string[self.pos :], in_argument=True)
            subparser.parse_expression()

            # Extract and advance past the parsed expression.
            expressions.append(Expression(subparser.parts))
            self.pos += subparser.pos

            if self.pos >= len(self.string) or self.string[self.pos] == GROUP_CLOSE:
                # Argument list terminated by EOF or closing brace.
                break

            # Only other way to terminate an expression is with ,.
            # Continue to the next argument.
            assert self.string[self.pos] == ARG_SEP
            self.pos += 1

        return expressions

    def _parse_ident(self) -> str:
        """Parse an identifier and return it (possibly an empty string).
        Updates ``pos``.
        """
        remainder = self.string[self.pos :]
        match = re.match(r"\w*", remainder)
        if match is None:
            raise ParseError("invalid identifier")
        ident = match.group(0)
        self.pos += len(ident)
        return ident


def _parse(template: str) -> Expression:
    """Parse a top-level template string Expression. Any extraneous text
    is considered literal text.
    """
    parser = Parser(template)
    parser.parse_expression()

    parts = parser.parts
    remainder = parser.string[parser.pos :]
    if remainder:
        parts.append(remainder)
    return Expression(parts)


@functools.lru_cache(maxsize=128)
def template(fmt: str) -> Template:
    return Template(fmt)


# External interface.
[docs] class Template: """A string template, including text, Symbols, and Calls.""" expr: Expression original: str def __init__(self, template: str) -> None: self.expr = _parse(template) self.original = template self.compiled = self.translate() def __eq__(self, other: Any) -> bool: return self.original == other.original
[docs] def interpret(self, values: Values = {}, functions: FunctionCollection = {}) -> str: """Like `substitute`, but forces the interpreter (rather than the compiled version) to be used. The interpreter includes exception-handling code for missing variables and buggy template functions but is much slower. """ return self.expr.evaluate(Environment(values, functions))
[docs] def substitute( self, values: Values = {}, functions: FunctionCollection = {} ) -> str: """Evaluate the template given the values and functions.""" try: res = self.compiled(values, functions) except Exception: # Handle any exceptions thrown by compiled version. res = self.interpret(values, functions) return res
[docs] def translate(self) -> Callable[..., str]: """Compile the template to a Python function.""" expressions, varnames, funcnames = self.expr.translate() argnames: list[str] = [] for varname in varnames: argnames.append(VARIABLE_PREFIX + varname) for funcname in funcnames: argnames.append(FUNCTION_PREFIX + funcname) func = compile_func( argnames, [ast.Return(ast.List(expressions, ast.Load()))], ) def wrapper_func( values: Values = {}, functions: FunctionCollection = {} ) -> str: args = {} for varname in varnames: args[VARIABLE_PREFIX + varname] = values[varname] for funcname in funcnames: args[FUNCTION_PREFIX + funcname] = functions[funcname] parts = func(**args) return "".join(parts) return wrapper_func
__all__: list[str] = ["Template"]