# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""This file originates from the file `beets/util/functemplate.py
<https://github.com/beetbox/beets/blob/master/beets/util/functemplate.py>`_
of the `beets project <http://beets.io>`_.
This module implements a string formatter based on the standard PEP
292 string.Template class extended with function calls. Variables, as
with string.Template, are indicated with $ and functions are delimited
with %.
This module assumes that everything is Unicode: the template and the
substitution values. Bytestrings are not supported. Also, the templates
always behave like the ``safe_substitute`` method in the standard
library: unknown symbols are left intact.
This is sort of like a tiny, horrible degeneration of a real templating
engine like Jinja2 or Mustache.
"""
from __future__ import annotations
import ast
import dis
import functools
import re
import sys
import types
from typing import Any, Callable
from .types import FunctionCollection, Values
SYMBOL_DELIM = "$"
FUNC_DELIM = "%"
GROUP_OPEN = "{"
GROUP_CLOSE = "}"
ARG_SEP = ","
ESCAPE_CHAR = "$"
VARIABLE_PREFIX = "__var_"
FUNCTION_PREFIX = "__func_"
class Environment:
"""Contains the values and functions to be substituted into a
template.
"""
values: Values
functions: FunctionCollection
def __init__(self, values: Values, functions: FunctionCollection) -> None:
self.values = values
self.functions = functions
# Code generation helpers.
def ex_rvalue(name: str) -> ast.Name:
"""A variable store expression.
:param name: For example ``'str'``, ``'map'``, ``'__func_alpha'``
"""
return ast.Name(name, ast.Load())
def ex_literal(
val: int | float | bool | str | ast.Call | ast.List | ast.Name | None,
) -> ast.Constant:
"""An int, float, long, bool, string, or None literal with the given
value.
:param val: For example ``'abc123'``
"""
return ast.Constant(val)
def ex_call(
func: str | ast.Attribute | ast.Name,
args: list[ast.Call | ast.List | ast.Name | Any],
) -> ast.Call:
"""A function-call expression with only positional parameters. The
function may be an expression or the name of a function. Each
argument may be an expression or a value to be used as a literal.
"""
if isinstance(func, str):
func = ex_rvalue(func)
args = list(args)
for i in range(len(args)):
if not isinstance(args[i], ast.expr):
args[i] = ex_literal(args[i])
return ast.Call(func, args, [])
def compile_func(
arg_names: list[str],
statements: list[ast.Return],
name: str = "_the_func",
debug: bool = False,
) -> Callable[..., Any]:
"""Compile a list of statements as the body of a function and return
the resulting Python function. If `debug`, then print out the
bytecode of the compiled function.
"""
args_fields = {
"args": [ast.arg(arg=n, annotation=None) for n in arg_names],
"kwonlyargs": [],
"kw_defaults": [],
"defaults": [ex_literal(None) for _ in arg_names],
}
if "posonlyargs" in ast.arguments._fields: # Added in Python 3.8.
args_fields["posonlyargs"] = []
args = ast.arguments(**args_fields)
func_def = ast.FunctionDef(
name=name,
args=args,
body=statements,
decorator_list=[],
)
# The ast.Module signature changed in 3.8 to accept a list of types to
# ignore.
if sys.version_info >= (3, 8):
mod = ast.Module([func_def], [])
else:
mod = ast.Module([func_def])
ast.fix_missing_locations(mod)
prog = compile(mod, "<generated>", "exec")
# Debug: show bytecode.
if debug:
dis.dis(prog)
for const in prog.co_consts:
if isinstance(const, types.CodeType):
dis.dis(const)
the_locals: dict[str, Callable[..., Any]] = {}
exec(prog, {}, the_locals)
return the_locals[name]
# AST nodes for the template language.
class Symbol:
"""A variable-substitution symbol in a template."""
ident: str
original: str
def __init__(self, ident: str, original: str) -> None:
self.ident = ident
self.original = original
def __repr__(self) -> str:
return "Symbol(%s)" % repr(self.ident)
def evaluate(self, env: Environment):
"""Evaluate the symbol in the environment, returning a Unicode
string.
"""
if self.ident in env.values:
# Substitute for a value.
return env.values[self.ident]
else:
# Keep original text.
return self.original
def translate(self) -> tuple[list[ast.Name], set[str], set[str]]:
"""Compile the variable lookup."""
ident = self.ident
expr = ex_rvalue(VARIABLE_PREFIX + ident)
return [expr], {ident}, set()
class Call:
"""A function call in a template."""
ident: str
args: list[Expression]
original: str
def __init__(self, ident: str, args: list[Expression], original: str) -> None:
self.ident = ident
self.args = args
self.original = original
def __repr__(self) -> str:
return "Call({}, {}, {})".format(
repr(self.ident), repr(self.args), repr(self.original)
)
def evaluate(self, env: Environment) -> str:
"""Evaluate the function call in the environment, returning a
Unicode string.
"""
if self.ident in env.functions:
arg_vals = [expr.evaluate(env) for expr in self.args]
try:
out = env.functions[self.ident](*arg_vals)
except Exception as exc:
# Function raised exception! Maybe inlining the name of
# the exception will help debug.
return "<%s>" % str(exc)
return str(out)
else:
return self.original
def translate(self):
"""Compile the function call."""
varnames: set[str] = set()
funcnames: set[str] = {self.ident}
arg_exprs: list[ast.Call] = []
for arg in self.args:
subexprs, subvars, subfuncs = arg.translate()
varnames.update(subvars)
funcnames.update(subfuncs)
# Create a subexpression that joins the result components of
# the arguments.
arg_exprs.append(
ex_call(
ast.Attribute(ex_literal(""), "join", ast.Load()),
[
ex_call(
"map",
[
ex_rvalue(str.__name__),
ast.List(subexprs, ast.Load()),
],
)
],
)
)
subexpr_call = ex_call(FUNCTION_PREFIX + self.ident, arg_exprs)
return [subexpr_call], varnames, funcnames
class Expression:
"""Top-level template construct: contains a list of text blobs,
Symbols, and Calls.
"""
parts: list[str | Symbol | Call]
def __init__(self, parts: list[str | Symbol | Call]) -> None:
self.parts = parts
def __repr__(self):
return "Expression(%s)" % (repr(self.parts))
def evaluate(self, env: Environment) -> str:
"""Evaluate the entire expression in the environment, returning
a Unicode string.
"""
out: list[str | Symbol | Call] = []
for part in self.parts:
if isinstance(part, str):
out.append(part)
else:
out.append(part.evaluate(env))
return "".join(map(str, out))
def translate(self):
"""Compile the expression to a list of Python AST expressions, a
set of variable names used, and a set of function names.
"""
expressions: list[ast.Constant | ast.Name | ast.Call] = []
varnames: set[str] = set()
funcnames: set[str] = set()
for part in self.parts:
if isinstance(part, str):
expressions.append(ex_literal(part))
else:
e, v, f = part.translate()
expressions.extend(e)
varnames.update(v)
funcnames.update(f)
return expressions, varnames, funcnames
# Parser.
class ParseError(Exception):
pass
class Parser:
"""Parses a template expression string. Instantiate the class with
the template source and call ``parse_expression``. The ``pos`` field
will indicate the character after the expression finished and
``parts`` will contain a list of Unicode strings, Symbols, and Calls
reflecting the concatenated portions of the expression.
This is a terrible, ad-hoc parser implementation based on a
left-to-right scan with no lexing step to speak of; it's probably
both inefficient and incorrect. Maybe this should eventually be
replaced with a real, accepted parsing technique (PEG, parser
generator, etc.).
"""
string: str
in_argument: bool
pos: int
parts: list[str | Symbol | Call]
def __init__(self, string: str, in_argument: bool = False) -> None:
"""Create a new parser.
:param in_arguments: boolean that indicates the parser is to be
used for parsing function arguments, ie. considering commas
(`ARG_SEP`) a special character
"""
self.string = string
self.in_argument = in_argument
self.pos = 0
self.parts = []
# Common parsing resources.
special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE, ESCAPE_CHAR)
special_char_re = re.compile(
r"[%s]|\Z" % "".join(re.escape(c) for c in special_chars)
)
escapable_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP)
terminator_chars = (GROUP_CLOSE,)
def parse_expression(self) -> None:
"""Parse a template expression starting at ``pos``. Resulting
components (Unicode strings, Symbols, and Calls) are added to
the ``parts`` field, a list. The ``pos`` field is updated to be
the next character after the expression.
"""
# Append comma (ARG_SEP) to the list of special characters only when
# parsing function arguments.
extra_special_chars = ()
special_char_re = self.special_char_re
if self.in_argument:
extra_special_chars = (ARG_SEP,)
special_char_re = re.compile(
r"[%s]|\Z"
% "".join(
re.escape(c) for c in self.special_chars + extra_special_chars
)
)
text_parts: list[str] = []
while self.pos < len(self.string):
char = self.string[self.pos]
if char not in self.special_chars + extra_special_chars:
# A non-special character. Skip to the next special
# character, treating the interstice as literal text.
search = special_char_re.search(self.string[self.pos :])
if search is None:
raise ParseError("no special character found")
next_pos = search.start() + self.pos
text_parts.append(self.string[self.pos : next_pos])
self.pos = next_pos
continue
if self.pos == len(self.string) - 1:
# The last character can never begin a structure, so we
# just interpret it as a literal character (unless it
# terminates the expression, as with , and }).
if char not in self.terminator_chars + extra_special_chars:
text_parts.append(char)
self.pos += 1
break
next_char = self.string[self.pos + 1]
if char == ESCAPE_CHAR and next_char in (
self.escapable_chars + extra_special_chars
):
# An escaped special character ($$, $}, etc.). Note that
# ${ is not an escape sequence: this is ambiguous with
# the start of a symbol and it's not necessary (just
# using { suffices in all cases).
text_parts.append(next_char)
self.pos += 2 # Skip the next character.
continue
# Shift all characters collected so far into a single string.
if text_parts:
self.parts.append("".join(text_parts))
text_parts = []
if char == SYMBOL_DELIM:
# Parse a symbol.
self.parse_symbol()
elif char == FUNC_DELIM:
# Parse a function call.
self.parse_call()
elif char in self.terminator_chars + extra_special_chars:
# Template terminated.
break
elif char == GROUP_OPEN:
# Start of a group has no meaning hear; just pass
# through the character.
text_parts.append(char)
self.pos += 1
else:
assert False
# If any parsed characters remain, shift them into a string.
if text_parts:
self.parts.append("".join(text_parts))
def parse_symbol(self) -> None:
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
starting at ``pos``. Possibly appends a Symbol object (or,
failing that, text) to the ``parts`` field and updates ``pos``.
The character at ``pos`` must, as a precondition, be ``$``.
"""
assert self.pos < len(self.string)
assert self.string[self.pos] == SYMBOL_DELIM
if self.pos == len(self.string) - 1:
# Last character.
self.parts.append(SYMBOL_DELIM)
self.pos += 1
return
next_char = self.string[self.pos + 1]
start_pos = self.pos
self.pos += 1
if next_char == GROUP_OPEN:
# A symbol like ${this}.
self.pos += 1 # Skip opening.
closer = self.string.find(GROUP_CLOSE, self.pos)
if closer == -1 or closer == self.pos:
# No closing brace found or identifier is empty.
self.parts.append(self.string[start_pos : self.pos])
else:
# Closer found.
ident = self.string[self.pos : closer]
self.pos = closer + 1
self.parts.append(Symbol(ident, self.string[start_pos : self.pos]))
else:
# A bare-word symbol.
ident = self._parse_ident()
if ident:
# Found a real symbol.
self.parts.append(Symbol(ident, self.string[start_pos : self.pos]))
else:
# A standalone $.
self.parts.append(SYMBOL_DELIM)
def parse_call(self):
"""Parse a function call (like ``%foo{bar,baz}``) starting at
``pos``. Possibly appends a Call object to ``parts`` and update
``pos``. The character at ``pos`` must be ``%``.
"""
assert self.pos < len(self.string)
assert self.string[self.pos] == FUNC_DELIM
start_pos = self.pos
self.pos += 1
ident = self._parse_ident()
if not ident:
# No function name.
self.parts.append(FUNC_DELIM)
return
if self.pos >= len(self.string):
# Identifier terminates string.
self.parts.append(self.string[start_pos : self.pos])
return
if self.string[self.pos] != GROUP_OPEN:
# Argument list not opened.
self.parts.append(self.string[start_pos : self.pos])
return
# Skip past opening brace and try to parse an argument list.
self.pos += 1
args = self.parse_argument_list()
if self.pos >= len(self.string) or self.string[self.pos] != GROUP_CLOSE:
# Arguments unclosed.
self.parts.append(self.string[start_pos : self.pos])
return
self.pos += 1 # Move past closing brace.
self.parts.append(Call(ident, args, self.string[start_pos : self.pos]))
def parse_argument_list(self) -> list[Expression]:
"""Parse a list of arguments starting at ``pos``, returning a
list of Expression objects. Does not modify ``parts``. Should
leave ``pos`` pointing to a } character or the end of the
string.
"""
# Try to parse a subexpression in a subparser.
expressions: list[Expression] = []
while self.pos < len(self.string):
subparser = Parser(self.string[self.pos :], in_argument=True)
subparser.parse_expression()
# Extract and advance past the parsed expression.
expressions.append(Expression(subparser.parts))
self.pos += subparser.pos
if self.pos >= len(self.string) or self.string[self.pos] == GROUP_CLOSE:
# Argument list terminated by EOF or closing brace.
break
# Only other way to terminate an expression is with ,.
# Continue to the next argument.
assert self.string[self.pos] == ARG_SEP
self.pos += 1
return expressions
def _parse_ident(self) -> str:
"""Parse an identifier and return it (possibly an empty string).
Updates ``pos``.
"""
remainder = self.string[self.pos :]
match = re.match(r"\w*", remainder)
if match is None:
raise ParseError("invalid identifier")
ident = match.group(0)
self.pos += len(ident)
return ident
def _parse(template: str) -> Expression:
"""Parse a top-level template string Expression. Any extraneous text
is considered literal text.
"""
parser = Parser(template)
parser.parse_expression()
parts = parser.parts
remainder = parser.string[parser.pos :]
if remainder:
parts.append(remainder)
return Expression(parts)
@functools.lru_cache(maxsize=128)
def template(fmt: str) -> Template:
return Template(fmt)
# External interface.
[docs]
class Template:
"""A string template, including text, Symbols, and Calls."""
expr: Expression
original: str
def __init__(self, template: str) -> None:
self.expr = _parse(template)
self.original = template
self.compiled = self.translate()
def __eq__(self, other: Any) -> bool:
return self.original == other.original
[docs]
def interpret(self, values: Values = {}, functions: FunctionCollection = {}) -> str:
"""Like `substitute`, but forces the interpreter (rather than
the compiled version) to be used. The interpreter includes
exception-handling code for missing variables and buggy template
functions but is much slower.
"""
return self.expr.evaluate(Environment(values, functions))
[docs]
def substitute(
self, values: Values = {}, functions: FunctionCollection = {}
) -> str:
"""Evaluate the template given the values and functions."""
try:
res = self.compiled(values, functions)
except Exception: # Handle any exceptions thrown by compiled version.
res = self.interpret(values, functions)
return res
[docs]
def translate(self) -> Callable[..., str]:
"""Compile the template to a Python function."""
expressions, varnames, funcnames = self.expr.translate()
argnames: list[str] = []
for varname in varnames:
argnames.append(VARIABLE_PREFIX + varname)
for funcname in funcnames:
argnames.append(FUNCTION_PREFIX + funcname)
func = compile_func(
argnames,
[ast.Return(ast.List(expressions, ast.Load()))],
)
def wrapper_func(
values: Values = {}, functions: FunctionCollection = {}
) -> str:
args = {}
for varname in varnames:
args[VARIABLE_PREFIX + varname] = values[varname]
for funcname in funcnames:
args[FUNCTION_PREFIX + funcname] = functions[funcname]
parts = func(**args)
return "".join(parts)
return wrapper_func
__all__: list[str] = ["Template"]