Source code for tmep.template

# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""This file originates from the file `beets/util/functemplate.py
<https://github.com/beetbox/beets/blob/master/beets/util/functemplate.py>`_
of the `beets project <http://beets.io>`_.

This module implements a string formatter based on the standard PEP
292 string.Template class extended with function calls. Variables, as
with string.Template, are indicated with $ and functions are delimited
with %.

This module assumes that everything is Unicode: the template and the
substitution values. Bytestrings are not supported. Also, the templates
always behave like the ``safe_substitute`` method in the standard
library: unknown symbols are left intact.

This is sort of like a tiny, horrible degeneration of a real templating
engine like Jinja2 or Mustache.
"""

from __future__ import annotations

import ast
import dis
import functools
import re
import sys
import types
from typing import Any, Callable

from .types import FunctionCollection, Values

SYMBOL_DELIM = "$"
FUNC_DELIM = "%"
GROUP_OPEN = "{"
GROUP_CLOSE = "}"
ARG_SEP = ","
ESCAPE_CHAR = "$"

VARIABLE_PREFIX = "__var_"
FUNCTION_PREFIX = "__func_"


[docs] class Environment: """Contains the values and functions to be substituted into a template. """ values: Values functions: FunctionCollection def __init__(self, values: Values, functions: FunctionCollection) -> None: self.values = values self.functions = functions
# Code generation helpers.
[docs] def ex_rvalue(name: str) -> ast.Name: """A variable store expression. :param name: For example ``'str'``, ``'map'``, ``'__func_alpha'`` """ return ast.Name(name, ast.Load())
[docs] def ex_literal( val: int | float | bool | str | ast.Call | ast.List | ast.Name | None, ) -> ast.Constant: """An int, float, long, bool, string, or None literal with the given value. :param val: For example ``'abc123'`` """ return ast.Constant(val)
[docs] def ex_call( func: str | ast.Attribute | ast.Name, args: list[ast.Call | ast.List | ast.Name | Any], ) -> ast.Call: """A function-call expression with only positional parameters. The function may be an expression or the name of a function. Each argument may be an expression or a value to be used as a literal. """ if isinstance(func, str): func = ex_rvalue(func) args = list(args) for i in range(len(args)): if not isinstance(args[i], ast.expr): args[i] = ex_literal(args[i]) return ast.Call(func, args, [])
[docs] def compile_func( arg_names: list[str], statements: list[ast.Return], name: str = "_the_func", debug: bool = False, ) -> Callable[..., Any]: """Compile a list of statements as the body of a function and return the resulting Python function. If `debug`, then print out the bytecode of the compiled function. """ args_fields = { "args": [ast.arg(arg=n, annotation=None) for n in arg_names], "kwonlyargs": [], "kw_defaults": [], "defaults": [ex_literal(None) for _ in arg_names], } if "posonlyargs" in ast.arguments._fields: # Added in Python 3.8. args_fields["posonlyargs"] = [] args = ast.arguments(**args_fields) func_def = ast.FunctionDef( name=name, args=args, body=statements, decorator_list=[], ) # The ast.Module signature changed in 3.8 to accept a list of types to # ignore. if sys.version_info >= (3, 8): mod = ast.Module([func_def], []) else: mod = ast.Module([func_def]) ast.fix_missing_locations(mod) prog = compile(mod, "<generated>", "exec") # Debug: show bytecode. if debug: dis.dis(prog) for const in prog.co_consts: if isinstance(const, types.CodeType): dis.dis(const) the_locals: dict[str, Callable[..., Any]] = {} exec(prog, {}, the_locals) return the_locals[name]
# AST nodes for the template language.
[docs] class Symbol: """A variable-substitution symbol in a template.""" ident: str original: str def __init__(self, ident: str, original: str) -> None: self.ident = ident self.original = original def __repr__(self) -> str: return "Symbol(%s)" % repr(self.ident)
[docs] def evaluate(self, env: Environment): """Evaluate the symbol in the environment, returning a Unicode string. """ if self.ident in env.values: # Substitute for a value. return env.values[self.ident] else: # Keep original text. return self.original
[docs] def translate(self) -> tuple[list[ast.Name], set[str], set[str]]: """Compile the variable lookup.""" ident = self.ident expr = ex_rvalue(VARIABLE_PREFIX + ident) return [expr], {ident}, set()
[docs] class Call: """A function call in a template.""" ident: str args: list[Expression] original: str def __init__(self, ident: str, args: list[Expression], original: str) -> None: self.ident = ident self.args = args self.original = original def __repr__(self) -> str: return "Call({}, {}, {})".format( repr(self.ident), repr(self.args), repr(self.original) )
[docs] def evaluate(self, env: Environment) -> str: """Evaluate the function call in the environment, returning a Unicode string. """ if self.ident in env.functions: arg_vals = [expr.evaluate(env) for expr in self.args] try: out = env.functions[self.ident](*arg_vals) except Exception as exc: # Function raised exception! Maybe inlining the name of # the exception will help debug. return "<%s>" % str(exc) return str(out) else: return self.original
[docs] def translate(self): """Compile the function call.""" varnames: set[str] = set() funcnames: set[str] = {self.ident} arg_exprs: list[ast.Call] = [] for arg in self.args: subexprs, subvars, subfuncs = arg.translate() varnames.update(subvars) funcnames.update(subfuncs) # Create a subexpression that joins the result components of # the arguments. arg_exprs.append( ex_call( ast.Attribute(ex_literal(""), "join", ast.Load()), [ ex_call( "map", [ ex_rvalue(str.__name__), ast.List(subexprs, ast.Load()), ], ) ], ) ) subexpr_call = ex_call(FUNCTION_PREFIX + self.ident, arg_exprs) return [subexpr_call], varnames, funcnames
[docs] class Expression: """Top-level template construct: contains a list of text blobs, Symbols, and Calls. """ parts: list[str | Symbol | Call] def __init__(self, parts: list[str | Symbol | Call]) -> None: self.parts = parts def __repr__(self): return "Expression(%s)" % (repr(self.parts))
[docs] def evaluate(self, env: Environment) -> str: """Evaluate the entire expression in the environment, returning a Unicode string. """ out: list[str | Symbol | Call] = [] for part in self.parts: if isinstance(part, str): out.append(part) else: out.append(part.evaluate(env)) return "".join(map(str, out))
[docs] def translate(self): """Compile the expression to a list of Python AST expressions, a set of variable names used, and a set of function names. """ expressions: list[ast.Constant | ast.Name | ast.Call] = [] varnames: set[str] = set() funcnames: set[str] = set() for part in self.parts: if isinstance(part, str): expressions.append(ex_literal(part)) else: e, v, f = part.translate() expressions.extend(e) varnames.update(v) funcnames.update(f) return expressions, varnames, funcnames
# Parser.
[docs] class ParseError(Exception): pass
[docs] class Parser: """Parses a template expression string. Instantiate the class with the template source and call ``parse_expression``. The ``pos`` field will indicate the character after the expression finished and ``parts`` will contain a list of Unicode strings, Symbols, and Calls reflecting the concatenated portions of the expression. This is a terrible, ad-hoc parser implementation based on a left-to-right scan with no lexing step to speak of; it's probably both inefficient and incorrect. Maybe this should eventually be replaced with a real, accepted parsing technique (PEG, parser generator, etc.). """ string: str in_argument: bool pos: int parts: list[str | Symbol | Call] def __init__(self, string: str, in_argument: bool = False) -> None: """Create a new parser. :param in_arguments: boolean that indicates the parser is to be used for parsing function arguments, ie. considering commas (`ARG_SEP`) a special character """ self.string = string self.in_argument = in_argument self.pos = 0 self.parts = [] # Common parsing resources. special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE, ESCAPE_CHAR) special_char_re = re.compile( r"[%s]|\Z" % "".join(re.escape(c) for c in special_chars) ) escapable_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP) terminator_chars = (GROUP_CLOSE,)
[docs] def parse_expression(self) -> None: """Parse a template expression starting at ``pos``. Resulting components (Unicode strings, Symbols, and Calls) are added to the ``parts`` field, a list. The ``pos`` field is updated to be the next character after the expression. """ # Append comma (ARG_SEP) to the list of special characters only when # parsing function arguments. extra_special_chars = () special_char_re = self.special_char_re if self.in_argument: extra_special_chars = (ARG_SEP,) special_char_re = re.compile( r"[%s]|\Z" % "".join( re.escape(c) for c in self.special_chars + extra_special_chars ) ) text_parts: list[str] = [] while self.pos < len(self.string): char = self.string[self.pos] if char not in self.special_chars + extra_special_chars: # A non-special character. Skip to the next special # character, treating the interstice as literal text. search = special_char_re.search(self.string[self.pos :]) if search is None: raise ParseError("no special character found") next_pos = search.start() + self.pos text_parts.append(self.string[self.pos : next_pos]) self.pos = next_pos continue if self.pos == len(self.string) - 1: # The last character can never begin a structure, so we # just interpret it as a literal character (unless it # terminates the expression, as with , and }). if char not in self.terminator_chars + extra_special_chars: text_parts.append(char) self.pos += 1 break next_char = self.string[self.pos + 1] if char == ESCAPE_CHAR and next_char in ( self.escapable_chars + extra_special_chars ): # An escaped special character ($$, $}, etc.). Note that # ${ is not an escape sequence: this is ambiguous with # the start of a symbol and it's not necessary (just # using { suffices in all cases). text_parts.append(next_char) self.pos += 2 # Skip the next character. continue # Shift all characters collected so far into a single string. if text_parts: self.parts.append("".join(text_parts)) text_parts = [] if char == SYMBOL_DELIM: # Parse a symbol. self.parse_symbol() elif char == FUNC_DELIM: # Parse a function call. self.parse_call() elif char in self.terminator_chars + extra_special_chars: # Template terminated. break elif char == GROUP_OPEN: # Start of a group has no meaning hear; just pass # through the character. text_parts.append(char) self.pos += 1 else: assert False # If any parsed characters remain, shift them into a string. if text_parts: self.parts.append("".join(text_parts))
[docs] def parse_symbol(self) -> None: """Parse a variable reference (like ``$foo`` or ``${foo}``) starting at ``pos``. Possibly appends a Symbol object (or, failing that, text) to the ``parts`` field and updates ``pos``. The character at ``pos`` must, as a precondition, be ``$``. """ assert self.pos < len(self.string) assert self.string[self.pos] == SYMBOL_DELIM if self.pos == len(self.string) - 1: # Last character. self.parts.append(SYMBOL_DELIM) self.pos += 1 return next_char = self.string[self.pos + 1] start_pos = self.pos self.pos += 1 if next_char == GROUP_OPEN: # A symbol like ${this}. self.pos += 1 # Skip opening. closer = self.string.find(GROUP_CLOSE, self.pos) if closer == -1 or closer == self.pos: # No closing brace found or identifier is empty. self.parts.append(self.string[start_pos : self.pos]) else: # Closer found. ident = self.string[self.pos : closer] self.pos = closer + 1 self.parts.append(Symbol(ident, self.string[start_pos : self.pos])) else: # A bare-word symbol. ident = self._parse_ident() if ident: # Found a real symbol. self.parts.append(Symbol(ident, self.string[start_pos : self.pos])) else: # A standalone $. self.parts.append(SYMBOL_DELIM)
[docs] def parse_call(self): """Parse a function call (like ``%foo{bar,baz}``) starting at ``pos``. Possibly appends a Call object to ``parts`` and update ``pos``. The character at ``pos`` must be ``%``. """ assert self.pos < len(self.string) assert self.string[self.pos] == FUNC_DELIM start_pos = self.pos self.pos += 1 ident = self._parse_ident() if not ident: # No function name. self.parts.append(FUNC_DELIM) return if self.pos >= len(self.string): # Identifier terminates string. self.parts.append(self.string[start_pos : self.pos]) return if self.string[self.pos] != GROUP_OPEN: # Argument list not opened. self.parts.append(self.string[start_pos : self.pos]) return # Skip past opening brace and try to parse an argument list. self.pos += 1 args = self.parse_argument_list() if self.pos >= len(self.string) or self.string[self.pos] != GROUP_CLOSE: # Arguments unclosed. self.parts.append(self.string[start_pos : self.pos]) return self.pos += 1 # Move past closing brace. self.parts.append(Call(ident, args, self.string[start_pos : self.pos]))
[docs] def parse_argument_list(self) -> list[Expression]: """Parse a list of arguments starting at ``pos``, returning a list of Expression objects. Does not modify ``parts``. Should leave ``pos`` pointing to a } character or the end of the string. """ # Try to parse a subexpression in a subparser. expressions: list[Expression] = [] while self.pos < len(self.string): subparser = Parser(self.string[self.pos :], in_argument=True) subparser.parse_expression() # Extract and advance past the parsed expression. expressions.append(Expression(subparser.parts)) self.pos += subparser.pos if self.pos >= len(self.string) or self.string[self.pos] == GROUP_CLOSE: # Argument list terminated by EOF or closing brace. break # Only other way to terminate an expression is with ,. # Continue to the next argument. assert self.string[self.pos] == ARG_SEP self.pos += 1 return expressions
def _parse_ident(self) -> str: """Parse an identifier and return it (possibly an empty string). Updates ``pos``. """ remainder = self.string[self.pos :] match = re.match(r"\w*", remainder) if match is None: raise ParseError("invalid identifier") ident = match.group(0) self.pos += len(ident) return ident
def _parse(template: str) -> Expression: """Parse a top-level template string Expression. Any extraneous text is considered literal text. """ parser = Parser(template) parser.parse_expression() parts = parser.parts remainder = parser.string[parser.pos :] if remainder: parts.append(remainder) return Expression(parts)
[docs] @functools.lru_cache(maxsize=128) def template(fmt: str) -> Template: return Template(fmt)
# External interface.
[docs] class Template: """A string template, including text, Symbols, and Calls.""" expr: Expression original: str def __init__(self, template: str) -> None: self.expr = _parse(template) self.original = template self.compiled = self.translate() def __eq__(self, other: Any) -> bool: return self.original == other.original
[docs] def interpret(self, values: Values = {}, functions: FunctionCollection = {}) -> str: """Like `substitute`, but forces the interpreter (rather than the compiled version) to be used. The interpreter includes exception-handling code for missing variables and buggy template functions but is much slower. """ return self.expr.evaluate(Environment(values, functions))
[docs] def substitute( self, values: Values = {}, functions: FunctionCollection = {} ) -> str: """Evaluate the template given the values and functions.""" try: res = self.compiled(values, functions) except Exception: # Handle any exceptions thrown by compiled version. res = self.interpret(values, functions) return res
[docs] def translate(self) -> Callable[..., str]: """Compile the template to a Python function.""" expressions, varnames, funcnames = self.expr.translate() argnames: list[str] = [] for varname in varnames: argnames.append(VARIABLE_PREFIX + varname) for funcname in funcnames: argnames.append(FUNCTION_PREFIX + funcname) func = compile_func( argnames, [ast.Return(ast.List(expressions, ast.Load()))], ) def wrapper_func( values: Values = {}, functions: FunctionCollection = {} ) -> str: args = {} for varname in varnames: args[VARIABLE_PREFIX + varname] = values[varname] for funcname in funcnames: args[FUNCTION_PREFIX + funcname] = functions[funcname] parts = func(**args) return "".join(parts) return wrapper_func
__all__: list[str] = ["Template"]