parent
6c25ea1b91
commit
9574f89471
|
@ -101,6 +101,7 @@ class Connection(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
|
wrong_user = False
|
||||||
tries = 3
|
tries = 3
|
||||||
self.conn = socket.socket()
|
self.conn = socket.socket()
|
||||||
self.conn.settimeout(constants.ACCELERATE_CONNECT_TIMEOUT)
|
self.conn.settimeout(constants.ACCELERATE_CONNECT_TIMEOUT)
|
||||||
|
@ -108,6 +109,12 @@ class Connection(object):
|
||||||
while tries > 0:
|
while tries > 0:
|
||||||
try:
|
try:
|
||||||
self.conn.connect((self.host,self.accport))
|
self.conn.connect((self.host,self.accport))
|
||||||
|
if not self.validate_user():
|
||||||
|
# the accelerated daemon was started with a
|
||||||
|
# different remote_user. The above command
|
||||||
|
# should have caused the accelerate daemon to
|
||||||
|
# shutdown, so we'll reconnect.
|
||||||
|
wrong_user = True
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
vvvv("failed, retrying...")
|
vvvv("failed, retrying...")
|
||||||
|
@ -116,6 +123,9 @@ class Connection(object):
|
||||||
if tries == 0:
|
if tries == 0:
|
||||||
vvv("Could not connect via the accelerated connection, exceeded # of tries")
|
vvv("Could not connect via the accelerated connection, exceeded # of tries")
|
||||||
raise errors.AnsibleError("Failed to connect")
|
raise errors.AnsibleError("Failed to connect")
|
||||||
|
elif wrong_user:
|
||||||
|
vvv("Restarting daemon with a different remote_user")
|
||||||
|
raise errors.AnsibleError("Wrong user")
|
||||||
self.conn.settimeout(constants.ACCELERATE_TIMEOUT)
|
self.conn.settimeout(constants.ACCELERATE_TIMEOUT)
|
||||||
except:
|
except:
|
||||||
if allow_ssh:
|
if allow_ssh:
|
||||||
|
@ -159,6 +169,44 @@ class Connection(object):
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
raise errors.AnsibleError("timed out while waiting to receive data")
|
raise errors.AnsibleError("timed out while waiting to receive data")
|
||||||
|
|
||||||
|
def validate_user(self):
|
||||||
|
'''
|
||||||
|
Checks the remote uid of the accelerated daemon vs. the
|
||||||
|
one specified for this play and will cause the accel
|
||||||
|
daemon to exit if they don't match
|
||||||
|
'''
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
mode='validate_user',
|
||||||
|
username=self.user,
|
||||||
|
)
|
||||||
|
data = utils.jsonify(data)
|
||||||
|
data = utils.encrypt(self.key, data)
|
||||||
|
if self.send_data(data):
|
||||||
|
raise errors.AnsibleError("Failed to send command to %s" % self.host)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# we loop here while waiting for the response, because a
|
||||||
|
# long running command may cause us to receive keepalive packets
|
||||||
|
# ({"pong":"true"}) rather than the response we want.
|
||||||
|
response = self.recv_data()
|
||||||
|
if not response:
|
||||||
|
raise errors.AnsibleError("Failed to get a response from %s" % self.host)
|
||||||
|
response = utils.decrypt(self.key, response)
|
||||||
|
response = utils.parse_json(response)
|
||||||
|
if "pong" in response:
|
||||||
|
# it's a keepalive, go back to waiting
|
||||||
|
vvvv("%s: received a keepalive packet" % self.host)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
vvvv("%s: received the response" % self.host)
|
||||||
|
break
|
||||||
|
|
||||||
|
if response.get('failed'):
|
||||||
|
raise errors.AnsibleError("Error while validating user: %s" % response.get("msg"))
|
||||||
|
else:
|
||||||
|
return response.get('rc') == 0
|
||||||
|
|
||||||
def exec_command(self, cmd, tmp_path, sudo_user=None, sudoable=False, executable='/bin/sh', in_data=None, su=None, su_user=None):
|
def exec_command(self, cmd, tmp_path, sudo_user=None, sudoable=False, executable='/bin/sh', in_data=None, su=None, su_user=None):
|
||||||
''' run a command on the remote host '''
|
''' run a command on the remote host '''
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,7 @@ import getpass
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
|
import pwd
|
||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
@ -280,6 +281,9 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
|
||||||
elif mode == 'fetch':
|
elif mode == 'fetch':
|
||||||
vvvv("received a fetch request, getting it")
|
vvvv("received a fetch request, getting it")
|
||||||
response = self.fetch(data)
|
response = self.fetch(data)
|
||||||
|
elif mode == 'validate_user':
|
||||||
|
vvvv("received a request to validate the user id")
|
||||||
|
response = self.validate_user(data)
|
||||||
|
|
||||||
vvvv("response result is %s" % str(response))
|
vvvv("response result is %s" % str(response))
|
||||||
data2 = json.dumps(response)
|
data2 = json.dumps(response)
|
||||||
|
@ -287,6 +291,10 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
|
||||||
vvvv("sending the response back to the controller")
|
vvvv("sending the response back to the controller")
|
||||||
self.send_data(data2)
|
self.send_data(data2)
|
||||||
vvvv("done sending the response")
|
vvvv("done sending the response")
|
||||||
|
|
||||||
|
if mode == 'validate_user' and response.get('rc') == 1:
|
||||||
|
vvvv("detected a uid mismatch, shutting down")
|
||||||
|
self.server.shutdown()
|
||||||
except:
|
except:
|
||||||
tb = traceback.format_exc()
|
tb = traceback.format_exc()
|
||||||
log("encountered an unhandled exception in the handle() function")
|
log("encountered an unhandled exception in the handle() function")
|
||||||
|
@ -295,6 +303,27 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
|
||||||
data2 = self.server.key.Encrypt(data2)
|
data2 = self.server.key.Encrypt(data2)
|
||||||
self.send_data(data2)
|
self.send_data(data2)
|
||||||
|
|
||||||
|
def validate_user(self, data):
|
||||||
|
if 'username' not in data:
|
||||||
|
return dict(failed=True, msg='No username specified')
|
||||||
|
|
||||||
|
vvvv("validating we're running as %s" % data['username'])
|
||||||
|
|
||||||
|
# get the current uid
|
||||||
|
c_uid = os.getuid()
|
||||||
|
try:
|
||||||
|
# the target uid
|
||||||
|
t_uid = pwd.getpwnam(data['username']).pw_uid
|
||||||
|
except:
|
||||||
|
vvvv("could not find user %s" % data['username'])
|
||||||
|
return dict(failed=True, msg='could not find user %s' % data['username'])
|
||||||
|
|
||||||
|
# and return rc=0 for success, rc=1 for failure
|
||||||
|
if c_uid == t_uid:
|
||||||
|
return dict(rc=0)
|
||||||
|
else:
|
||||||
|
return dict(rc=1)
|
||||||
|
|
||||||
def command(self, data):
|
def command(self, data):
|
||||||
if 'cmd' not in data:
|
if 'cmd' not in data:
|
||||||
return dict(failed=True, msg='internal error: cmd is required')
|
return dict(failed=True, msg='internal error: cmd is required')
|
||||||
|
@ -409,14 +438,26 @@ def daemonize(module, password, port, timeout, minutes, ipv6):
|
||||||
signal.signal(signal.SIGALRM, catcher)
|
signal.signal(signal.SIGALRM, catcher)
|
||||||
signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
|
signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
|
||||||
|
|
||||||
if ipv6:
|
tries = 5
|
||||||
server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout)
|
while tries > 0:
|
||||||
else:
|
try:
|
||||||
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
|
if ipv6:
|
||||||
server.allow_reuse_address = True
|
server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout)
|
||||||
|
else:
|
||||||
|
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
|
||||||
|
server.allow_reuse_address = True
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
vv("Failed to create the TCP server (tries left = %d)" % tries)
|
||||||
|
tries -= 1
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
if tries == 0:
|
||||||
|
vv("Maximum number of attempts to create the TCP server reached, bailing out")
|
||||||
|
raise Exception("max # of attempts to serve reached")
|
||||||
|
|
||||||
vv("serving!")
|
vv("serving!")
|
||||||
server.serve_forever(poll_interval=1.0)
|
server.serve_forever(poll_interval=0.1)
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
tb = traceback.format_exc()
|
tb = traceback.format_exc()
|
||||||
log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb))
|
log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb))
|
||||||
|
|
Loading…
Reference in New Issue