byrd.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from collections import ChainMap, OrderedDict, defaultdict
  5. from string import Formatter
  6. import argparse
  7. import io
  8. import logging
  9. import os
  10. import posixpath
  11. import subprocess
  12. import sys
  13. import threading
  14. import keyring
  15. import paramiko
  16. import yaml
  17. __version__ = '0.0'
  18. log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  19. logger = logging.getLogger('byrd')
  20. logger.setLevel(logging.INFO)
  21. log_handler = logging.StreamHandler()
  22. log_handler.setLevel(logging.INFO)
  23. log_handler.setFormatter(logging.Formatter(log_fmt))
  24. logger.addHandler(log_handler)
  25. basedir, _ = os.path.split(__file__)
  26. PKG_DIR = os.path.join(basedir, 'pkg')
  27. TAB = '\n '
  28. class ByrdException(Exception):
  29. pass
  30. class FmtException(ByrdException):
  31. pass
  32. class ExecutionException(ByrdException):
  33. pass
  34. class RemoteException(ExecutionException):
  35. pass
  36. class LocalException(ExecutionException):
  37. pass
  38. def enable_logging_color():
  39. try:
  40. import colorama
  41. except ImportError:
  42. return
  43. colorama.init()
  44. MAGENTA = colorama.Fore.MAGENTA
  45. RED = colorama.Fore.RED
  46. RESET = colorama.Style.RESET_ALL
  47. # We define custom handler ..
  48. class Handler(logging.StreamHandler):
  49. def format(self, record):
  50. if record.levelname == 'INFO':
  51. record.msg = MAGENTA + record.msg + RESET
  52. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  53. record.msg = RED + record.msg + RESET
  54. return super(Handler, self).format(record)
  55. # .. and plug it
  56. logger.removeHandler(log_handler)
  57. handler = Handler()
  58. handler.setFormatter(logging.Formatter(log_fmt))
  59. logger.addHandler(handler)
  60. logger.propagate = 0
  61. def yaml_load(stream):
  62. class OrderedLoader(yaml.Loader):
  63. pass
  64. def construct_mapping(loader, node):
  65. loader.flatten_mapping(node)
  66. return OrderedDict(loader.construct_pairs(node))
  67. OrderedLoader.add_constructor(
  68. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  69. construct_mapping)
  70. return yaml.load(stream, OrderedLoader)
  71. def edits(word):
  72. yield word
  73. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  74. for left, right in splits:
  75. if right:
  76. yield left + right[1:]
  77. def gen_candidates(wordlist):
  78. candidates = defaultdict(set)
  79. for word in wordlist:
  80. for ed1 in edits(word):
  81. for ed2 in edits(ed1):
  82. candidates[ed2].add(word)
  83. return candidates
  84. def spell(candidates, word):
  85. matches = set(chain.from_iterable(
  86. candidates[ed] for ed in edits(word) if ed in candidates
  87. ))
  88. return matches
  89. def spellcheck(objdict, word):
  90. if word in objdict:
  91. return
  92. candidates = objdict.get('_candidates')
  93. if not candidates:
  94. candidates = gen_candidates(list(objdict))
  95. objdict._candidates = candidates
  96. msg = '"%s" not found in %s' % (word, objdict._path)
  97. matches = spell(candidates, word)
  98. if matches:
  99. msg += ', try: %s' % ' or '.join(matches)
  100. raise ByrdException(msg)
  101. class ObjectDict(dict):
  102. """
  103. Simple objet sub-class that allows to transform a dict into an
  104. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  105. """
  106. _meta = {}
  107. def copy(self):
  108. res = ObjectDict(super().copy())
  109. ObjectDict._meta[id(res)] = ObjectDict._meta.get(id(self), {}).copy()
  110. return res
  111. def __getattr__(self, key):
  112. if key.startswith('_'):
  113. return ObjectDict._meta[id(self), key]
  114. if key in self:
  115. return self[key]
  116. else:
  117. return None
  118. def __setattr__(self, key, value):
  119. if key.startswith('_'):
  120. ObjectDict._meta[id(self), key] = value
  121. else:
  122. self[key] = value
  123. class Node:
  124. @staticmethod
  125. def fail(path, kind):
  126. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  127. raise ByrdException(msg % (kind, '->'.join(path)))
  128. @classmethod
  129. def parse(cls, cfg, path=tuple()):
  130. children = getattr(cls, '_children', None)
  131. type_name = children and type(children).__name__ \
  132. or ' or '.join((c.__name__ for c in cls._type))
  133. res = None
  134. if type_name == 'dict':
  135. if not isinstance(cfg, dict):
  136. cls.fail(path, type_name)
  137. res = ObjectDict()
  138. if '*' in children:
  139. assert len(children) == 1, "Don't mix '*' and other keys"
  140. child_class = children['*']
  141. for name, value in cfg.items():
  142. res[name] = child_class.parse(value, path + (name,))
  143. else:
  144. # Enforce known pre-defined
  145. for key in cfg:
  146. if key not in children:
  147. path = ' -> '.join(path)
  148. if path:
  149. msg = 'Attribute "%s" not understood in %s' % (
  150. key, path)
  151. else:
  152. msg = 'Top-level attribute "%s" not understood' % (
  153. key)
  154. candidates = gen_candidates(children.keys())
  155. matches = spell(candidates, key)
  156. if matches:
  157. msg += ', try: %s' % ' or '.join(matches)
  158. raise ByrdException(msg)
  159. for name, child_class in children.items():
  160. if name not in cfg:
  161. continue
  162. res[name] = child_class.parse(cfg[name], path + (name,))
  163. elif type_name == 'list':
  164. if not isinstance(cfg, list):
  165. cls.fail(path, type_name)
  166. child_class = children[0]
  167. res = [child_class.parse(c, path+ ('[%s]' % pos,))
  168. for pos, c in enumerate(cfg)]
  169. else:
  170. if not isinstance(cfg, cls._type):
  171. cls.fail(path, type_name)
  172. res = cfg
  173. return cls.setup(res, path)
  174. @classmethod
  175. def setup(cls, values, path):
  176. if isinstance(values, ObjectDict):
  177. values._path = '->'.join(path)
  178. return values
  179. class Atom(Node):
  180. _type = (str, bool)
  181. class AtomList(Node):
  182. _children = [Atom]
  183. class Hosts(Node):
  184. _children = [Atom]
  185. class Auth(Node):
  186. _children = {'*': Atom}
  187. class EnvNode(Node):
  188. _children = {'*': Atom}
  189. class HostGroup(Node):
  190. _children = {
  191. 'hosts': Hosts,
  192. }
  193. class Network(Node):
  194. _children = {
  195. '*': HostGroup,
  196. }
  197. class Multi(Node):
  198. _children = {
  199. 'task': Atom,
  200. 'export': Atom,
  201. 'network': Atom,
  202. }
  203. class MultiList(Node):
  204. _children = [Multi]
  205. class Task(Node):
  206. _children = {
  207. 'desc': Atom,
  208. 'local': Atom,
  209. 'python': Atom,
  210. 'once': Atom,
  211. 'run': Atom,
  212. 'sudo': Atom,
  213. 'send': Atom,
  214. 'to': Atom,
  215. 'assert': Atom,
  216. 'env': EnvNode,
  217. 'multi': MultiList,
  218. 'fmt': Atom,
  219. }
  220. @classmethod
  221. def setup(cls, values, path):
  222. values['name'] = path and path[-1] or ''
  223. if 'desc' not in values:
  224. values['desc'] = values.get('name', '')
  225. super().setup(values, path)
  226. return values
  227. # Multi can also accept any task attribute:
  228. Multi._children.update(Task._children)
  229. class TaskGroup(Node):
  230. _children = {
  231. '*': Task,
  232. }
  233. class LoadNode(Node):
  234. _children = {
  235. 'file': Atom,
  236. 'pkg': Atom,
  237. 'as': Atom,
  238. }
  239. class LoadList(Node):
  240. _children = [LoadNode]
  241. class ConfigRoot(Node):
  242. _children = {
  243. 'networks': Network,
  244. 'tasks': TaskGroup,
  245. 'auth': Auth,
  246. 'env': EnvNode,
  247. 'load': LoadList,
  248. }
  249. class Env(ChainMap):
  250. def __init__(self, *dicts):
  251. self.fmt_kind = 'new'
  252. return super().__init__(*filter(lambda x: x is not None, dicts))
  253. def fmt_env(self, child_env, kind=None):
  254. new_env = {}
  255. for key, val in child_env.items():
  256. # env wrap-around!
  257. new_val = self.fmt(val, kind=kind)
  258. if new_val == val:
  259. continue
  260. new_env[key] = new_val
  261. return Env(new_env, child_env)
  262. def fmt_string(self, string, kind=None):
  263. fmt_kind = kind or self.fmt_kind
  264. try:
  265. if fmt_kind == 'old':
  266. return string % self
  267. else:
  268. return string.format(**self)
  269. except KeyError as exc:
  270. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  271. candidates = gen_candidates(self.keys())
  272. key = exc.args[0]
  273. matches = spell(candidates, key)
  274. if matches:
  275. msg += ', try: %s' % ' or '.join(matches)
  276. raise FmtException(msg )
  277. except IndexError as exc:
  278. msg = 'Unable to format "%s", positional argument not supported'
  279. raise FmtException(msg)
  280. def fmt(self, what, kind=None):
  281. if isinstance(what, str):
  282. return self.fmt_string(what, kind=kind)
  283. return self.fmt_env(what, kind=kind)
  284. def get_secret(service, resource, resource_id=None):
  285. resource_id = resource_id or resource
  286. secret = keyring.get_password(service, resource_id)
  287. if not secret:
  288. secret = getpass('Password for %s: ' % resource)
  289. keyring.set_password(service, resource_id, secret)
  290. return secret
  291. def get_passphrase(key_path):
  292. service = 'SSH private key'
  293. csum = md5(open(key_path, 'rb').read()).digest().hex()
  294. return get_secret(service, key_path, csum)
  295. def get_password(host):
  296. service = 'SSH password'
  297. return get_secret(service, host)
  298. def get_sudo_passwd():
  299. service = "Sudo password"
  300. return get_secret(service, 'sudo')
  301. CONNECTION_CACHE = {}
  302. def connect(host, auth):
  303. if host in CONNECTION_CACHE:
  304. return CONNECTION_CACHE[host]
  305. private_key_file = password = None
  306. if auth and auth.get('ssh_private_key'):
  307. private_key_file = auth.ssh_private_key
  308. if not os.path.exists(auth.ssh_private_key):
  309. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  310. raise ByrdException(msg)
  311. password = get_passphrase(auth.ssh_private_key)
  312. else:
  313. password = get_password(host)
  314. username, hostname = host.split('@', 1)
  315. client = paramiko.SSHClient()
  316. client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  317. client.connect(hostname, username=username, password=password,
  318. key_filename=private_key_file,
  319. )
  320. CONNECTION_CACHE[host] = client
  321. return client
  322. def run_local(cmd, env, cli):
  323. # Run local task
  324. cmd = env.fmt(cmd)
  325. logger.info(env.fmt('{task_desc}', kind='new'))
  326. if cli.dry_run:
  327. logger.info('[dry-run] ' + cmd)
  328. return None
  329. logger.debug(TAB + TAB.join(cmd.splitlines()))
  330. process = subprocess.Popen(
  331. cmd, shell=True,
  332. stdout=subprocess.PIPE,
  333. stderr=subprocess.STDOUT,
  334. env=env,
  335. )
  336. stdout, stderr = process.communicate()
  337. success = process.returncode == 0
  338. if stdout:
  339. logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
  340. if not success:
  341. raise LocalException(stdout, stderr)
  342. return ObjectDict(stdout=stdout, stderr=stderr)
  343. def run_python(task, env, cli):
  344. # Execute a piece of python localy
  345. code = task.python
  346. logger.info(env.fmt('{task_desc}', kind='new'))
  347. if cli.dry_run:
  348. logger.info('[dry-run] ' + code)
  349. return None
  350. logger.debug(TAB + TAB.join(code.splitlines()))
  351. cmd = 'python -c "import sys;exec(sys.stdin.read())"'
  352. if task.sudo:
  353. user = 'root' if task.sudo is True else task.sudo
  354. cmd = 'sudo -u {} -- {}'.format(user, cmd)
  355. process = subprocess.Popen(
  356. cmd,
  357. stdout=subprocess.PIPE,
  358. stderr=subprocess.PIPE,
  359. stdin=subprocess.PIPE,
  360. env=env,
  361. )
  362. # Plug io
  363. out_buff = io.StringIO()
  364. err_buff = io.StringIO()
  365. log_stream(process.stdout, out_buff)
  366. log_stream(process.stderr, err_buff)
  367. process.stdin.write(code.encode())
  368. process.stdin.flush()
  369. process.stdin.close()
  370. success = process.wait() == 0
  371. process.stdout.close()
  372. process.stderr.close()
  373. out = out_buff.getvalue()
  374. if out:
  375. logger.debug(TAB + TAB.join(out.splitlines()))
  376. if not success:
  377. raise LocalException(out + err_buff.getvalue())
  378. return ObjectDict(stdout=out, stderr=err_buff.getvalue())
  379. def log_stream(stream, buff):
  380. def _log():
  381. try:
  382. for chunk in iter(lambda: stream.readline(2048), ""):
  383. if isinstance(chunk, bytes):
  384. chunk = chunk.decode()
  385. buff.write(chunk)
  386. except ValueError:
  387. # read raises a ValueError on closed stream
  388. pass
  389. t = threading.Thread(target=_log)
  390. t.start()
  391. return t
  392. def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
  393. chan = client.get_transport().open_session()
  394. if env:
  395. chan.update_environment(env)
  396. stdin = chan.makefile('wb')
  397. stdout = chan.makefile('r')
  398. stderr = chan.makefile_stderr('r')
  399. out_buff = io.StringIO()
  400. err_buff = io.StringIO()
  401. out_thread = log_stream(stdout, out_buff)
  402. err_thread = log_stream(stderr, err_buff)
  403. if sudo:
  404. assert not in_buff, 'in_buff and sudo can not be combined'
  405. if isinstance(sudo, str):
  406. sudo_cmd = 'sudo -u %s -s' % sudo
  407. else:
  408. sudo_cmd = 'sudo -s'
  409. chan.exec_command(sudo_cmd)
  410. in_buff = cmd
  411. else:
  412. chan.exec_command(cmd)
  413. if in_buff:
  414. # XXX use a real buff (not a simple str) ?
  415. stdin.write(in_buff)
  416. stdin.flush()
  417. stdin.close()
  418. chan.shutdown_write()
  419. success = chan.recv_exit_status() == 0
  420. out_thread.join()
  421. err_thread.join()
  422. if not success:
  423. raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
  424. res = ObjectDict(
  425. stdout = out_buff.getvalue(),
  426. stderr = err_buff.getvalue(),
  427. )
  428. return res
  429. def run_remote(task, host, env, cli):
  430. res = None
  431. host = env.fmt(host)
  432. env.update({
  433. 'host': extract_host(host),
  434. })
  435. if cli.dry_run:
  436. client = None
  437. else:
  438. client = connect(host, cli.cfg.auth)
  439. if task.run:
  440. cmd = env.fmt(task.run)
  441. prefix = ''
  442. if task.sudo:
  443. if task.sudo is True:
  444. prefix = '[sudo] '
  445. else:
  446. prefix = '[sudo as %s] ' % task.sudo
  447. msg = prefix + '{host}: {task_desc}'
  448. logger.info(env.fmt(msg, kind='new'))
  449. logger.debug(TAB + TAB.join(cmd.splitlines()))
  450. if cli.dry_run:
  451. logger.info('[dry-run] ' + cmd)
  452. else:
  453. res = run_helper(client, cmd, env=env, sudo=task.sudo)
  454. elif task.send:
  455. local_path = env.fmt(task.send)
  456. remote_path = env.fmt(task.to)
  457. logger.info(f'[send] {local_path} -> {host}:{remote_path}')
  458. if not os.path.exists(local_path):
  459. ByrdException('Path "%s" not found' % local_path)
  460. if cli.dry_run:
  461. logger.info('[dry-run]')
  462. return
  463. else:
  464. with client.open_sftp() as sftp:
  465. if os.path.isfile(local_path):
  466. sftp.put(os.path.abspath(local_path), remote_path)
  467. elif os.path.isdir(local_path):
  468. for root, subdirs, files in os.walk(local_path):
  469. rel_dir = os.path.relpath(root, local_path)
  470. rel_dirs = os.path.split(rel_dir)
  471. rem_dir = posixpath.join(remote_path, *rel_dirs)
  472. run_helper(client, 'mkdir -p {}'.format(rem_dir))
  473. for f in files:
  474. rel_f = os.path.join(root, f)
  475. rem_file = posixpath.join(rem_dir, f)
  476. sftp.put(os.path.abspath(rel_f), rem_file)
  477. else:
  478. msg = 'Unexpected path "%s" (not a file, not a directory)'
  479. ByrdException(msg % local_path)
  480. else:
  481. raise ByrdException('Unable to run task "%s"' % task.name)
  482. if res and res.stdout:
  483. logger.debug(TAB + TAB.join(res.stdout.splitlines()))
  484. return res
  485. def run_task(task, host, cli, env=None):
  486. '''
  487. Execute one task on one host (or locally)
  488. '''
  489. if task.local:
  490. res = run_local(task.local, env, cli)
  491. elif task.python:
  492. res = run_python(task, env, cli)
  493. else:
  494. res = run_remote(task, host, env, cli)
  495. if task.get('assert'):
  496. eval_env = {
  497. 'stdout': res.stdout.strip(),
  498. 'stderr': res.stderr.strip(),
  499. }
  500. assert_ = env.fmt(task['assert'])
  501. ok = eval(assert_, eval_env)
  502. if ok:
  503. logger.info('Assert ok')
  504. else:
  505. raise ByrdException('Assert "%s" failed!' % assert_)
  506. return res
  507. def run_batch(task, hosts, cli, global_env=None):
  508. '''
  509. Run one task on a list of hosts
  510. '''
  511. out = None
  512. export_env = {}
  513. task_env = global_env.fmt(task.get('env', {}))
  514. if task.get('multi'):
  515. parent_env = Env(export_env, task_env, global_env)
  516. parent_sudo = task.sudo
  517. for pos, step in enumerate(task.multi):
  518. task_name = step.task
  519. if task_name:
  520. # _cfg contain "local" config wrt the task
  521. siblings = task._cfg.tasks
  522. spellcheck(siblings, task_name)
  523. sub_task = siblings[task_name]
  524. sudo = step.sudo or sub_task.sudo or parent_sudo
  525. else:
  526. # reify a task out of attributes
  527. sub_task = Task.parse(step)
  528. sub_task._path = '%s->[%s]' % (task._path, pos)
  529. sudo = sub_task.sudo or parent_sudo
  530. sub_task.sudo = sudo
  531. network = step.get('network')
  532. if network:
  533. spellcheck(cli.cfg.networks, network)
  534. hosts = cli.cfg.networks[network].hosts
  535. child_env = step.get('env', {})
  536. child_env = parent_env.fmt(child_env)
  537. out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
  538. out = out.decode() if isinstance(out, bytes) else out
  539. export_env['_'] = out
  540. if step.export:
  541. export_env[step.export] = out
  542. else:
  543. task_env.update({
  544. 'task_desc': global_env.fmt(task.desc),
  545. 'task_name': task.name,
  546. })
  547. parent_env = Env(task_env, global_env)
  548. if task.get('fmt'):
  549. parent_env.fmt_kind = task.fmt
  550. res = None
  551. if task.once and (task.local or task.python):
  552. res = run_task(task, None, cli, parent_env)
  553. elif hosts:
  554. for host in hosts:
  555. env_host = extract_host(host)
  556. parent_env.update({
  557. 'host': env_host,
  558. })
  559. res = run_task(task, host, cli, parent_env)
  560. if task.once:
  561. break
  562. else:
  563. logger.warning('Nothing to do for task "%s"' % task._path)
  564. out = res and res.stdout.strip() or ''
  565. return out
  566. def extract_host(host_string):
  567. return host_string and host_string.split('@')[-1] or ''
  568. def abort(msg):
  569. logger.error(msg)
  570. sys.exit(1)
  571. def load_cfg(path, prefix=None):
  572. load_sections = ('networks', 'tasks', 'auth', 'env')
  573. if os.path.isfile(path):
  574. logger.debug('Load config %s' % path)
  575. cfg = yaml_load(open(path))
  576. cfg = ConfigRoot.parse(cfg)
  577. else:
  578. raise ByrdException('Config file "%s" not found' % path)
  579. # Define useful defaults
  580. cfg.networks = cfg.networks or ObjectDict()
  581. cfg.tasks = cfg.tasks or ObjectDict()
  582. # Create backrefs between tasks to the local config
  583. if cfg.get('tasks'):
  584. items = cfg['tasks'].items()
  585. for k, v in items:
  586. v._cfg = ObjectDict(cfg.copy())
  587. if prefix:
  588. key_fn = lambda x: '/'.join(prefix + [x])
  589. # Apply prefix
  590. for section in load_sections:
  591. if not section in cfg:
  592. continue
  593. items = cfg[section].items()
  594. cfg[section] = {key_fn(k): v for k, v in items}
  595. # Recursive load
  596. if cfg.load:
  597. cfg_path = os.path.dirname(path)
  598. for item in cfg.load:
  599. if item.get('file'):
  600. rel_path = item.file
  601. child_path = os.path.join(cfg_path, item.file)
  602. elif item.get('pkg'):
  603. rel_path = item.pkg
  604. child_path = os.path.join(PKG_DIR, item.pkg)
  605. if item.get('as'):
  606. child_prefix = item['as']
  607. else:
  608. child_prefix, _ = os.path.splitext(rel_path)
  609. child_cfg = load_cfg(child_path, child_prefix.split('/'))
  610. for section in load_sections:
  611. if not section in cfg:
  612. continue
  613. cfg[section].update(child_cfg.get(section, {}))
  614. return cfg
  615. def load_cli(args=None):
  616. parser = argparse.ArgumentParser()
  617. parser.add_argument('names', nargs='*',
  618. help='Hosts and commands to run them on')
  619. parser.add_argument('-c', '--config', default='bd.yaml',
  620. help='Config file')
  621. parser.add_argument('-R', '--run', nargs='*', default=[],
  622. help='Run remote task')
  623. parser.add_argument('-L', '--run-local', nargs='*', default=[],
  624. help='Run local task')
  625. parser.add_argument('-P', '--run-python', nargs='*', default=[],
  626. help='Run python task')
  627. parser.add_argument('-d', '--dry-run', action='store_true',
  628. help='Do not run actual tasks, just print them')
  629. parser.add_argument('-e', '--env', nargs='*', default=[],
  630. help='Add value to execution environment '
  631. '(ex: -e foo=bar "name=John Doe")')
  632. parser.add_argument('-s', '--sudo', default='auto',
  633. help='Enable sudo (auto|yes|no')
  634. parser.add_argument('-v', '--verbose', action='count',
  635. default=0, help='Increase verbosity')
  636. parser.add_argument('-q', '--quiet', action='count',
  637. default=0, help='Decrease verbosity')
  638. parser.add_argument('-n', '--no-color', action='store_true',
  639. help='Disable colored logs')
  640. parser.add_argument('-i', '--info', action='store_true',
  641. help='Print info')
  642. cli = parser.parse_args(args=args)
  643. cli = ObjectDict(vars(cli))
  644. # Load config
  645. cfg = load_cfg(cli.config)
  646. cli.cfg = cfg
  647. cli.update(get_hosts_and_tasks(cli, cfg))
  648. # Transformt env string into dict
  649. cli.env = dict(e.split('=') for e in cli.env)
  650. return cli
  651. def get_hosts_and_tasks(cli, cfg):
  652. # Make sure we don't have overlap between hosts and tasks
  653. items = list(cfg.networks) + list(cfg.tasks)
  654. msg = 'Name collision between tasks and networks'
  655. assert len(set(items)) == len(items), msg
  656. # Build task list
  657. tasks = []
  658. networks = []
  659. for name in cli.names:
  660. if name in cfg.networks:
  661. host = cfg.networks[name]
  662. networks.append(host)
  663. elif name in cfg.tasks:
  664. task = cfg.tasks[name]
  665. tasks.append(task)
  666. else:
  667. msg = 'Name "%s" not understood' % name
  668. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  669. if matches:
  670. msg += ', try: %s' % ' or '.join(matches)
  671. raise ByrdException (msg)
  672. # Collect custom tasks from cli
  673. customs = []
  674. for cli_key in ('run', 'run_local', 'run_python'):
  675. cmd_key = cli_key.rsplit('_', 1)[-1]
  676. customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
  677. for custom_task in customs:
  678. task = Task.parse(yaml_load(custom_task))
  679. task.desc = 'Custom command'
  680. tasks.append(task)
  681. hosts = list(chain.from_iterable(n.hosts for n in networks))
  682. return dict(hosts=hosts, tasks=tasks)
  683. def info(cli):
  684. formatter = Formatter()
  685. for name, attr in cli.cfg.tasks.items():
  686. kind = 'remote'
  687. if attr.python:
  688. kind = 'python'
  689. elif attr.local:
  690. kind = 'local'
  691. elif attr.multi:
  692. kind = 'multi'
  693. elif attr.send:
  694. kind = 'send file'
  695. print(f'{name} [{kind}]:\n\tDescription: {attr.desc}')
  696. values = []
  697. for v in attr.values():
  698. if isinstance(v, list):
  699. values.extend(v)
  700. elif isinstance(v, dict):
  701. values.extend(v.values())
  702. else:
  703. values.append(v)
  704. values = filter(lambda x: isinstance(x, str), values)
  705. fmt_fields = [i[1] for v in values for i in formatter.parse(v) if i[1]]
  706. if fmt_fields:
  707. variables = ', '.join(sorted(set(fmt_fields)))
  708. else:
  709. variables = None
  710. if variables:
  711. print(f'\tVariables: {variables}')
  712. def main():
  713. cli = None
  714. try:
  715. cli = load_cli()
  716. if not cli.no_color:
  717. enable_logging_color()
  718. cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
  719. level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
  720. log_handler.setLevel(level)
  721. logger.setLevel(level)
  722. if cli.info:
  723. info(cli)
  724. return
  725. base_env = Env(
  726. cli.env, # Highest-priority
  727. cli.cfg.get('env'),
  728. os.environ, # Lowest
  729. )
  730. for task in cli.tasks:
  731. run_batch(task, cli.hosts, cli, base_env)
  732. except ByrdException as e:
  733. if cli and cli.verbose > 2:
  734. raise
  735. abort(str(e))
  736. if __name__ == '__main__':
  737. main()