Make ziploader handle python packages as well as python modules

pull/4420/head
Toshio Kuratomi 2016-04-19 00:55:19 -07:00
parent c600ab81ee
commit 5fc90058e4
1 changed files with 124 additions and 36 deletions

View File

@ -22,6 +22,7 @@ __metaclass__ = type
import ast import ast
import base64 import base64
import imp
import json import json
import os import os
import shlex import shlex
@ -255,31 +256,48 @@ class ModuleDepFinder(ast.NodeVisitor):
# Caveats: # Caveats:
# This code currently does not handle: # This code currently does not handle:
# * relative imports from py2.6+ from . import urls # * relative imports from py2.6+ from . import urls
# * python packages (directories with __init__.py in them)
IMPORT_PREFIX_SIZE = len('ansible.module_utils.') IMPORT_PREFIX_SIZE = len('ansible.module_utils.')
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""
Walk the ast tree for the python module.
Save submodule[.submoduleN][.identifier] into self.submodules
self.submodules will end up with tuples like:
- ('basic',)
- ('urls', 'fetch_url')
- ('database', 'postgres')
- ('database', 'postgres', 'quote')
It's up to calling code to determine whether the final element of the
dotted strings are module names or something else (function, class, or
variable names)
"""
super(ModuleDepFinder, self).__init__(*args, **kwargs) super(ModuleDepFinder, self).__init__(*args, **kwargs)
self.module_files = set() self.submodules = set()
def visit_Import(self, node): def visit_Import(self, node):
# import ansible.module_utils.MODLIB[.other] # import ansible.module_utils.MODLIB[.MODLIBn] [as asname]
for alias in (a for a in node.names if a.name.startswith('ansible.module_utils.')): for alias in (a for a in node.names if a.name.startswith('ansible.module_utils.')):
py_mod = alias.name[self.IMPORT_PREFIX_SIZE:].split('.', 1)[0] py_mod = alias.name[self.IMPORT_PREFIX_SIZE:]
self.module_files.add(py_mod) self.submodules.add((py_mod,))
self.generic_visit(node) self.generic_visit(node)
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
if node.module.startswith('ansible.module_utils'): if node.module.startswith('ansible.module_utils'):
where_from = node.module[self.IMPORT_PREFIX_SIZE:] where_from = node.module[self.IMPORT_PREFIX_SIZE:]
# from ansible.module_utils.MODLIB[.other] import foo
if where_from: if where_from:
py_mod = where_from.split('.', 1)[0] # from ansible.module_utils.MODULE1[.MODULEn] import IDENTIFIER [as asname]
self.module_files.add(py_mod) # from ansible.module_utils.MODULE1[.MODULEn] import MODULEn+1 [as asname]
else: # from ansible.module_utils.MODULE1[.MODULEn] import MODULEn+1 [,IDENTIFIER] [as asname]
# from ansible.module_utils import MODLIB py_mod = tuple(where_from.split('.'))
for alias in node.names: for alias in node.names:
self.module_files.add(alias.name) self.submodules.add(py_mod + (alias.name,))
else:
# from ansible.module_utils import MODLIB [,MODLIB2] [as asname]
for alias in node.names:
self.submodules.add((alias.name,))
self.generic_visit(node) self.generic_visit(node)
@ -298,7 +316,7 @@ STRIPPED_ZIPLOADER_TEMPLATE = _strip_comments(ZIPLOADER_TEMPLATE)
def _slurp(path): def _slurp(path):
if not os.path.exists(path): if not os.path.exists(path):
raise AnsibleError("imported module support code does not exist at %s" % path) raise AnsibleError("imported module support code does not exist at %s" % os.path.abspath(path))
fd = open(path, 'rb') fd = open(path, 'rb')
data = fd.read() data = fd.read()
fd.close() fd.close()
@ -331,27 +349,101 @@ def _get_facility(task_vars):
facility = task_vars['ansible_syslog_facility'] facility = task_vars['ansible_syslog_facility']
return facility return facility
def recursive_finder(data, snippet_names, snippet_data, zf): def recursive_finder(name, data, py_module_names, py_module_cache, zf):
""" """
Using ModuleDepFinder, make sure we have all of the module_utils files that Using ModuleDepFinder, make sure we have all of the module_utils files that
the module its module_utils files needs. the module its module_utils files needs.
""" """
# Parse the module and find the imports of ansible.module_utils
tree = ast.parse(data) tree = ast.parse(data)
finder = ModuleDepFinder() finder = ModuleDepFinder()
finder.visit(tree) finder.visit(tree)
new_snippets = set() #
for snippet_name in finder.module_files.difference(snippet_names): # Determine what imports that we've found are modules (vs class, function.
fname = '%s.py' % snippet_name # variable names) for packages
new_snippets.add(snippet_name) #
if snippet_name not in snippet_data:
snippet_data[snippet_name] = _slurp(os.path.join(_SNIPPET_PATH, fname))
zf.writestr(os.path.join("ansible/module_utils", fname), snippet_data[snippet_name])
snippet_names.update(new_snippets)
for snippet_name in tuple(new_snippets): normalized_modules = set()
recursive_finder(snippet_data[snippet_name], snippet_names, snippet_data, zf) # Loop through the imports that we've found to normalize them
del snippet_data[snippet_name] # Exclude paths that match with paths we've already processed
# (Have to exclude them a second time once the paths are processed)
for py_module_name in finder.submodules.difference(py_module_names):
module_info = None
# Check whether either the last or the second to last identifier is
# a module name
for idx in (1, 2):
if len(py_module_name) < idx:
break
try:
module_info = imp.find_module(py_module_name[-idx],
[os.path.join(_SNIPPET_PATH, *py_module_name[:-idx])])
break
except ImportError:
continue
# Could not find the module. Construct a helpful error message.
if module_info is None:
msg = ['Could not find imported module support code for %s. Looked for' % name]
if idx == 2:
msg.append('either %s or %s' % (py_module_name[-1], py_module_name[-2]))
else:
msg.append(py_module_name[-1])
raise AnsibleError(' '.join(msg))
if idx == 2:
# We've determined that the last portion was an identifier and
# thus, not part of the module name
py_module_name = py_module_name[:-1]
# If not already processed then we've got work to do
if py_module_name not in py_module_names:
# If not in the cache, then read the file into the cache
# We already have a file handle for the module open so it makes
# sense to read it now
if py_module_name not in py_module_cache:
if module_info[2][2] == imp.PKG_DIRECTORY:
# Read the __init__.py instead of the module file as this is
# a python package
py_module_cache[py_module_name + ('__init__',)] = _slurp(os.path.join(os.path.join(_SNIPPET_PATH, *py_module_name), '__init__.py'))
normalized_modules.add(py_module_name + ('__init__',))
else:
py_module_cache[py_module_name] = module_info[0].read()
module_info[0].close()
normalized_modules.add(py_module_name)
# Make sure that all the packages that this module is a part of
# are also added
for i in range(1, len(py_module_name)):
py_pkg_name = py_module_name[:-i] + ('__init__',)
if py_pkg_name not in py_module_names:
normalized_modules.add(py_pkg_name)
py_module_cache[py_pkg_name] = _slurp('%s.py' % os.path.join(_SNIPPET_PATH, *py_pkg_name))
#
# iterate through all of the ansible.module_utils* imports that we haven't
# already checked for new imports
#
# set of modules that we haven't added to the zipfile
unprocessed_py_module_names = normalized_modules.difference(py_module_names)
for py_module_name in unprocessed_py_module_names:
py_module_path = os.path.join(*py_module_name)
py_module_file_name = '%s.py' % py_module_path
zf.writestr(os.path.join("ansible/module_utils",
py_module_file_name), py_module_cache[py_module_name])
# Add the names of the files we're scheduling to examine in the loop to
# py_module_names so that we don't re-examine them in the next pass
# through recursive_finder()
py_module_names.update(unprocessed_py_module_names)
for py_module_file in unprocessed_py_module_names:
recursive_finder(py_module_file, py_module_cache[py_module_file], py_module_names, py_module_cache, zf)
# Save memory; the file won't have to be read again for this ansible module.
del py_module_cache[py_module_file]
def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression): def _find_snippet_imports(module_name, module_data, module_path, module_args, task_vars, module_compression):
""" """
@ -392,7 +484,7 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
return module_data, module_style, shebang return module_data, module_style, shebang
output = BytesIO() output = BytesIO()
snippet_names = set() py_module_names = set()
if module_substyle == 'python': if module_substyle == 'python':
# ziploader for new-style python classes # ziploader for new-style python classes
@ -403,8 +495,6 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
params = dict(ANSIBLE_MODULE_ARGS=module_args, params = dict(ANSIBLE_MODULE_ARGS=module_args,
ANSIBLE_MODULE_CONSTANTS=constants, ANSIBLE_MODULE_CONSTANTS=constants,
) )
#python_repred_args = to_bytes(repr(module_args_json))
#python_repred_constants = to_bytes(repr(json.dumps(constants)), errors='strict')
python_repred_params = to_bytes(repr(json.dumps(params)), errors='strict') python_repred_params = to_bytes(repr(json.dumps(params)), errors='strict')
try: try:
@ -421,7 +511,7 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
if os.path.exists(cached_module_filename): if os.path.exists(cached_module_filename):
zipdata = open(cached_module_filename, 'rb').read() zipdata = open(cached_module_filename, 'rb').read()
# Fool the check later... I think we should just remove the check # Fool the check later... I think we should just remove the check
snippet_names.add('basic') py_module_names.add(('basic',))
else: else:
with action_write_locks[module_name]: with action_write_locks[module_name]:
# Check that no other process has created this while we were # Check that no other process has created this while we were
@ -435,8 +525,8 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
zf.writestr('ansible_module_%s.py' % module_name, module_data) zf.writestr('ansible_module_%s.py' % module_name, module_data)
snippet_data = dict() py_module_cache = { ('__init__',): b'' }
recursive_finder(module_data, snippet_names, snippet_data, zf) recursive_finder(module_name, module_data, py_module_names, py_module_cache, zf)
zf.close() zf.close()
zipdata = base64.b64encode(zipoutput.getvalue()) zipdata = base64.b64encode(zipoutput.getvalue())
@ -464,15 +554,13 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
except IOError: except IOError:
raise AnsibleError('A different worker process failed to create module file. Look at traceback for that process for debugging information.') raise AnsibleError('A different worker process failed to create module file. Look at traceback for that process for debugging information.')
# Fool the check later... I think we should just remove the check # Fool the check later... I think we should just remove the check
snippet_names.add('basic') py_module_names.add(('basic',))
shebang, interpreter = _get_shebang(u'/usr/bin/python', task_vars) shebang, interpreter = _get_shebang(u'/usr/bin/python', task_vars)
if shebang is None: if shebang is None:
shebang = u'#!/usr/bin/python' shebang = u'#!/usr/bin/python'
output.write(to_bytes(STRIPPED_ZIPLOADER_TEMPLATE % dict( output.write(to_bytes(STRIPPED_ZIPLOADER_TEMPLATE % dict(
zipdata=zipdata, zipdata=zipdata,
ansible_module=module_name, ansible_module=module_name,
#args=python_repred_args,
#constants=python_repred_constants,
params=python_repred_params, params=python_repred_params,
shebang=shebang, shebang=shebang,
interpreter=interpreter, interpreter=interpreter,
@ -484,7 +572,7 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
# modules that use ziploader may implement their own helpers and not # modules that use ziploader may implement their own helpers and not
# need basic.py. All the constants that we substituted into basic.py # need basic.py. All the constants that we substituted into basic.py
# for module_replacer are now available in other, better ways. # for module_replacer are now available in other, better ways.
if 'basic' not in snippet_names: if ('basic',) not in py_module_names:
raise AnsibleError("missing required import in %s: Did not import ansible.module_utils.basic for boilerplate helper code" % module_path) raise AnsibleError("missing required import in %s: Did not import ansible.module_utils.basic for boilerplate helper code" % module_path)
elif module_substyle == 'powershell': elif module_substyle == 'powershell':
@ -494,7 +582,7 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
if REPLACER_WINDOWS in line: if REPLACER_WINDOWS in line:
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1")) ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
output.write(ps_data) output.write(ps_data)
snippet_names.add(b'powershell') py_module_names.add((b'powershell',))
continue continue
output.write(line + b'\n') output.write(line + b'\n')
module_data = output.getvalue() module_data = output.getvalue()
@ -506,7 +594,7 @@ def _find_snippet_imports(module_name, module_data, module_path, module_args, ta
# get here if we are going to substitute powershell.ps1 into the # get here if we are going to substitute powershell.ps1 into the
# module anyway. Leaving it for when/if we add other powershell # module anyway. Leaving it for when/if we add other powershell
# module_utils files. # module_utils files.
if b'powershell' not in snippet_names: if (b'powershell',) not in py_module_names:
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path) raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
elif module_substyle == 'jsonargs': elif module_substyle == 'jsonargs':