xonsh/amalgamate.py

566 lines
17 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
2016-06-12 11:07:09 -04:00
"""A package-based, source code amalgamater."""
import os
2016-06-12 12:31:50 -04:00
import sys
2016-06-12 13:18:50 -04:00
import pprint
2016-06-13 00:46:35 -04:00
from itertools import repeat
2016-06-12 12:31:50 -04:00
from collections import namedtuple
from collections.abc import Mapping
2016-07-18 10:07:30 +02:00
from ast import parse, walk, Import, ImportFrom
2016-06-12 12:31:50 -04:00
2018-08-30 09:17:34 -05:00
__version__ = "0.1.2"
2016-07-21 00:12:11 -04:00
2018-08-30 09:17:34 -05:00
ModNode = namedtuple("ModNode", ["name", "pkgdeps", "extdeps", "futures"])
2016-06-12 12:31:50 -04:00
ModNode.__doc__ = """Module node for dependency graph.
2016-06-16 20:20:10 -04:00
Attributes
----------
2016-06-12 12:31:50 -04:00
name : str
Module name.
pkgdeps : frozenset of str
Module dependencies in the same package.
extdeps : frozenset of str
External module dependencies from outside of the package.
2016-07-21 00:12:11 -04:00
futures : frozenset of str
Import directive names antecedent to 'from __future__ import'
2016-06-12 12:31:50 -04:00
"""
2016-06-16 20:20:10 -04:00
2016-06-12 12:31:50 -04:00
class SourceCache(Mapping):
"""Stores / loads source code for files based on package and module names."""
2016-06-12 13:18:50 -04:00
def __init__(self, *args, **kwargs):
2016-06-12 12:31:50 -04:00
self._d = dict(*args, **kwargs)
def __getitem__(self, key):
d = self._d
if key in d:
return d[key]
pkg, name = key
2018-08-30 09:17:34 -05:00
pkgdir = pkg.replace(".", os.sep)
fname = pkgdir + os.sep + name + ".py"
with open(fname, encoding="utf-8", errors="surrogateescape") as f:
2016-06-12 12:31:50 -04:00
raw = f.read()
d[key] = raw
return raw
def __iter__(self):
yield from self._d
def __len__(self):
return len(self._d)
SOURCES = SourceCache()
2016-07-18 23:46:33 -04:00
class GlobalNames(object):
"""Stores globally defined names that have been seen on ast nodes."""
2018-08-30 09:17:34 -05:00
impnodes = frozenset(["import", "importfrom"])
2016-07-18 23:46:33 -04:00
2018-08-30 09:17:34 -05:00
def __init__(self, pkg="<pkg>"):
2016-07-18 23:46:33 -04:00
self.cache = {}
self.pkg = pkg
2018-08-30 09:17:34 -05:00
self.module = "<mod>"
2016-07-18 23:46:33 -04:00
self.topnode = None
def warn_duplicates(self):
2018-08-30 09:17:34 -05:00
s = ""
2016-07-18 23:46:33 -04:00
for key in sorted(self.cache.keys()):
val = self.cache[key]
if len(val) < 2:
continue
val = sorted(val)
2016-07-19 09:21:47 -04:00
if all([val[0][0] == x[0] for x in val[1:]]):
continue
2018-08-30 09:17:34 -05:00
s += "WARNING: {0!r} defined in multiple locations:\n".format(key)
2016-07-18 23:46:33 -04:00
for loc in val:
2018-08-30 09:17:34 -05:00
s += " {}:{} ({})\n".format(*loc)
2016-07-18 23:46:33 -04:00
if len(s) > 0:
2018-08-30 09:17:34 -05:00
print(s, end="", flush=True, file=sys.stderr)
2016-07-18 23:46:33 -04:00
def entry(self, name, lineno):
2018-08-30 09:17:34 -05:00
if name.startswith("__"):
2016-07-19 00:09:23 -04:00
return
2016-07-18 23:46:33 -04:00
topnode = self.topnode
2018-08-30 09:17:34 -05:00
e = (self.pkg + "." + self.module, lineno, topnode)
2016-07-18 23:46:33 -04:00
if name in self.cache:
2018-08-30 09:17:34 -05:00
if topnode in self.impnodes and all(
[topnode == x[2] for x in self.cache[name]]
):
2016-07-18 23:46:33 -04:00
return
self.cache[name].add(e)
else:
self.cache[name] = set([e])
def add(self, node, istopnode=False):
"""Adds the names from the node to the cache."""
nodename = node.__class__.__name__.lower()
if istopnode:
self.topnode = nodename
2018-08-30 09:17:34 -05:00
meth = getattr(self, "_add_" + nodename, None)
2016-07-18 23:46:33 -04:00
if meth is not None:
meth(node)
def _add_name(self, node):
self.entry(node.id, node.lineno)
def _add_tuple(self, node):
for x in node.elts:
self.add(x)
def _add_assign(self, node):
for target in node.targets:
self.add(target)
def _add_functiondef(self, node):
self.entry(node.name, node.lineno)
def _add_classdef(self, node):
self.entry(node.name, node.lineno)
def _add_import(self, node):
lineno = node.lineno
for target in node.names:
if target.asname is None:
2018-08-30 09:17:34 -05:00
name, _, _ = target.name.partition(".")
2016-07-18 23:46:33 -04:00
else:
name = target.asname
self.entry(name, lineno)
def _add_importfrom(self, node):
2016-07-23 13:53:49 -04:00
pkg, _ = resolve_package_module(node.module, self.pkg, node.level)
2016-07-19 00:09:23 -04:00
if pkg == self.pkg:
return
2016-07-18 23:46:33 -04:00
lineno = node.lineno
for target in node.names:
if target.asname is None:
name = target.name
else:
name = target.asname
self.entry(name, lineno)
def _add_with(self, node):
for item in node.items:
if item.optional_vars is None:
continue
self.add(item.optional_vars)
for child in node.body:
self.add(child, istopnode=True)
def _add_for(self, node):
self.add(node.target)
for child in node.body:
self.add(child, istopnode=True)
def _add_while(self, node):
for child in node.body:
self.add(child, istopnode=True)
def _add_if(self, node):
for child in node.body:
self.add(child, istopnode=True)
2016-07-19 09:21:47 -04:00
for child in node.orelse:
self.add(child, istopnode=True)
2016-07-18 23:46:33 -04:00
def _add_try(self, node):
for child in node.body:
self.add(child, istopnode=True)
2016-07-23 13:53:49 -04:00
def module_is_package(module, pkg, level):
"""Returns whether or not the module name refers to the package."""
if level == 0:
return module == pkg
elif level == 1:
return module is None
else:
return False
def module_from_package(module, pkg, level):
"""Returns whether or not a module is from the package."""
if level == 0:
2018-08-30 09:17:34 -05:00
return module.startswith(pkg + ".")
2016-07-23 13:53:49 -04:00
elif level == 1:
return True
else:
2016-07-23 14:18:22 -04:00
return False
2016-07-23 13:53:49 -04:00
def resolve_package_module(module, pkg, level, default=None):
"""Returns a 2-tuple of package and module name, even for relative
imports
"""
if level == 0:
2018-08-30 09:17:34 -05:00
p, _, m = module.rpartition(".")
2016-07-23 13:53:49 -04:00
elif level == 1:
p = pkg
m = module or default
else:
p = m = None
return p, m
2016-07-18 23:46:33 -04:00
def make_node(name, pkg, allowed, glbnames):
2016-06-12 12:31:50 -04:00
"""Makes a node by parsing a file and traversing its AST."""
raw = SOURCES[pkg, name]
tree = parse(raw, filename=name)
# we only want to deal with global import statements
pkgdeps = set()
extdeps = set()
2016-07-21 00:12:11 -04:00
futures = set()
2016-07-18 23:46:33 -04:00
glbnames.module = name
2016-06-12 12:31:50 -04:00
for a in tree.body:
2016-07-19 00:09:23 -04:00
glbnames.add(a, istopnode=True)
2016-06-12 12:31:50 -04:00
if isinstance(a, Import):
for n in a.names:
2018-08-30 09:17:34 -05:00
p, dot, m = n.name.rpartition(".")
2016-06-12 13:18:50 -04:00
if p == pkg and m in allowed:
pkgdeps.add(m)
2016-06-12 12:31:50 -04:00
else:
extdeps.add(n.name)
elif isinstance(a, ImportFrom):
2016-07-23 13:53:49 -04:00
if module_is_package(a.module, pkg, a.level):
2016-06-12 13:18:50 -04:00
pkgdeps.update(n.name for n in a.names if n.name in allowed)
2016-07-23 13:53:49 -04:00
elif module_from_package(a.module, pkg, a.level):
2018-08-30 09:17:34 -05:00
p, m = resolve_package_module(
a.module, pkg, a.level, default=a.names[0].name
)
2016-06-12 13:18:50 -04:00
if p == pkg and m in allowed:
2016-06-12 12:31:50 -04:00
pkgdeps.add(m)
else:
extdeps.add(a.module)
2018-08-30 09:17:34 -05:00
elif a.module == "__future__":
2016-07-21 00:12:11 -04:00
futures.update(n.name for n in a.names)
2018-08-30 09:17:34 -05:00
return ModNode(name, frozenset(pkgdeps), frozenset(extdeps), frozenset(futures))
2016-06-12 12:31:50 -04:00
2016-06-12 13:26:59 -04:00
def make_graph(pkg, exclude=None):
2016-06-12 12:31:50 -04:00
"""Create a graph (dict) of module dependencies."""
graph = {}
2018-08-30 09:17:34 -05:00
pkgdir = pkg.replace(".", os.sep)
2016-06-12 13:18:50 -04:00
allowed = set()
files = os.listdir(pkgdir)
for fname in files:
2016-06-12 12:31:50 -04:00
base, ext = os.path.splitext(fname)
2018-08-30 09:17:34 -05:00
if base.startswith("__") or ext != ".py":
2016-06-12 12:31:50 -04:00
continue
2016-06-12 13:18:50 -04:00
allowed.add(base)
2016-06-12 13:26:59 -04:00
if exclude:
allowed -= exclude
2016-07-18 23:46:33 -04:00
glbnames = GlobalNames(pkg=pkg)
2016-06-12 13:18:50 -04:00
for base in allowed:
2016-07-18 23:46:33 -04:00
graph[base] = make_node(base, pkg, allowed, glbnames)
glbnames.warn_duplicates()
2016-06-12 12:31:50 -04:00
return graph
2016-06-12 13:18:50 -04:00
def depsort(graph):
"""Sort modules by dependency."""
remaining = set(graph.keys())
seder = []
solved = set()
while 0 < len(remaining):
nodeps = {m for m in remaining if len(graph[m].pkgdeps - solved) == 0}
if len(nodeps) == 0:
2018-08-30 09:17:34 -05:00
msg = (
"\nsolved order = {0}\nremaining = {1}\nCycle detected in "
"module graph!"
).format(pprint.pformat(seder), pprint.pformat(remaining))
2016-06-12 13:18:50 -04:00
raise RuntimeError(msg)
solved |= nodeps
remaining -= nodeps
seder += sorted(nodeps)
return seder
2016-06-12 12:31:50 -04:00
2016-06-13 22:12:41 -04:00
LAZY_IMPORTS = """
from sys import modules as _modules
from types import ModuleType as _ModuleType
2016-06-13 22:55:02 -04:00
from importlib import import_module as _import_module
2016-06-13 22:12:41 -04:00
class _LazyModule(_ModuleType):
2016-06-13 22:39:14 -04:00
def __init__(self, pkg, mod, asname=None):
2016-06-13 22:12:41 -04:00
'''Lazy module 'pkg.mod' in package 'pkg'.'''
2016-06-14 23:46:13 -04:00
self.__dct__ = {
'loaded': False,
'pkg': pkg, # pkg
2016-06-15 00:44:01 -04:00
'mod': mod, # pkg.mod
'asname': asname, # alias
2016-06-14 23:46:13 -04:00
}
2016-06-13 22:12:41 -04:00
@classmethod
2016-06-13 22:39:14 -04:00
def load(cls, pkg, mod, asname=None):
2016-06-13 22:12:41 -04:00
if mod in _modules:
2016-07-28 04:47:39 -04:00
key = pkg if asname is None else mod
return _modules[key]
2016-06-13 22:12:41 -04:00
else:
2016-06-13 22:39:14 -04:00
return cls(pkg, mod, asname)
2016-06-13 22:12:41 -04:00
def __getattribute__(self, name):
2016-06-13 22:55:02 -04:00
if name == '__dct__':
return super(_LazyModule, self).__getattribute__(name)
2016-06-13 22:55:02 -04:00
dct = self.__dct__
2016-06-13 22:39:14 -04:00
mod = dct['mod']
if dct['loaded']:
2016-06-13 22:12:41 -04:00
m = _modules[mod]
else:
m = _import_module(mod)
2016-06-13 22:39:14 -04:00
glbs = globals()
pkg = dct['pkg']
asname = dct['asname']
if asname is None:
2016-06-26 14:41:53 -04:00
glbs[pkg] = m = _modules[pkg]
2016-06-13 22:39:14 -04:00
else:
glbs[asname] = m
dct['loaded'] = True
2016-06-13 22:12:41 -04:00
return getattr(m, name)
"""
2016-07-18 11:30:57 +02:00
2016-06-13 00:46:35 -04:00
def get_lineno(node, default=0):
"""Gets the lineno of a node or returns the default."""
2018-08-30 09:17:34 -05:00
return getattr(node, "lineno", default)
2016-06-13 00:46:35 -04:00
def min_line(node):
"""Computes the minimum lineno."""
node_line = get_lineno(node)
return min(map(get_lineno, walk(node), repeat(node_line)))
def format_import(names):
"""Format an import line"""
parts = []
for _, name, asname in names:
if asname is None:
parts.append(name)
else:
2018-08-30 09:17:34 -05:00
parts.append(name + " as " + asname)
line = "import " + ", ".join(parts) + "\n"
2016-06-13 00:46:35 -04:00
return line
2016-06-13 22:39:14 -04:00
def format_lazy_import(names):
"""Formats lazy import lines"""
2018-08-30 09:17:34 -05:00
lines = ""
2016-06-13 22:39:14 -04:00
for _, name, asname in names:
2018-08-30 09:17:34 -05:00
pkg, _, _ = name.partition(".")
2016-06-13 22:39:14 -04:00
if asname is None:
2018-08-30 09:17:34 -05:00
line = "{pkg} = _LazyModule.load({pkg!r}, {mod!r})\n"
2016-06-13 22:39:14 -04:00
else:
2018-08-30 09:17:34 -05:00
line = "{asname} = _LazyModule.load({pkg!r}, {mod!r}, {asname!r})\n"
2016-06-13 22:39:14 -04:00
lines += line.format(pkg=pkg, mod=name, asname=asname)
return lines
2016-06-13 01:23:11 -04:00
def format_from_import(names):
"""Format a from import line"""
parts = []
for _, module, name, asname in names: # noqa
2016-06-13 01:23:11 -04:00
if asname is None:
parts.append(name)
else:
2018-08-30 09:17:34 -05:00
parts.append(name + " as " + asname)
line = "from " + module
line += " import " + ", ".join(parts) + "\n"
2016-06-13 01:23:11 -04:00
return line
2016-06-13 00:46:35 -04:00
def rewrite_imports(name, pkg, order, imps):
"""Rewrite the global imports in the file given the amalgamation."""
raw = SOURCES[pkg, name]
tree = parse(raw, filename=name)
replacements = [] # list of (startline, stopline, str) tuples
# collect replacements in forward direction
for a, b in zip(tree.body, tree.body[1:] + [None]):
2016-06-13 01:28:27 -04:00
if not isinstance(a, (Import, ImportFrom)):
2016-06-13 00:46:35 -04:00
continue
start = min_line(a) - 1
stop = len(tree.body) if b is None else min_line(b) - 1
if isinstance(a, Import):
keep = []
for n in a.names:
2018-08-30 09:17:34 -05:00
p, dot, m = n.name.rpartition(".")
2016-06-13 01:23:11 -04:00
if p == pkg and m in order:
2018-08-30 09:17:34 -05:00
msg = (
"Cannot amalgamate import of amalgamated module:"
"\n\n import {0}.{1}\n\nin {0}/{2}.py"
).format(pkg, n.name, name)
2016-06-13 01:23:11 -04:00
raise RuntimeError(msg)
2016-06-13 00:46:35 -04:00
imp = (Import, n.name, n.asname)
if imp not in imps:
imps.add(imp)
keep.append(imp)
2016-06-16 20:20:10 -04:00
if len(keep) == 0:
2018-08-30 09:17:34 -05:00
s = ", ".join(n.name for n in a.names)
s = "# amalgamated " + s + "\n"
2016-06-13 01:23:11 -04:00
else:
2016-06-13 22:39:14 -04:00
s = format_lazy_import(keep)
2016-06-13 00:46:35 -04:00
replacements.append((start, stop, s))
elif isinstance(a, ImportFrom):
2018-08-30 09:17:34 -05:00
p, m = resolve_package_module(a.module, pkg, a.level, default="")
2016-07-23 13:53:49 -04:00
if module_is_package(a.module, pkg, a.level):
2016-06-13 01:23:11 -04:00
for n in a.names:
if n.name in order:
2018-08-30 09:17:34 -05:00
msg = (
"Cannot amalgamate import of "
"amalgamated module:\n\n from {0} import {1}\n"
"\nin {0}/{2}.py"
).format(pkg, n.name, name)
2016-06-13 01:23:11 -04:00
raise RuntimeError(msg)
2016-07-23 13:53:49 -04:00
elif p == pkg and m in order:
2018-08-30 09:17:34 -05:00
replacements.append(
(start, stop, "# amalgamated " + p + "." + m + "\n")
)
elif a.module == "__future__":
replacements.append(
(start, stop, "# amalgamated __future__ directive\n")
)
2016-06-13 01:23:11 -04:00
else:
keep = []
for n in a.names:
imp = (ImportFrom, a.module, n.name, n.asname)
if imp not in imps:
imps.add(imp)
keep.append(imp)
if len(keep) == len(a.names):
continue # all new imports
elif len(keep) == 0:
2018-08-30 09:17:34 -05:00
s = ", ".join(n.name for n in a.names)
s = "# amalgamated from " + a.module + " import " + s + "\n"
2016-06-13 00:46:35 -04:00
else:
2016-06-13 01:23:11 -04:00
s = format_from_import(keep)
replacements.append((start, stop, s))
2016-06-13 00:46:35 -04:00
# apply replacements in reverse
lines = raw.splitlines(keepends=True)
for start, stop, s in replacements[::-1]:
lines[start] = s
for _ in range(stop - start - 1):
2018-07-15 17:00:46 -05:00
del lines[start + 1]
2018-08-30 09:17:34 -05:00
return "".join(lines)
2016-06-13 00:46:35 -04:00
2016-07-21 00:12:11 -04:00
def sorted_futures(graph):
"""Returns a sorted, unique list of future imports."""
f = set()
for value in graph.values():
f |= value.futures
return sorted(f)
2016-06-13 00:46:35 -04:00
def amalgamate(order, graph, pkg):
"""Create amalgamated source."""
2018-08-30 09:17:34 -05:00
src = (
'"""Amalgamation of {0} package, made up of the following '
"modules, in order:\n\n* "
).format(pkg)
src += "\n* ".join(order)
src += '\n\n"""\n'
2016-07-21 00:12:11 -04:00
futures = sorted_futures(graph)
if len(futures) > 0:
2018-08-30 09:17:34 -05:00
src += "from __future__ import " + ", ".join(futures) + "\n"
2016-06-13 22:12:41 -04:00
src += LAZY_IMPORTS
2016-06-13 00:46:35 -04:00
imps = set()
for name in order:
lines = rewrite_imports(name, pkg, order, imps)
2018-08-30 09:17:34 -05:00
src += "#\n# " + name + "\n#\n" + lines + "\n"
2016-06-13 00:46:35 -04:00
return src
def write_amalgam(src, pkg):
"""Write out __amalgam__.py file"""
2018-08-30 09:17:34 -05:00
pkgdir = pkg.replace(".", os.sep)
fname = os.path.join(pkgdir, "__amalgam__.py")
with open(fname, "w", encoding="utf-8", errors="surrogateescape") as f:
2016-06-13 00:46:35 -04:00
f.write(src)
2016-06-13 02:37:54 -04:00
def _init_name_lines(pkg):
2018-08-30 09:17:34 -05:00
pkgdir = pkg.replace(".", os.sep)
fname = os.path.join(pkgdir, "__init__.py")
with open(fname, encoding="utf-8", errors="surrogateescape") as f:
2016-06-13 02:37:54 -04:00
raw = f.read()
lines = raw.splitlines()
return fname, lines
def read_exclude(pkg):
"""reads in modules to exclude from __init__.py"""
_, lines = _init_name_lines(pkg)
exclude = set()
for line in lines:
2018-08-30 09:17:34 -05:00
if line.startswith("# amalgamate exclude"):
2016-06-13 02:37:54 -04:00
exclude.update(line.split()[3:])
return exclude
FAKE_LOAD = """
2016-06-15 22:11:17 -04:00
import os as _os
if _os.getenv("{debug}", ""):
2016-06-13 02:37:54 -04:00
pass
2016-06-15 22:11:17 -04:00
else:
import sys as _sys
2016-06-15 22:11:17 -04:00
try:
from {pkg} import __amalgam__
2016-06-15 22:11:17 -04:00
{load}
del __amalgam__
except ImportError:
pass
del _sys
del _os
2016-06-13 02:37:54 -04:00
""".strip()
2018-08-30 09:17:34 -05:00
def rewrite_init(pkg, order, debug="DEBUG"):
2016-06-13 02:37:54 -04:00
"""Rewrites the init file to insert modules."""
fname, lines = _init_name_lines(pkg)
start, stop = -1, -1
2016-06-13 02:37:54 -04:00
for i, line in enumerate(lines):
2018-08-30 09:17:34 -05:00
if line.startswith("# amalgamate end"):
2016-06-13 02:37:54 -04:00
stop = i
2018-08-30 09:17:34 -05:00
elif line.startswith("# amalgamate"):
2016-06-13 02:37:54 -04:00
start = i
2020-08-26 10:10:59 -05:00
t = "{1} = __amalgam__\n " '_sys.modules["{0}.{1}"] = __amalgam__'
2018-08-30 09:17:34 -05:00
load = "\n ".join(t.format(pkg, m) for m in order)
2016-06-16 11:03:43 -04:00
s = FAKE_LOAD.format(pkg=pkg, load=load, debug=debug)
2016-06-13 02:37:54 -04:00
if start + 1 == stop:
lines.insert(stop, s)
else:
2018-07-15 17:00:46 -05:00
lines[start + 1] = s
2018-08-30 09:17:34 -05:00
lines = lines[: start + 2] + lines[stop:]
init = "\n".join(lines) + "\n"
with open(fname, "w", encoding="utf-8", errors="surrogateescape") as f:
2016-06-13 02:37:54 -04:00
f.write(init)
2016-06-12 12:31:50 -04:00
def main(args=None):
if args is None:
args = sys.argv
2018-08-30 09:17:34 -05:00
debug = "DEBUG"
2016-06-13 01:28:27 -04:00
for pkg in args[1:]:
2018-08-30 09:17:34 -05:00
if pkg.startswith("--debug="):
2016-06-16 11:03:43 -04:00
debug = pkg[8:]
continue
2018-08-30 09:17:34 -05:00
print("Amalgamating " + pkg)
2016-06-13 02:37:54 -04:00
exclude = read_exclude(pkg)
2018-08-30 09:17:34 -05:00
print(" excluding {}".format(pprint.pformat(exclude or None)))
2016-06-13 02:37:54 -04:00
graph = make_graph(pkg, exclude=exclude)
2016-06-13 00:46:35 -04:00
order = depsort(graph)
2016-06-13 01:28:27 -04:00
src = amalgamate(order, graph, pkg)
2016-06-13 00:46:35 -04:00
write_amalgam(src, pkg)
2016-06-16 11:03:43 -04:00
rewrite_init(pkg, order, debug=debug)
2018-08-30 09:17:34 -05:00
print(" collapsed {} modules".format(len(order)))
2016-06-12 12:31:50 -04:00
2018-08-30 09:17:34 -05:00
if __name__ == "__main__":
2016-06-15 00:44:01 -04:00
main()