byrd.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from collections import ChainMap, OrderedDict, defaultdict
  5. from string import Formatter
  6. import argparse
  7. import io
  8. import logging
  9. import os
  10. import posixpath
  11. import subprocess
  12. import sys
  13. import threading
  14. import paramiko
  15. import yaml
  16. try:
  17. import keyring
  18. except ImportError:
  19. keyring = None
  20. __version__ = '0.0'
  21. log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  22. logger = logging.getLogger('byrd')
  23. logger.setLevel(logging.INFO)
  24. log_handler = logging.StreamHandler()
  25. log_handler.setLevel(logging.INFO)
  26. log_handler.setFormatter(logging.Formatter(log_fmt))
  27. logger.addHandler(log_handler)
  28. basedir, _ = os.path.split(__file__)
  29. PKG_DIR = os.path.join(basedir, 'pkg')
  30. TAB = '\n '
  31. class ByrdException(Exception):
  32. pass
  33. class FmtException(ByrdException):
  34. pass
  35. class ExecutionException(ByrdException):
  36. pass
  37. class RemoteException(ExecutionException):
  38. pass
  39. class LocalException(ExecutionException):
  40. pass
  41. def enable_logging_color():
  42. try:
  43. import colorama
  44. except ImportError:
  45. return
  46. colorama.init()
  47. MAGENTA = colorama.Fore.MAGENTA
  48. RED = colorama.Fore.RED
  49. RESET = colorama.Style.RESET_ALL
  50. # We define custom handler ..
  51. class Handler(logging.StreamHandler):
  52. def format(self, record):
  53. if record.levelname == 'INFO':
  54. record.msg = MAGENTA + record.msg + RESET
  55. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  56. record.msg = RED + record.msg + RESET
  57. return super(Handler, self).format(record)
  58. # .. and plug it
  59. logger.removeHandler(log_handler)
  60. handler = Handler()
  61. handler.setFormatter(logging.Formatter(log_fmt))
  62. logger.addHandler(handler)
  63. logger.propagate = 0
  64. def yaml_load(stream):
  65. class OrderedLoader(yaml.Loader):
  66. pass
  67. def construct_mapping(loader, node):
  68. loader.flatten_mapping(node)
  69. return OrderedDict(loader.construct_pairs(node))
  70. OrderedLoader.add_constructor(
  71. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  72. construct_mapping)
  73. return yaml.load(stream, OrderedLoader)
  74. def edits(word):
  75. yield word
  76. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  77. for left, right in splits:
  78. if right:
  79. yield left + right[1:]
  80. def gen_candidates(wordlist):
  81. candidates = defaultdict(set)
  82. for word in wordlist:
  83. for ed1 in edits(word):
  84. for ed2 in edits(ed1):
  85. candidates[ed2].add(word)
  86. return candidates
  87. def spell(candidates, word):
  88. matches = set(chain.from_iterable(
  89. candidates[ed] for ed in edits(word) if ed in candidates
  90. ))
  91. return matches
  92. def spellcheck(objdict, word):
  93. if word in objdict:
  94. return
  95. candidates = objdict.get('_candidates')
  96. if not candidates:
  97. candidates = gen_candidates(list(objdict))
  98. objdict._candidates = candidates
  99. msg = '"%s" not found in %s' % (word, objdict._path)
  100. matches = spell(candidates, word)
  101. if matches:
  102. msg += ', try: %s' % ' or '.join(matches)
  103. raise ByrdException(msg)
  104. class ObjectDict(dict):
  105. """
  106. Simple objet sub-class that allows to transform a dict into an
  107. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  108. """
  109. _meta = {}
  110. def copy(self):
  111. res = ObjectDict(super().copy())
  112. ObjectDict._meta[id(res)] = ObjectDict._meta.get(id(self), {}).copy()
  113. return res
  114. def __getattr__(self, key):
  115. if key.startswith('_'):
  116. return ObjectDict._meta[id(self), key]
  117. if key in self:
  118. return self[key]
  119. else:
  120. return None
  121. def __setattr__(self, key, value):
  122. if key.startswith('_'):
  123. ObjectDict._meta[id(self), key] = value
  124. else:
  125. self[key] = value
  126. class Node:
  127. @staticmethod
  128. def fail(path, kind):
  129. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  130. raise ByrdException(msg % (kind, '->'.join(path)))
  131. @classmethod
  132. def parse(cls, cfg, path=tuple()):
  133. children = getattr(cls, '_children', None)
  134. type_name = children and type(children).__name__ \
  135. or ' or '.join((c.__name__ for c in cls._type))
  136. res = None
  137. if type_name == 'dict':
  138. if not isinstance(cfg, dict):
  139. cls.fail(path, type_name)
  140. res = ObjectDict()
  141. if '*' in children:
  142. assert len(children) == 1, "Don't mix '*' and other keys"
  143. child_class = children['*']
  144. for name, value in cfg.items():
  145. res[name] = child_class.parse(value, path + (name,))
  146. else:
  147. # Enforce known pre-defined
  148. for key in cfg:
  149. if key not in children:
  150. path = ' -> '.join(path)
  151. if path:
  152. msg = 'Attribute "%s" not understood in %s' % (
  153. key, path)
  154. else:
  155. msg = 'Top-level attribute "%s" not understood' % (
  156. key)
  157. candidates = gen_candidates(children.keys())
  158. matches = spell(candidates, key)
  159. if matches:
  160. msg += ', try: %s' % ' or '.join(matches)
  161. raise ByrdException(msg)
  162. for name, child_class in children.items():
  163. if name not in cfg:
  164. continue
  165. res[name] = child_class.parse(cfg[name], path + (name,))
  166. elif type_name == 'list':
  167. if not isinstance(cfg, list):
  168. cls.fail(path, type_name)
  169. child_class = children[0]
  170. res = [child_class.parse(c, path+ ('[%s]' % pos,))
  171. for pos, c in enumerate(cfg)]
  172. else:
  173. if not isinstance(cfg, cls._type):
  174. cls.fail(path, type_name)
  175. res = cfg
  176. return cls.setup(res, path)
  177. @classmethod
  178. def setup(cls, values, path):
  179. if isinstance(values, ObjectDict):
  180. values._path = '->'.join(path)
  181. return values
  182. class Atom(Node):
  183. _type = (str, bool)
  184. class AtomList(Node):
  185. _children = [Atom]
  186. class Hosts(Node):
  187. _children = [Atom]
  188. class Auth(Node):
  189. _children = {'*': Atom}
  190. class EnvNode(Node):
  191. _children = {'*': Atom}
  192. class HostGroup(Node):
  193. _children = {
  194. 'hosts': Hosts,
  195. }
  196. class Network(Node):
  197. _children = {
  198. '*': HostGroup,
  199. }
  200. class Multi(Node):
  201. _children = {
  202. 'task': Atom,
  203. 'export': Atom,
  204. 'network': Atom,
  205. }
  206. class MultiList(Node):
  207. _children = [Multi]
  208. class Task(Node):
  209. _children = {
  210. 'desc': Atom,
  211. 'local': Atom,
  212. 'python': Atom,
  213. 'once': Atom,
  214. 'run': Atom,
  215. 'sudo': Atom,
  216. 'send': Atom,
  217. 'to': Atom,
  218. 'assert': Atom,
  219. 'env': EnvNode,
  220. 'multi': MultiList,
  221. }
  222. @classmethod
  223. def setup(cls, values, path):
  224. values['name'] = path and path[-1] or ''
  225. if 'desc' not in values:
  226. values['desc'] = values.get('name', '')
  227. super().setup(values, path)
  228. return values
  229. # Multi can also accept any task attribute:
  230. Multi._children.update(Task._children)
  231. class TaskGroup(Node):
  232. _children = {
  233. '*': Task,
  234. }
  235. class LoadNode(Node):
  236. _children = {
  237. 'file': Atom,
  238. 'pkg': Atom,
  239. 'as': Atom,
  240. }
  241. class LoadList(Node):
  242. _children = [LoadNode]
  243. class ConfigRoot(Node):
  244. _children = {
  245. 'networks': Network,
  246. 'tasks': TaskGroup,
  247. 'auth': Auth,
  248. 'env': EnvNode,
  249. 'load': LoadList,
  250. }
  251. class Env(ChainMap):
  252. def __init__(self, *dicts):
  253. return super().__init__(*filter(lambda x: x is not None, dicts))
  254. def fmt_env(self, child_env):
  255. new_env = {}
  256. for key, val in child_env.items():
  257. # env wrap-around!
  258. new_val = self.fmt(val)
  259. if new_val == val:
  260. continue
  261. new_env[key] = new_val
  262. return Env(new_env, child_env)
  263. def fmt_string(self, string):
  264. try:
  265. return string.format(**self)
  266. except KeyError as exc:
  267. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  268. candidates = gen_candidates(self.keys())
  269. key = exc.args[0]
  270. matches = spell(candidates, key)
  271. if matches:
  272. msg += ', try: %s' % ' or '.join(matches)
  273. raise FmtException(msg )
  274. except IndexError as exc:
  275. msg = 'Unable to format "%s", positional argument not supported'
  276. raise FmtException(msg)
  277. def fmt(self, what):
  278. if isinstance(what, str):
  279. return self.fmt_string(what)
  280. return self.fmt_env(what)
  281. def get_passphrase(key_path):
  282. service = 'SSH private key'
  283. csum = md5(open(key_path, 'rb').read()).digest().hex()
  284. ssh_pass = keyring.get_password(service, csum)
  285. if not ssh_pass:
  286. ssh_pass = getpass('Password for %s: ' % key_path)
  287. keyring.set_password(service, csum, ssh_pass)
  288. return ssh_pass
  289. def get_password(host):
  290. service = 'SSH password'
  291. ssh_pass = keyring.get_password(service, host)
  292. if not ssh_pass:
  293. ssh_pass = getpass('Password for %s: ' % host)
  294. keyring.set_password(service, host, ssh_pass)
  295. return ssh_pass
  296. def get_sudo_passwd():
  297. service = "Sudo password"
  298. passwd = keyring.get_password(service, '-')
  299. if not passwd:
  300. passwd = getpass('Sudo password:')
  301. keyring.set_password(service, '-', passwd)
  302. return passwd
  303. CONNECTION_CACHE = {}
  304. def connect(host, auth):
  305. if host in CONNECTION_CACHE:
  306. return CONNECTION_CACHE[host]
  307. private_key_file = password = None
  308. if auth and auth.get('ssh_private_key'):
  309. private_key_file = auth.ssh_private_key
  310. if not os.path.exists(auth.ssh_private_key):
  311. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  312. raise ByrdException(msg)
  313. password = get_passphrase(auth.ssh_private_key)
  314. else:
  315. password = get_password(host)
  316. username, hostname = host.split('@', 1)
  317. client = paramiko.SSHClient()
  318. client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  319. client.connect(hostname, username=username, password=password,
  320. key_filename=private_key_file,
  321. )
  322. CONNECTION_CACHE[host] = client
  323. return client
  324. def run_local(cmd, env, cli):
  325. # Run local task
  326. cmd = env.fmt(cmd)
  327. logger.info(env.fmt('{task_desc}'))
  328. if cli.dry_run:
  329. logger.info('[dry-run] ' + cmd)
  330. return None
  331. logger.debug(TAB + TAB.join(cmd.splitlines()))
  332. process = subprocess.Popen(
  333. cmd, shell=True,
  334. stdout=subprocess.PIPE,
  335. stderr=subprocess.STDOUT,
  336. env=env,
  337. )
  338. stdout, stderr = process.communicate()
  339. success = process.returncode == 0
  340. if stdout:
  341. logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
  342. if not success:
  343. raise LocalException(stdout, stderr)
  344. return ObjectDict(stdout=stdout, stderr=stderr)
  345. def run_python(task, env, cli):
  346. # Execute a piece of python localy
  347. code = task.python
  348. logger.info(env.fmt('{task_desc}'))
  349. if cli.dry_run:
  350. logger.info('[dry-run] ' + code)
  351. return None
  352. logger.debug(TAB + TAB.join(code.splitlines()))
  353. cmd = 'python -c "import sys;exec(sys.stdin.read())"'
  354. if task.sudo:
  355. user = 'root' if task.sudo is True else task.sudo
  356. cmd = 'sudo -u {} -- {}'.format(user, cmd)
  357. process = subprocess.Popen(
  358. cmd,
  359. stdout=subprocess.PIPE,
  360. stderr=subprocess.PIPE,
  361. stdin=subprocess.PIPE,
  362. env=env,
  363. )
  364. # Plug io
  365. out_buff = io.StringIO()
  366. err_buff = io.StringIO()
  367. log_stream(process.stdout, out_buff)
  368. log_stream(process.stderr, err_buff)
  369. process.stdin.write(code.encode())
  370. process.stdin.flush()
  371. process.stdin.close()
  372. success = process.wait() == 0
  373. process.stdout.close()
  374. process.stderr.close()
  375. out = out_buff.getvalue()
  376. if out:
  377. logger.debug(TAB + TAB.join(out.splitlines()))
  378. if not success:
  379. raise LocalException(out + err_buff.getvalue())
  380. return ObjectDict(stdout=out, stderr=err_buff.getvalue())
  381. def log_stream(stream, buff):
  382. def _log():
  383. try:
  384. for chunk in iter(lambda: stream.readline(2048), ""):
  385. if isinstance(chunk, bytes):
  386. chunk = chunk.decode()
  387. buff.write(chunk)
  388. except ValueError:
  389. # read raises a ValueError on closed stream
  390. pass
  391. t = threading.Thread(target=_log)
  392. t.start()
  393. return t
  394. def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
  395. chan = client.get_transport().open_session()
  396. if env:
  397. chan.update_environment(env)
  398. stdin = chan.makefile('wb')
  399. stdout = chan.makefile('r')
  400. stderr = chan.makefile_stderr('r')
  401. out_buff = io.StringIO()
  402. err_buff = io.StringIO()
  403. out_thread = log_stream(stdout, out_buff)
  404. err_thread = log_stream(stderr, err_buff)
  405. if sudo:
  406. assert not in_buff, 'in_buff and sudo can not be combined'
  407. if isinstance(sudo, str):
  408. sudo_cmd = 'sudo -u %s -s' % sudo
  409. else:
  410. sudo_cmd = 'sudo -s'
  411. chan.exec_command(sudo_cmd)
  412. in_buff = cmd
  413. else:
  414. chan.exec_command(cmd)
  415. if in_buff:
  416. # XXX use a real buff (not a simple str) ?
  417. stdin.write(in_buff)
  418. stdin.flush()
  419. stdin.close()
  420. chan.shutdown_write()
  421. success = chan.recv_exit_status() == 0
  422. out_thread.join()
  423. err_thread.join()
  424. if not success:
  425. raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
  426. res = ObjectDict(
  427. stdout = out_buff.getvalue(),
  428. stderr = err_buff.getvalue(),
  429. )
  430. return res
  431. def run_remote(task, host, env, cli):
  432. res = None
  433. host = env.fmt(host)
  434. env.update({
  435. 'host': host,
  436. })
  437. if cli.dry_run:
  438. client = None
  439. else:
  440. client = connect(host, cli.cfg.auth)
  441. if task.run:
  442. cmd = env.fmt(task.run)
  443. prefix = ''
  444. if task.sudo:
  445. if task.sudo is True:
  446. prefix = '[sudo] '
  447. else:
  448. prefix = '[sudo as %s] ' % task.sudo
  449. msg = prefix + '{host}: {task_desc}'
  450. logger.info(env.fmt(msg))
  451. logger.debug(TAB + TAB.join(cmd.splitlines()))
  452. if cli.dry_run:
  453. logger.info('[dry-run] ' + cmd)
  454. else:
  455. res = run_helper(client, cmd, env=env, sudo=task.sudo)
  456. elif task.send:
  457. local_path = env.fmt(task.send)
  458. remote_path = env.fmt(task.to)
  459. logger.info(f'[send] {local_path} -> {host}:{remote_path}')
  460. if not os.path.exists(local_path):
  461. ByrdException('Path "%s" not found' % local_path)
  462. if cli.dry_run:
  463. logger.info('[dry-run]')
  464. return
  465. else:
  466. with client.open_sftp() as sftp:
  467. if os.path.isfile(local_path):
  468. sftp.put(os.path.abspath(local_path), remote_path)
  469. elif os.path.isdir(local_path):
  470. for root, subdirs, files in os.walk(local_path):
  471. rel_dir = os.path.relpath(root, local_path)
  472. rel_dirs = os.path.split(rel_dir)
  473. rem_dir = posixpath.join(remote_path, *rel_dirs)
  474. run_helper(client, 'mkdir -p {}'.format(rem_dir))
  475. for f in files:
  476. rel_f = os.path.join(root, f)
  477. rem_file = posixpath.join(rem_dir, f)
  478. sftp.put(os.path.abspath(rel_f), rem_file)
  479. else:
  480. msg = 'Unexpected path "%s" (not a file, not a directory)'
  481. ByrdException(msg % local_path)
  482. else:
  483. raise ByrdException('Unable to run task "%s"' % task.name)
  484. if res and res.stdout:
  485. logger.debug(TAB + TAB.join(res.stdout.splitlines()))
  486. return res
  487. def run_task(task, host, cli, parent_env=None):
  488. '''
  489. Execute one task on one host (or locally)
  490. '''
  491. # Prepare environment
  492. env = Env(
  493. {},
  494. # Env on the task itself
  495. task.get('env'),
  496. # Env from parent task
  497. parent_env,
  498. ).new_child()
  499. env.update({
  500. 'task_desc': env.fmt(task.desc),
  501. 'task_name': task.name,
  502. 'host': host or '',
  503. })
  504. if task.local:
  505. res = run_local(task.local, env, cli)
  506. elif task.python:
  507. res = run_python(task, env, cli)
  508. else:
  509. res = run_remote(task, host, env, cli)
  510. if task.get('assert'):
  511. eval_env = {
  512. 'stdout': res.stdout.strip(),
  513. 'stderr': res.stderr.strip(),
  514. }
  515. assert_ = env.fmt(task['assert'])
  516. ok = eval(assert_, eval_env)
  517. if ok:
  518. logger.info('Assert ok')
  519. else:
  520. raise ByrdException('Assert "%s" failed!' % assert_)
  521. return res
  522. def run_batch(task, hosts, cli, global_env=None):
  523. '''
  524. Run one task on a list of hosts
  525. '''
  526. out = None
  527. export_env = {}
  528. task_env = global_env.fmt(task.get('env', {}))
  529. parent_env = Env(export_env, task_env, global_env)
  530. if task.get('multi'):
  531. parent_sudo = task.sudo
  532. for pos, step in enumerate(task.multi):
  533. task_name = step.task
  534. if task_name:
  535. # _cfg contain "local" config wrt the task
  536. siblings = task._cfg.tasks
  537. spellcheck(siblings, task_name)
  538. sub_task = siblings[task_name]
  539. sudo = step.sudo or sub_task.sudo or parent_sudo
  540. else:
  541. # reify a task out of attributes
  542. sub_task = Task.parse(step)
  543. sub_task._path = '%s->[%s]' % (task._path, pos)
  544. sudo = sub_task.sudo or parent_sudo
  545. sub_task.sudo = sudo
  546. network = step.get('network')
  547. if network:
  548. spellcheck(cli.cfg.networks, network)
  549. hosts = cli.cfg.networks[network].hosts
  550. child_env = step.get('env', {})
  551. child_env = parent_env.fmt(child_env)
  552. out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
  553. out = out.decode() if isinstance(out, bytes) else out
  554. export_env['_'] = out
  555. if step.export:
  556. export_env[step.export] = out
  557. else:
  558. res = None
  559. if task.once and (task.local or task.python):
  560. res = run_task(task, None, cli, parent_env)
  561. elif hosts:
  562. for host in hosts:
  563. res = run_task(task, host, cli, parent_env)
  564. if task.once:
  565. break
  566. else:
  567. logger.warning('Nothing to do for task "%s"' % task._path)
  568. out = res and res.stdout.strip() or ''
  569. return out
  570. def abort(msg):
  571. logger.error(msg)
  572. sys.exit(1)
  573. def load_cfg(path, prefix=None):
  574. load_sections = ('networks', 'tasks', 'auth', 'env')
  575. if os.path.isfile(path):
  576. logger.debug('Load config %s' % path)
  577. cfg = yaml_load(open(path))
  578. cfg = ConfigRoot.parse(cfg)
  579. else:
  580. raise ByrdException('Config file "%s" not found' % path)
  581. # Define useful defaults
  582. cfg.networks = cfg.networks or ObjectDict()
  583. cfg.tasks = cfg.tasks or ObjectDict()
  584. # Create backrefs between tasks to the local config
  585. if cfg.get('tasks'):
  586. items = cfg['tasks'].items()
  587. for k, v in items:
  588. v._cfg = ObjectDict(cfg.copy())
  589. if prefix:
  590. key_fn = lambda x: '/'.join(prefix + [x])
  591. # Apply prefix
  592. for section in load_sections:
  593. if not section in cfg:
  594. continue
  595. items = cfg[section].items()
  596. cfg[section] = {key_fn(k): v for k, v in items}
  597. # Recursive load
  598. if cfg.load:
  599. cfg_path = os.path.dirname(path)
  600. for item in cfg.load:
  601. if item.get('file'):
  602. rel_path = item.file
  603. child_path = os.path.join(cfg_path, item.file)
  604. elif item.get('pkg'):
  605. rel_path = item.pkg
  606. child_path = os.path.join(PKG_DIR, item.pkg)
  607. if item.get('as'):
  608. child_prefix = item['as']
  609. else:
  610. child_prefix, _ = os.path.splitext(rel_path)
  611. child_cfg = load_cfg(child_path, child_prefix.split('/'))
  612. for section in load_sections:
  613. if not section in cfg:
  614. continue
  615. cfg[section].update(child_cfg.get(section, {}))
  616. return cfg
  617. def load_cli(args=None):
  618. parser = argparse.ArgumentParser()
  619. parser.add_argument('names', nargs='*',
  620. help='Hosts and commands to run them on')
  621. parser.add_argument('-c', '--config', default='bd.yaml',
  622. help='Config file')
  623. parser.add_argument('-R', '--run', nargs='*', default=[],
  624. help='Run remote task')
  625. parser.add_argument('-L', '--run-local', nargs='*', default=[],
  626. help='Run local task')
  627. parser.add_argument('-P', '--run-python', nargs='*', default=[],
  628. help='Run python task')
  629. parser.add_argument('-d', '--dry-run', action='store_true',
  630. help='Do not run actual tasks, just print them')
  631. parser.add_argument('-e', '--env', nargs='*', default=[],
  632. help='Add value to execution environment '
  633. '(ex: -e foo=bar "name=John Doe")')
  634. parser.add_argument('-s', '--sudo', default='auto',
  635. help='Enable sudo (auto|yes|no')
  636. parser.add_argument('-v', '--verbose', action='count',
  637. default=0, help='Increase verbosity')
  638. parser.add_argument('-q', '--quiet', action='count',
  639. default=0, help='Decrease verbosity')
  640. parser.add_argument('-n', '--no-color', action='store_true',
  641. help='Disable colored logs')
  642. parser.add_argument('-i', '--info', action='store_true',
  643. help='Print info')
  644. cli = parser.parse_args(args=args)
  645. cli = ObjectDict(vars(cli))
  646. # Load config
  647. cfg = load_cfg(cli.config)
  648. cli.cfg = cfg
  649. cli.update(get_hosts_and_tasks(cli, cfg))
  650. # Transformt env string into dict
  651. cli.env = dict(e.split('=') for e in cli.env)
  652. return cli
  653. def get_hosts_and_tasks(cli, cfg):
  654. # Make sure we don't have overlap between hosts and tasks
  655. items = list(cfg.networks) + list(cfg.tasks)
  656. msg = 'Name collision between tasks and networks'
  657. assert len(set(items)) == len(items), msg
  658. # Build task list
  659. tasks = []
  660. networks = []
  661. for name in cli.names:
  662. if name in cfg.networks:
  663. host = cfg.networks[name]
  664. networks.append(host)
  665. elif name in cfg.tasks:
  666. task = cfg.tasks[name]
  667. tasks.append(task)
  668. else:
  669. msg = 'Name "%s" not understood' % name
  670. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  671. if matches:
  672. msg += ', try: %s' % ' or '.join(matches)
  673. raise ByrdException (msg)
  674. # Collect custom tasks from cli
  675. customs = []
  676. for cli_key in ('run', 'run_local', 'run_python'):
  677. cmd_key = cli_key.rsplit('_', 1)[-1]
  678. customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
  679. for custom_task in customs:
  680. task = Task.parse(yaml_load(custom_task))
  681. task.desc = 'Custom command'
  682. tasks.append(task)
  683. hosts = list(chain.from_iterable(n.hosts for n in networks))
  684. return dict(hosts=hosts, tasks=tasks)
  685. def info(cli):
  686. formatter = Formatter()
  687. for name, attr in cli.cfg.tasks.items():
  688. kind = 'remote'
  689. if attr.python:
  690. kind = 'python'
  691. elif attr.local:
  692. kind = 'local'
  693. elif attr.multi:
  694. kind = 'multi'
  695. elif attr.send:
  696. kind = 'send file'
  697. print(f'{name} [{kind}]:\n\tDescription: {attr.desc}')
  698. values = []
  699. for v in attr.values():
  700. if isinstance(v, list):
  701. values.extend(v)
  702. elif isinstance(v, dict):
  703. values.extend(v.values())
  704. else:
  705. values.append(v)
  706. values = filter(lambda x: isinstance(x, str), values)
  707. fmt_fields = [i[1] for v in values for i in formatter.parse(v) if i[1]]
  708. if fmt_fields:
  709. variables = ', '.join(sorted(set(fmt_fields)))
  710. else:
  711. variables = None
  712. if variables:
  713. print(f'\tVariables: {variables}')
  714. def main():
  715. cli = None
  716. try:
  717. cli = load_cli()
  718. if not cli.no_color:
  719. enable_logging_color()
  720. cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
  721. level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
  722. log_handler.setLevel(level)
  723. logger.setLevel(level)
  724. if cli.info:
  725. info(cli)
  726. return
  727. base_env = Env(
  728. cli.env, # Highest-priority
  729. cli.cfg.get('env'),
  730. os.environ, # Lowest
  731. )
  732. for task in cli.tasks:
  733. run_batch(task, cli.hosts, cli, base_env)
  734. except ByrdException as e:
  735. if cli and cli.verbose > 2:
  736. raise
  737. abort(str(e))
  738. if __name__ == '__main__':
  739. main()