main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from itertools import islice
  5. from string import Formatter
  6. import argparse
  7. import io
  8. import os
  9. import posixpath
  10. import subprocess
  11. import sys
  12. import threading
  13. from .config import Task, yaml_load, ConfigRoot
  14. from .utils import (ByrdException, LocalException, ObjectDict, RemoteException,
  15. DummyClient, Env, spellcheck, spell, enable_logging_color,
  16. logger, log_handler)
  17. try:
  18. # This file is imported by setup.py at install time
  19. import keyring
  20. import paramiko
  21. except ImportError:
  22. pass
  23. __version__ = '0.0.2'
  24. basedir, _ = os.path.split(__file__)
  25. PKG_DIR = os.path.join(basedir, 'pkg')
  26. TAB = '\n '
  27. def get_secret(service, resource, resource_id=None):
  28. resource_id = resource_id or resource
  29. secret = keyring.get_password(service, resource_id)
  30. if not secret:
  31. secret = getpass('Password for %s: ' % resource)
  32. keyring.set_password(service, resource_id, secret)
  33. return secret
  34. def get_passphrase(key_path):
  35. service = 'SSH private key'
  36. csum = md5(open(key_path, 'rb').read()).digest().hex()
  37. return get_secret(service, key_path, csum)
  38. def get_password(host):
  39. service = 'SSH password'
  40. return get_secret(service, host)
  41. def get_sudo_passwd():
  42. service = "Sudo password"
  43. return get_secret(service, 'sudo')
  44. CONNECTION_CACHE = {}
  45. def connect(host, auth):
  46. if host in CONNECTION_CACHE:
  47. return CONNECTION_CACHE[host]
  48. private_key_file = password = None
  49. if auth and auth.get('ssh_private_key'):
  50. private_key_file = os.path.expanduser(auth.ssh_private_key)
  51. if not os.path.exists(private_key_file):
  52. msg = 'Private key file "%s" not found' % private_key_file
  53. raise ByrdException(msg)
  54. password = get_passphrase(private_key_file)
  55. else:
  56. password = get_password(host)
  57. username, hostname = host.split('@', 1)
  58. client = paramiko.SSHClient()
  59. client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  60. client.connect(hostname, username=username, password=password,
  61. key_filename=private_key_file,
  62. )
  63. logger.debug(f'Connected to {hostname} as {username}')
  64. CONNECTION_CACHE[host] = client
  65. return client
  66. def run_local(cmd, env, cli):
  67. # Run local task
  68. cmd = env.fmt(cmd)
  69. logger.info(env.fmt('{task_desc}', kind='new'))
  70. if cli.dry_run:
  71. logger.info('[dry-run] ' + cmd)
  72. return None
  73. logger.debug(TAB + TAB.join(cmd.splitlines()))
  74. process = subprocess.Popen(
  75. cmd, shell=True,
  76. stdout=subprocess.PIPE,
  77. stderr=subprocess.STDOUT,
  78. env=env,
  79. )
  80. stdout, stderr = process.communicate()
  81. success = process.returncode == 0
  82. if stdout:
  83. logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
  84. if not success:
  85. raise LocalException(stdout, stderr)
  86. return ObjectDict(stdout=stdout, stderr=stderr)
  87. def run_python(task, env, cli):
  88. # Execute a piece of python localy
  89. code = task.python
  90. logger.info(env.fmt('{task_desc}', kind='new'))
  91. if cli.dry_run:
  92. logger.info('[dry-run] ' + code)
  93. return None
  94. logger.debug(TAB + TAB.join(code.splitlines()))
  95. cmd = ['python', '-c', 'import sys;exec(sys.stdin.read())']
  96. if task.sudo:
  97. user = 'root' if task.sudo is True else task.sudo
  98. cmd = 'sudo -u {} -- {}'.format(user, cmd)
  99. process = subprocess.Popen(
  100. cmd,
  101. stdout=subprocess.PIPE,
  102. stderr=subprocess.PIPE,
  103. stdin=subprocess.PIPE,
  104. env=env,
  105. )
  106. # Plug io
  107. out_buff = io.StringIO()
  108. err_buff = io.StringIO()
  109. log_stream(process.stdout, out_buff)
  110. log_stream(process.stderr, err_buff)
  111. process.stdin.write(code.encode())
  112. process.stdin.flush()
  113. process.stdin.close()
  114. success = process.wait() == 0
  115. process.stdout.close()
  116. process.stderr.close()
  117. out = out_buff.getvalue()
  118. if out:
  119. logger.debug(TAB + TAB.join(out.splitlines()))
  120. if not success:
  121. raise LocalException(out + err_buff.getvalue())
  122. return ObjectDict(stdout=out, stderr=err_buff.getvalue())
  123. def log_stream(stream, buff):
  124. def _log():
  125. try:
  126. for chunk in iter(lambda: stream.readline(2048), ""):
  127. if isinstance(chunk, bytes):
  128. chunk = chunk.decode()
  129. buff.write(chunk)
  130. except ValueError:
  131. # read raises a ValueError on closed stream
  132. pass
  133. t = threading.Thread(target=_log)
  134. t.start()
  135. return t
  136. def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
  137. '''
  138. Helper function to run `cmd` command on remote host
  139. '''
  140. chan = client.get_transport().open_session()
  141. if env:
  142. chan.update_environment(env)
  143. stdin = chan.makefile('wb')
  144. stdout = chan.makefile('r')
  145. stderr = chan.makefile_stderr('r')
  146. out_buff = io.StringIO()
  147. err_buff = io.StringIO()
  148. out_thread = log_stream(stdout, out_buff)
  149. err_thread = log_stream(stderr, err_buff)
  150. if sudo:
  151. assert not in_buff, 'in_buff and sudo can not be combined'
  152. if isinstance(sudo, str):
  153. sudo_cmd = 'sudo -u %s -s' % sudo
  154. else:
  155. sudo_cmd = 'sudo -s'
  156. chan.exec_command(sudo_cmd)
  157. in_buff = cmd
  158. else:
  159. chan.exec_command(cmd)
  160. if in_buff:
  161. # XXX use a real buff (not a simple str) ?
  162. stdin.write(in_buff)
  163. stdin.flush()
  164. stdin.close()
  165. chan.shutdown_write()
  166. success = chan.recv_exit_status() == 0
  167. out_thread.join()
  168. err_thread.join()
  169. if not success:
  170. raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
  171. res = ObjectDict(
  172. stdout = out_buff.getvalue(),
  173. stderr = err_buff.getvalue(),
  174. )
  175. return res
  176. def run_remote(task, host, env, cli):
  177. res = None
  178. host = env.fmt(host)
  179. env.update({
  180. 'host': extract_host(host),
  181. })
  182. if cli.dry_run:
  183. client = DummyClient()
  184. else:
  185. client = connect(host, cli.cfg.auth)
  186. if task.run:
  187. cmd = env.fmt(task.run)
  188. prefix = ''
  189. if task.sudo:
  190. if task.sudo is True:
  191. prefix = '[sudo] '
  192. else:
  193. prefix = '[sudo as %s] ' % task.sudo
  194. msg = prefix + '{host}: {task_desc}'
  195. logger.info(env.fmt(msg, kind='new'))
  196. logger.debug(TAB + TAB.join(cmd.splitlines()))
  197. if cli.dry_run:
  198. logger.info('[dry-run] ' + cmd)
  199. else:
  200. res = run_helper(client, cmd, env=env, sudo=task.sudo)
  201. elif task.send:
  202. local_path = env.fmt(task.send)
  203. if not os.path.exists(local_path):
  204. raise ByrdException('Path "%s" not found' % local_path)
  205. else:
  206. send(client, env, cli, task)
  207. else:
  208. raise ByrdException('Unable to run task "%s"' % task.name)
  209. if res and res.stdout:
  210. logger.debug(TAB + TAB.join(res.stdout.splitlines()))
  211. return res
  212. def send(client, env, cli, task):
  213. fmt = task.fmt and Env(env, {'fmt': 'new'}).fmt(task.fmt) or None
  214. local_path = env.fmt(task.send)
  215. remote_path = env.fmt(task.to)
  216. dry_run = cli.dry_run
  217. with client.open_sftp() as sftp:
  218. if os.path.isfile(local_path):
  219. send_file(sftp, os.path.abspath(local_path), remote_path, env,
  220. dry_run=dry_run, fmt=fmt)
  221. elif os.path.isdir(local_path):
  222. for root, subdirs, files in os.walk(local_path):
  223. rel_dir = os.path.relpath(root, local_path)
  224. rel_dirs = os.path.split(rel_dir)
  225. rem_dir = posixpath.join(remote_path, *rel_dirs)
  226. run_helper(client, 'mkdir -p {}'.format(rem_dir))
  227. for f in files:
  228. rel_f = os.path.join(root, f)
  229. rem_file = posixpath.join(rem_dir, f)
  230. send_file(sftp, os.path.abspath(rel_f), rem_file, env,
  231. dry_run=dry_run, fmt=fmt)
  232. else:
  233. msg = 'Unexpected path "%s" (not a file, not a directory)'
  234. raise ByrdException(msg % local_path)
  235. def send_file(sftp, local_path, remote_path, env, dry_run=False, fmt=None):
  236. if not fmt:
  237. logger.info(f'[send] {local_path} -> {remote_path}')
  238. lines = islice(open(local_path), 30)
  239. logger.debug('File head:' + TAB.join(lines))
  240. if not dry_run:
  241. sftp.put(local_path, remote_path)
  242. return
  243. # Format file content and save it on remote
  244. local_relpath = os.path.relpath(local_path)
  245. logger.info(f'[fmt] {local_relpath} -> {remote_path}')
  246. content = env.fmt(open(local_path).read(), kind=fmt)
  247. lines = islice(content.splitlines(), 30)
  248. logger.debug('File head:' + TAB.join(lines))
  249. if not dry_run:
  250. fh = sftp.open(remote_path, mode='w')
  251. fh.write(content)
  252. fh.close()
  253. def run_task(task, host, cli, env=None):
  254. '''
  255. Execute one task on one host (or locally)
  256. '''
  257. if task.local:
  258. res = run_local(task.local, env, cli)
  259. elif task.python:
  260. res = run_python(task, env, cli)
  261. else:
  262. res = run_remote(task, host, env, cli)
  263. if task.get('assert'):
  264. eval_env = {
  265. 'stdout': res.stdout.strip(),
  266. 'stderr': res.stderr.strip(),
  267. }
  268. assert_ = env.fmt(task['assert'])
  269. ok = eval(assert_, eval_env)
  270. if ok:
  271. logger.info('Assert ok')
  272. else:
  273. raise ByrdException('Assert "%s" failed!' % assert_)
  274. if task.get('warn'):
  275. msg = Env(env, res).fmt(task['warn'])
  276. logger.warning(msg)
  277. return res
  278. def run_batch(task, hosts, cli, global_env=None):
  279. '''
  280. Run one task on a list of hosts
  281. '''
  282. out = None
  283. export_env = {}
  284. task_env = global_env.fmt(task.get('env', {}))
  285. if not hosts and task.networks:
  286. hosts = list(chain(*(spellcheck(cli.cfg.networks, n).hosts
  287. for n in task.networks)))
  288. if task.get('multi'):
  289. parent_env = Env(export_env, task_env, global_env)
  290. parent_sudo = task.sudo
  291. for pos, step in enumerate(task.multi):
  292. task_name = step.task
  293. if task_name:
  294. # _cfg contain "local" config wrt the task
  295. siblings = task._cfg.tasks
  296. sub_task = spellcheck(siblings, task_name)
  297. sudo = step.sudo or sub_task.sudo or parent_sudo
  298. else:
  299. # reify a task out of attributes
  300. sub_task = Task.parse(step)
  301. sub_task._path = '%s->[%s]' % (task._path, pos)
  302. sudo = sub_task.sudo or parent_sudo
  303. sub_task.sudo = sudo
  304. network = step.get('network')
  305. if network:
  306. hosts = spellcheck(cli.cfg.networks, network).hosts
  307. child_env = step.get('env', {})
  308. child_env = parent_env.fmt(child_env)
  309. out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
  310. out = out.decode() if isinstance(out, bytes) else out
  311. export_env['_'] = out
  312. if step.export:
  313. export_env[step.export] = out
  314. else:
  315. task_env.update({
  316. 'task_desc': global_env.fmt(task.desc),
  317. 'task_name': task.name,
  318. })
  319. parent_env = Env(task_env, global_env)
  320. if task.get('fmt'):
  321. parent_env.fmt_kind = task.fmt
  322. res = None
  323. if task.once and (task.local or task.python):
  324. res = run_task(task, None, cli, parent_env)
  325. elif hosts:
  326. for host in hosts:
  327. env_host = extract_host(host)
  328. parent_env.update({
  329. 'host': env_host,
  330. })
  331. res = run_task(task, host, cli, parent_env)
  332. if task.once:
  333. break
  334. else:
  335. logger.warning('Nothing to do for task "%s"' % task._path)
  336. out = res and res.stdout.strip() or ''
  337. return out
  338. def extract_host(host_string):
  339. return host_string and host_string.split('@')[-1] or ''
  340. def abort(msg):
  341. logger.error(msg)
  342. sys.exit(1)
  343. def load_cfg(path, prefix=None):
  344. load_sections = ('networks', 'tasks', 'auth', 'env')
  345. if os.path.isfile(path):
  346. logger.debug('Load config %s' % os.path.relpath(path))
  347. cfg = yaml_load(open(path))
  348. cfg = ConfigRoot.parse(cfg)
  349. else:
  350. raise ByrdException('Config file "%s" not found' % path)
  351. # Define useful defaults
  352. cfg.networks = cfg.networks or ObjectDict()
  353. cfg.tasks = cfg.tasks or ObjectDict()
  354. # Create backrefs between tasks to the local config
  355. if cfg.get('tasks'):
  356. cfg_cp = cfg.copy()
  357. for k, v in cfg['tasks'].items():
  358. v._cfg = cfg_cp
  359. # Recursive load
  360. if cfg.load:
  361. cfg_path = os.path.dirname(path)
  362. for item in cfg.load:
  363. if item.get('file'):
  364. rel_path = item.file
  365. child_path = os.path.join(cfg_path, item.file)
  366. elif item.get('pkg'):
  367. rel_path = item.pkg
  368. child_path = os.path.join(PKG_DIR, item.pkg)
  369. if item.get('as'):
  370. child_prefix = item['as']
  371. else:
  372. child_prefix, _ = os.path.splitext(rel_path)
  373. child_cfg = load_cfg(child_path, child_prefix)
  374. key_fn = lambda x: '/'.join([child_prefix, x])
  375. for section in load_sections:
  376. if not section in child_cfg:
  377. continue
  378. items = {key_fn(k): v for k, v in child_cfg[section].items()}
  379. cfg[section].update(items)
  380. return cfg
  381. def load_cli(args=None):
  382. parser = argparse.ArgumentParser()
  383. parser.add_argument('names', nargs='*',
  384. help='Hosts and commands to run them on')
  385. parser.add_argument('-c', '--config', default='bd.yaml',
  386. help='Config file')
  387. parser.add_argument('-R', '--run', nargs='*', default=[],
  388. help='Run remote task')
  389. parser.add_argument('-L', '--run-local', nargs='*', default=[],
  390. help='Run local task')
  391. parser.add_argument('-P', '--run-python', nargs='*', default=[],
  392. help='Run python task')
  393. parser.add_argument('-d', '--dry-run', action='store_true',
  394. help='Do not run actual tasks, just print them')
  395. parser.add_argument('-e', '--env', nargs='*', default=[],
  396. help='Add value to execution environment '
  397. '(ex: -e foo=bar "name=John Doe")')
  398. parser.add_argument('-s', '--sudo', default='auto',
  399. help='Enable sudo (auto|yes|no')
  400. parser.add_argument('-v', '--verbose', action='count',
  401. default=0, help='Increase verbosity')
  402. parser.add_argument('-q', '--quiet', action='count',
  403. default=0, help='Decrease verbosity')
  404. parser.add_argument('-n', '--no-color', action='store_true',
  405. help='Disable colored logs')
  406. parser.add_argument('-i', '--info', action='store_true',
  407. help='Print info')
  408. cli = parser.parse_args(args=args)
  409. cli = ObjectDict(vars(cli))
  410. # Load config
  411. cfg = load_cfg(cli.config)
  412. cli.cfg = cfg
  413. cli.update(get_hosts_and_tasks(cli, cfg))
  414. # Transformt env string into dict
  415. cli.env = dict(e.split('=') for e in cli.env)
  416. return cli
  417. def get_hosts_and_tasks(cli, cfg):
  418. # Make sure we don't have overlap between hosts and tasks
  419. items = list(cfg.networks) + list(cfg.tasks)
  420. msg = 'Name collision between tasks and networks'
  421. assert len(set(items)) == len(items), msg
  422. # Build task list
  423. tasks = []
  424. networks = []
  425. for name in cli.names:
  426. if name in cfg.networks:
  427. host = cfg.networks[name]
  428. networks.append(host)
  429. elif name in cfg.tasks:
  430. task = cfg.tasks[name]
  431. tasks.append(task)
  432. else:
  433. msg = 'Name "%s" not understood' % name
  434. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  435. if matches:
  436. msg += ', try: %s' % ' or '.join(matches)
  437. raise ByrdException (msg)
  438. # Collect custom tasks from cli
  439. customs = []
  440. for cli_key in ('run', 'run_local', 'run_python'):
  441. cmd_key = cli_key.rsplit('_', 1)[-1]
  442. customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
  443. for custom_task in customs:
  444. task = Task.parse(yaml_load(custom_task))
  445. task.desc = 'Custom command'
  446. tasks.append(task)
  447. hosts = list(chain.from_iterable(n.hosts for n in networks))
  448. return dict(hosts=hosts, tasks=tasks)
  449. def info(cli):
  450. formatter = Formatter()
  451. for name, attr in cli.cfg.tasks.items():
  452. kind = 'remote'
  453. if attr.python:
  454. kind = 'python'
  455. elif attr.local:
  456. kind = 'local'
  457. elif attr.multi:
  458. kind = 'multi'
  459. elif attr.send:
  460. kind = 'send file'
  461. print(f'{name} [{kind}]:\n\tDescription: {attr.desc}')
  462. values = []
  463. for v in attr.values():
  464. if isinstance(v, list):
  465. values.extend(v)
  466. elif isinstance(v, dict):
  467. values.extend(v.values())
  468. else:
  469. values.append(v)
  470. values = filter(lambda x: isinstance(x, str), values)
  471. fmt_fields = [i[1] for v in values for i in formatter.parse(v) if i[1]]
  472. if fmt_fields:
  473. variables = ', '.join(sorted(set(fmt_fields)))
  474. else:
  475. variables = None
  476. if variables:
  477. print(f'\tVariables: {variables}')
  478. def main():
  479. cli = None
  480. try:
  481. cli = load_cli()
  482. if not cli.no_color:
  483. enable_logging_color()
  484. cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
  485. level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
  486. logger.setLevel(level)
  487. log_handler.setLevel(level)
  488. if cli.info:
  489. info(cli)
  490. return
  491. base_env = Env(
  492. cli.env, # Highest-priority
  493. cli.cfg.get('env'),
  494. os.environ, # Lowest
  495. )
  496. for task in cli.tasks:
  497. run_batch(task, cli.hosts, cli, base_env)
  498. except ByrdException as e:
  499. if cli and cli.verbose > 2:
  500. raise
  501. abort(str(e))
  502. if __name__ == '__main__':
  503. main()