Improve the API for connection plugins and update local and ssh to use it

pull/4420/head
Toshio Kuratomi 2015-04-15 16:32:44 -07:00
parent 1f7d23fc18
commit 01df51d2ae
5 changed files with 126 additions and 72 deletions

View File

@ -374,8 +374,6 @@ class TaskExecutor:
if not connection: if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type) raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
connection.connect()
return connection return connection
def _get_action_handler(self, connection): def _get_action_handler(self, connection):

View File

@ -168,7 +168,7 @@ class ActionBase:
if result['rc'] != 0: if result['rc'] != 0:
if result['rc'] == 5: if result['rc'] == 5:
output = 'Authentication failure.' output = 'Authentication failure.'
elif result['rc'] == 255 and self._connection.get_transport() in ['ssh']: elif result['rc'] == 255 and self._connection.transport in ('ssh',):
# FIXME: more utils.VERBOSITY # FIXME: more utils.VERBOSITY
#if utils.VERBOSITY > 3: #if utils.VERBOSITY > 3:
# output = 'SSH encountered an unknown error. The output was:\n%s' % (result['stdout']+result['stderr']) # output = 'SSH encountered an unknown error. The output was:\n%s' % (result['stdout']+result['stderr'])

View File

@ -1,4 +1,5 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> # (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com>
# #
# This file is part of Ansible # This file is part of Ansible
# #
@ -19,6 +20,10 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from abc import ABCMeta, abstractmethod, abstractproperty
from six import add_metaclass
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
@ -29,7 +34,7 @@ from ansible.utils.display import Display
__all__ = ['ConnectionBase'] __all__ = ['ConnectionBase']
@add_metaclass(ABCMeta)
class ConnectionBase: class ConnectionBase:
''' '''
A base class for connections to contain common code. A base class for connections to contain common code.
@ -39,9 +44,15 @@ class ConnectionBase:
become_methods = C.BECOME_METHODS become_methods = C.BECOME_METHODS
def __init__(self, connection_info, *args, **kwargs): def __init__(self, connection_info, *args, **kwargs):
# All these hasattrs allow subclasses to override these parameters
if not hasattr(self, '_connection_info'):
self._connection_info = connection_info self._connection_info = connection_info
if not hasattr(self, '_display'):
self._display = Display(verbosity=connection_info.verbosity) self._display = Display(verbosity=connection_info.verbosity)
if not hasattr(self, '_connected'):
self._connected = False
self._connect()
def _become_method_supported(self, become_method): def _become_method_supported(self, become_method):
''' Checks if the current class supports this privilege escalation method ''' ''' Checks if the current class supports this privilege escalation method '''
@ -50,3 +61,33 @@ class ConnectionBase:
return True return True
raise AnsibleError("Internal Error: this connection module does not support running commands via %s" % become_method) raise AnsibleError("Internal Error: this connection module does not support running commands via %s" % become_method)
@abstractproperty
def transport(self):
"""String used to identify this Connection class from other classes"""
pass
@abstractmethod
def _connect(self):
"""Connect to the host we've been initialized with"""
pass
@abstractmethod
def exec_command(self, cmd, tmp_path, executable=None, in_data=None):
"""Run a command on the remote host"""
pass
@abstractmethod
def put_file(self, in_path, out_path):
"""Transfer a file from local to remote"""
pass
@abstractmethod
def fetch_file(self, in_path, out_path):
"""Fetch a file from remote to local"""
pass
@abstractmethod
def close(self):
"""Terminate the connection"""
pass

View File

@ -1,4 +1,5 @@
# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com> # (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com>
# #
# This file is part of Ansible # This file is part of Ansible
# #
@ -19,13 +20,12 @@ __metaclass__ = type
import traceback import traceback
import os import os
import pipes
import shutil import shutil
import subprocess import subprocess
import select #import select
import fcntl #import fcntl
from ansible.errors import AnsibleError from ansible.errors import AnsibleError, AnsibleFileNotFound
from ansible.plugins.connections import ConnectionBase from ansible.plugins.connections import ConnectionBase
from ansible.utils.debug import debug from ansible.utils.debug import debug
@ -33,15 +33,17 @@ from ansible.utils.debug import debug
class Connection(ConnectionBase): class Connection(ConnectionBase):
''' Local based connections ''' ''' Local based connections '''
def get_transport(self): @property
def transport(self):
''' used to identify this connection object ''' ''' used to identify this connection object '''
return 'local' return 'local'
def connect(self, port=None): def _connect(self, port=None):
''' connect to the local host; nothing to do here ''' ''' connect to the local host; nothing to do here '''
self._display.vvv("ESTABLISH LOCAL CONNECTION FOR USER: %s" % self._connection_info.remote_user, host=self._connection_info.remote_addr) if not self._connected:
self._display.vvv("ESTABLISH LOCAL CONNECTION FOR USER: {0}".format(self._connection_info.remote_user, host=self._connection_info.remote_addr))
self._connected = True
return self return self
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
@ -57,7 +59,7 @@ class Connection(ConnectionBase):
executable = executable.split()[0] if executable else None executable = executable.split()[0] if executable else None
self._display.vvv("%s EXEC %s" % (self._connection_info.remote_addr, cmd)) self._display.vvv("{0} EXEC {1}".format(self._connection_info.remote_addr, cmd))
# FIXME: cwd= needs to be set to the basedir of the playbook # FIXME: cwd= needs to be set to the basedir of the playbook
debug("opening command with Popen()") debug("opening command with Popen()")
p = subprocess.Popen( p = subprocess.Popen(
@ -106,26 +108,25 @@ class Connection(ConnectionBase):
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to local ''' ''' transfer a file from local to local '''
#vvv("PUT %s TO %s" % (in_path, out_path), host=self.host) #vvv("PUT {0} TO {1}".format(in_path, out_path), host=self.host)
self._display.vvv("%s PUT %s TO %s" % (self._connection_info.remote_addr, in_path, out_path)) self._display.vvv("{0} PUT {1} TO {2}".format(self._connection_info.remote_addr, in_path, out_path))
if not os.path.exists(in_path): if not os.path.exists(in_path):
#raise AnsibleFileNotFound("file or module does not exist: %s" % in_path) raise AnsibleFileNotFound("file or module does not exist: {0}".format(in_path))
raise AnsibleError("file or module does not exist: %s" % in_path)
try: try:
shutil.copyfile(in_path, out_path) shutil.copyfile(in_path, out_path)
except shutil.Error: except shutil.Error:
traceback.print_exc() traceback.print_exc()
raise AnsibleError("failed to copy: %s and %s are the same" % (in_path, out_path)) raise AnsibleError("failed to copy: {0} and {1} are the same".format(in_path, out_path))
except IOError: except IOError:
traceback.print_exc() traceback.print_exc()
raise AnsibleError("failed to transfer file to %s" % out_path) raise AnsibleError("failed to transfer file to {0}".format(out_path))
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
#vvv("FETCH %s TO %s" % (in_path, out_path), host=self.host)
self._display.vvv("%s FETCH %s TO %s" % (self._connection_info.remote_addr, in_path, out_path))
''' fetch a file from local to local -- for copatibility ''' ''' fetch a file from local to local -- for copatibility '''
#vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self.host)
self._display.vvv("{0} FETCH {1} TO {2}".format(self._connection_info.remote_addr, in_path, out_path))
self.put_file(in_path, out_path) self.put_file(in_path, out_path)
def close(self): def close(self):
''' terminate the connection; nothing to do here ''' ''' terminate the connection; nothing to do here '''
pass self._connected = False

View File

@ -33,15 +33,13 @@ import pty
from hashlib import sha1 from hashlib import sha1
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleConnectionFailure from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
from ansible.plugins.connections import ConnectionBase from ansible.plugins.connections import ConnectionBase
class Connection(ConnectionBase): class Connection(ConnectionBase):
''' ssh based connections ''' ''' ssh based connections '''
def __init__(self, connection_info, *args, **kwargs): def __init__(self, connection_info, *args, **kwargs):
super(Connection, self).__init__(connection_info)
# SSH connection specific init stuff # SSH connection specific init stuff
self.HASHED_KEY_MAGIC = "|1|" self.HASHED_KEY_MAGIC = "|1|"
self._has_pipelining = True self._has_pipelining = True
@ -52,14 +50,20 @@ class Connection(ConnectionBase):
self._cp_dir = '/tmp' self._cp_dir = '/tmp'
#fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN) #fcntl.lockf(self.runner.process_lockfile, fcntl.LOCK_UN)
def get_transport(self): super(Connection, self).__init__(connection_info)
@property
def transport(self):
''' used to identify this connection object from other classes ''' ''' used to identify this connection object from other classes '''
return 'ssh' return 'ssh'
def connect(self): def _connect(self):
''' connect to the remote host ''' ''' connect to the remote host '''
self._display.vvv("ESTABLISH SSH CONNECTION FOR USER: %s" % self._connection_info.remote_user, host=self._connection_info.remote_addr) self._display.vvv("ESTABLISH SSH CONNECTION FOR USER: {0}".format(self._connection_info.remote_user), host=self._connection_info.remote_addr)
if self._connected:
return self
self._common_args = [] self._common_args = []
extra_args = C.ANSIBLE_SSH_ARGS extra_args = C.ANSIBLE_SSH_ARGS
@ -67,11 +71,11 @@ class Connection(ConnectionBase):
# make sure there is no empty string added as this can produce weird errors # make sure there is no empty string added as this can produce weird errors
self._common_args += [x.strip() for x in shlex.split(extra_args) if x.strip()] self._common_args += [x.strip() for x in shlex.split(extra_args) if x.strip()]
else: else:
self._common_args += [ self._common_args += (
"-o", "ControlMaster=auto", "-o", "ControlMaster=auto",
"-o", "ControlPersist=60s", "-o", "ControlPersist=60s",
"-o", "ControlPath=\"%s\"" % (C.ANSIBLE_SSH_CONTROL_PATH % dict(directory=self._cp_dir)), "-o", "ControlPath=\"{0}\"".format(C.ANSIBLE_SSH_CONTROL_PATH.format(dict(directory=self._cp_dir))),
] )
cp_in_use = False cp_in_use = False
cp_path_set = False cp_path_set = False
@ -82,30 +86,34 @@ class Connection(ConnectionBase):
cp_path_set = True cp_path_set = True
if cp_in_use and not cp_path_set: if cp_in_use and not cp_path_set:
self._common_args += ["-o", "ControlPath=\"%s\"" % (C.ANSIBLE_SSH_CONTROL_PATH % dict(directory=self._cp_dir))] self._common_args += ("-o", "ControlPath=\"{0}\"".format(
C.ANSIBLE_SSH_CONTROL_PATH.format(dict(directory=self._cp_dir)))
)
if not C.HOST_KEY_CHECKING: if not C.HOST_KEY_CHECKING:
self._common_args += ["-o", "StrictHostKeyChecking=no"] self._common_args += ("-o", "StrictHostKeyChecking=no")
if self._connection_info.port is not None: if self._connection_info.port is not None:
self._common_args += ["-o", "Port=%d" % (self._connection_info.port)] self._common_args += ("-o", "Port={0}".format(self._connection_info.port))
# FIXME: need to get this from connection info # FIXME: need to get this from connection info
#if self.private_key_file is not None: #if self.private_key_file is not None:
# self._common_args += ["-o", "IdentityFile=\"%s\"" % os.path.expanduser(self.private_key_file)] # self._common_args += ("-o", "IdentityFile=\"{0}\"".format(os.path.expanduser(self.private_key_file)))
#elif self.runner.private_key_file is not None: #elif self.runner.private_key_file is not None:
# self._common_args += ["-o", "IdentityFile=\"%s\"" % os.path.expanduser(self.runner.private_key_file)] # self._common_args += ("-o", "IdentityFile=\"{0}\"".format(os.path.expanduser(self.runner.private_key_file)))
if self._connection_info.password: if self._connection_info.password:
self._common_args += ["-o", "GSSAPIAuthentication=no", self._common_args += ("-o", "GSSAPIAuthentication=no",
"-o", "PubkeyAuthentication=no"] "-o", "PubkeyAuthentication=no")
else: else:
self._common_args += ["-o", "KbdInteractiveAuthentication=no", self._common_args += ("-o", "KbdInteractiveAuthentication=no",
"-o", "PreferredAuthentications=gssapi-with-mic,gssapi-keyex,hostbased,publickey", "-o", "PreferredAuthentications=gssapi-with-mic,gssapi-keyex,hostbased,publickey",
"-o", "PasswordAuthentication=no"] "-o", "PasswordAuthentication=no")
if self._connection_info.remote_user is not None and self._connection_info.remote_user != pwd.getpwuid(os.geteuid())[0]: if self._connection_info.remote_user is not None and self._connection_info.remote_user != pwd.getpwuid(os.geteuid())[0]:
self._common_args += ["-o", "User="+self._connection_info.remote_user] self._common_args += ("-o", "User={0}".format(self._connection_info.remote_user))
# FIXME: figure out where this goes # FIXME: figure out where this goes
#self._common_args += ["-o", "ConnectTimeout=%d" % self.runner.timeout] #self._common_args += ("-o", "ConnectTimeout={0}".format(self.runner.timeout))
self._common_args += ["-o", "ConnectTimeout=15"] self._common_args += ("-o", "ConnectTimeout=15")
self._connected = True
return self return self
@ -136,13 +144,13 @@ class Connection(ConnectionBase):
except OSError: except OSError:
raise AnsibleError("to use the 'ssh' connection type with passwords, you must install the sshpass program") raise AnsibleError("to use the 'ssh' connection type with passwords, you must install the sshpass program")
(self.rfd, self.wfd) = os.pipe() (self.rfd, self.wfd) = os.pipe()
return ["sshpass", "-d%d" % self.rfd] return ("sshpass", "-d{0}".format(self.rfd))
return [] return []
def _send_password(self): def _send_password(self):
if self._connection_info.password: if self._connection_info.password:
os.close(self.rfd) os.close(self.rfd)
os.write(self.wfd, "%s\n" % self._connection_info.password) os.write(self.wfd, "{0}\n".format(self._connection_info.password))
os.close(self.wfd) os.close(self.wfd)
def _communicate(self, p, stdin, indata, su=False, sudoable=False, prompt=None): def _communicate(self, p, stdin, indata, su=False, sudoable=False, prompt=None):
@ -258,33 +266,33 @@ class Connection(ConnectionBase):
return False return False
if (hfiles_not_found == len(host_file_list)): if (hfiles_not_found == len(host_file_list)):
self._display.vvv("EXEC previous known host file not found for %s" % host) self._display.vvv("EXEC previous known host file not found for {0}".format(host))
return True return True
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
''' run a command on the remote host ''' ''' run a command on the remote host '''
ssh_cmd = self._password_cmd() ssh_cmd = self._password_cmd()
ssh_cmd += ["ssh", "-C"] ssh_cmd += ("ssh", "-C")
if not in_data: if not in_data:
# we can only use tty when we are not pipelining the modules. piping data into /usr/bin/python # we can only use tty when we are not pipelining the modules. piping data into /usr/bin/python
# inside a tty automatically invokes the python interactive-mode but the modules are not # inside a tty automatically invokes the python interactive-mode but the modules are not
# compatible with the interactive-mode ("unexpected indent" mainly because of empty lines) # compatible with the interactive-mode ("unexpected indent" mainly because of empty lines)
ssh_cmd += ["-tt"] ssh_cmd.append("-tt")
if self._connection_info.verbosity > 3: if self._connection_info.verbosity > 3:
ssh_cmd += ["-vvv"] ssh_cmd.append("-vvv")
else: else:
ssh_cmd += ["-q"] ssh_cmd.append("-q")
ssh_cmd += self._common_args ssh_cmd += self._common_args
# FIXME: ipv6 stuff needs to be figured out. It's in the connection info, however # FIXME: ipv6 stuff needs to be figured out. It's in the connection info, however
# not sure if it's all working yet so this remains commented out # not sure if it's all working yet so this remains commented out
#if self._ipv6: #if self._ipv6:
# ssh_cmd += ['-6'] # ssh_cmd += ['-6']
ssh_cmd += [self._connection_info.remote_addr] ssh_cmd.append(self._connection_info.remote_addr)
ssh_cmd.append(cmd) ssh_cmd.append(cmd)
self._display.vvv("EXEC %s" % ' '.join(ssh_cmd), host=self._connection_info.remote_addr) self._display.vvv("EXEC {0}".format(' '.join(ssh_cmd)), host=self._connection_info.remote_addr)
not_in_host_file = self.not_in_host_file(self._connection_info.remote_addr) not_in_host_file = self.not_in_host_file(self._connection_info.remote_addr)
@ -384,9 +392,9 @@ class Connection(ConnectionBase):
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to remote ''' ''' transfer a file from local to remote '''
self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) self._display.vvv("PUT {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr)
if not os.path.exists(in_path): if not os.path.exists(in_path):
raise AnsibleFileNotFound("file or module does not exist: %s" % in_path) raise AnsibleFileNotFound("file or module does not exist: {0}".format(in_path))
cmd = self._password_cmd() cmd = self._password_cmd()
# FIXME: make a function, used in all 3 methods EXEC/PUT/FETCH # FIXME: make a function, used in all 3 methods EXEC/PUT/FETCH
@ -398,12 +406,15 @@ class Connection(ConnectionBase):
# host = '[%s]' % host # host = '[%s]' % host
if C.DEFAULT_SCP_IF_SSH: if C.DEFAULT_SCP_IF_SSH:
cmd += ["scp"] + self._common_args cmd.append('scp')
cmd += [in_path,host + ":" + pipes.quote(out_path)] cmd += self._common_args
cmd.append(in_path,host + ":" + pipes.quote(out_path))
indata = None indata = None
else: else:
cmd += ["sftp"] + self._common_args + [host] cmd.append('sftp')
indata = "put %s %s\n" % (pipes.quote(in_path), pipes.quote(out_path)) cmd += self._common_args
cmd.append(host)
indata = "put {0} {1}\n".format(pipes.quote(in_path), pipes.quote(out_path))
(p, stdin) = self._run(cmd, indata) (p, stdin) = self._run(cmd, indata)
@ -412,11 +423,11 @@ class Connection(ConnectionBase):
(returncode, stdout, stderr) = self._communicate(p, stdin, indata) (returncode, stdout, stderr) = self._communicate(p, stdin, indata)
if returncode != 0: if returncode != 0:
raise AnsibleError("failed to transfer file to %s:\n%s\n%s" % (out_path, stdout, stderr)) raise AnsibleError("failed to transfer file to {0}:\n{1}\n{2}".format(out_path, stdout, stderr))
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
''' fetch a file from remote to local ''' ''' fetch a file from remote to local '''
self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) self._display.vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr)
cmd = self._password_cmd() cmd = self._password_cmd()
# FIXME: make a function, used in all 3 methods EXEC/PUT/FETCH # FIXME: make a function, used in all 3 methods EXEC/PUT/FETCH
@ -428,21 +439,24 @@ class Connection(ConnectionBase):
# host = '[%s]' % self._connection_info.remote_addr # host = '[%s]' % self._connection_info.remote_addr
if C.DEFAULT_SCP_IF_SSH: if C.DEFAULT_SCP_IF_SSH:
cmd += ["scp"] + self._common_args cmd.append('scp')
cmd += [host + ":" + in_path, out_path] cmd += self._common_args
cmd += ('{0}:{1}'.format(host, in_path), out_path)
indata = None indata = None
else: else:
cmd += ["sftp"] + self._common_args + [host] cmd.append('sftp')
indata = "get %s %s\n" % (in_path, out_path) cmd += self._common_args
cmd.append(host)
indata = "get {0} {1}\n".format(in_path, out_path)
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self._send_password() self._send_password()
stdout, stderr = p.communicate(indata) stdout, stderr = p.communicate(indata)
if p.returncode != 0: if p.returncode != 0:
raise AnsibleError("failed to transfer file from %s:\n%s\n%s" % (in_path, stdout, stderr)) raise AnsibleError("failed to transfer file from {0}:\n{1}\n{2}".format(in_path, stdout, stderr))
def close(self): def close(self):
''' not applicable since we're executing openssh binaries ''' ''' not applicable since we're executing openssh binaries '''
pass self._connected = False