byrd.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919
  1. from contextlib import contextmanager
  2. from getpass import getpass
  3. from hashlib import md5
  4. from itertools import chain
  5. from collections import ChainMap, OrderedDict, defaultdict
  6. from itertools import islice
  7. from string import Formatter
  8. import argparse
  9. import io
  10. import logging
  11. import os
  12. import posixpath
  13. import subprocess
  14. import sys
  15. import threading
  16. import keyring
  17. import paramiko
  18. import yaml
  19. __version__ = '0.0'
  20. log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  21. logger = logging.getLogger('byrd')
  22. logger.setLevel(logging.INFO)
  23. log_handler = logging.StreamHandler()
  24. log_handler.setLevel(logging.INFO)
  25. log_handler.setFormatter(logging.Formatter(log_fmt))
  26. logger.addHandler(log_handler)
  27. basedir, _ = os.path.split(__file__)
  28. PKG_DIR = os.path.join(basedir, 'pkg')
  29. TAB = '\n '
  30. class ByrdException(Exception):
  31. pass
  32. class FmtException(ByrdException):
  33. pass
  34. class ExecutionException(ByrdException):
  35. pass
  36. class RemoteException(ExecutionException):
  37. pass
  38. class LocalException(ExecutionException):
  39. pass
  40. def enable_logging_color():
  41. try:
  42. import colorama
  43. except ImportError:
  44. return
  45. colorama.init()
  46. MAGENTA = colorama.Fore.MAGENTA
  47. RED = colorama.Fore.RED
  48. RESET = colorama.Style.RESET_ALL
  49. # We define custom handler ..
  50. class Handler(logging.StreamHandler):
  51. def format(self, record):
  52. if record.levelname == 'INFO':
  53. record.msg = MAGENTA + record.msg + RESET
  54. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  55. record.msg = RED + record.msg + RESET
  56. return super(Handler, self).format(record)
  57. # .. and plug it
  58. logger.removeHandler(log_handler)
  59. handler = Handler()
  60. handler.setFormatter(logging.Formatter(log_fmt))
  61. logger.addHandler(handler)
  62. logger.propagate = 0
  63. def yaml_load(stream):
  64. class OrderedLoader(yaml.Loader):
  65. pass
  66. def construct_mapping(loader, node):
  67. loader.flatten_mapping(node)
  68. return OrderedDict(loader.construct_pairs(node))
  69. OrderedLoader.add_constructor(
  70. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  71. construct_mapping)
  72. return yaml.load(stream, OrderedLoader)
  73. def edits(word):
  74. yield word
  75. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  76. for left, right in splits:
  77. if right:
  78. yield left + right[1:]
  79. def gen_candidates(wordlist):
  80. candidates = defaultdict(set)
  81. for word in wordlist:
  82. for ed1 in edits(word):
  83. for ed2 in edits(ed1):
  84. candidates[ed2].add(word)
  85. return candidates
  86. def spell(candidates, word):
  87. matches = set(chain.from_iterable(
  88. candidates[ed] for ed in edits(word) if ed in candidates
  89. ))
  90. return matches
  91. def spellcheck(objdict, word):
  92. if word in objdict:
  93. return
  94. candidates = objdict.get('_candidates')
  95. if not candidates:
  96. candidates = gen_candidates(list(objdict))
  97. objdict._candidates = candidates
  98. msg = '"%s" not found in %s' % (word, objdict._path)
  99. matches = spell(candidates, word)
  100. if matches:
  101. msg += ', try: %s' % ' or '.join(matches)
  102. raise ByrdException(msg)
  103. class ObjectDict(dict):
  104. """
  105. Simple objet sub-class that allows to transform a dict into an
  106. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  107. """
  108. _meta = {}
  109. def copy(self):
  110. res = ObjectDict(super().copy())
  111. ObjectDict._meta[id(res)] = ObjectDict._meta.get(id(self), {}).copy()
  112. return res
  113. def __getattr__(self, key):
  114. if key.startswith('_'):
  115. return ObjectDict._meta[id(self), key]
  116. if key in self:
  117. return self[key]
  118. else:
  119. return None
  120. def __setattr__(self, key, value):
  121. if key.startswith('_'):
  122. ObjectDict._meta[id(self), key] = value
  123. else:
  124. self[key] = value
  125. class Node:
  126. @staticmethod
  127. def fail(path, kind):
  128. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  129. raise ByrdException(msg % (kind, '->'.join(path)))
  130. @classmethod
  131. def parse(cls, cfg, path=tuple()):
  132. children = getattr(cls, '_children', None)
  133. type_name = children and type(children).__name__ \
  134. or ' or '.join((c.__name__ for c in cls._type))
  135. res = None
  136. if type_name == 'dict':
  137. if not isinstance(cfg, dict):
  138. cls.fail(path, type_name)
  139. res = ObjectDict()
  140. if '*' in children:
  141. assert len(children) == 1, "Don't mix '*' and other keys"
  142. child_class = children['*']
  143. for name, value in cfg.items():
  144. res[name] = child_class.parse(value, path + (name,))
  145. else:
  146. # Enforce known pre-defined
  147. for key in cfg:
  148. if key not in children:
  149. path = ' -> '.join(path)
  150. if path:
  151. msg = 'Attribute "%s" not understood in %s' % (
  152. key, path)
  153. else:
  154. msg = 'Top-level attribute "%s" not understood' % (
  155. key)
  156. candidates = gen_candidates(children.keys())
  157. matches = spell(candidates, key)
  158. if matches:
  159. msg += ', try: %s' % ' or '.join(matches)
  160. raise ByrdException(msg)
  161. for name, child_class in children.items():
  162. if name not in cfg:
  163. continue
  164. res[name] = child_class.parse(cfg[name], path + (name,))
  165. elif type_name == 'list':
  166. if not isinstance(cfg, list):
  167. cls.fail(path, type_name)
  168. child_class = children[0]
  169. res = [child_class.parse(c, path+ ('[%s]' % pos,))
  170. for pos, c in enumerate(cfg)]
  171. else:
  172. if not isinstance(cfg, cls._type):
  173. cls.fail(path, type_name)
  174. res = cfg
  175. return cls.setup(res, path)
  176. @classmethod
  177. def setup(cls, values, path):
  178. if isinstance(values, ObjectDict):
  179. values._path = '->'.join(path)
  180. return values
  181. class Atom(Node):
  182. _type = (str, bool)
  183. class AtomList(Node):
  184. _children = [Atom]
  185. class Hosts(Node):
  186. _children = [Atom]
  187. class Auth(Node):
  188. _children = {'*': Atom}
  189. class EnvNode(Node):
  190. _children = {'*': Atom}
  191. class HostGroup(Node):
  192. _children = {
  193. 'hosts': Hosts,
  194. }
  195. class Network(Node):
  196. _children = {
  197. '*': HostGroup,
  198. }
  199. class Multi(Node):
  200. _children = {
  201. 'task': Atom,
  202. 'export': Atom,
  203. 'network': Atom,
  204. }
  205. class MultiList(Node):
  206. _children = [Multi]
  207. class Task(Node):
  208. _children = {
  209. 'desc': Atom,
  210. 'local': Atom,
  211. 'python': Atom,
  212. 'once': Atom,
  213. 'run': Atom,
  214. 'sudo': Atom,
  215. 'send': Atom,
  216. 'to': Atom,
  217. 'assert': Atom,
  218. 'env': EnvNode,
  219. 'multi': MultiList,
  220. 'fmt': Atom,
  221. }
  222. @classmethod
  223. def setup(cls, values, path):
  224. values['name'] = path and path[-1] or ''
  225. if 'desc' not in values:
  226. values['desc'] = values.get('name', '')
  227. super().setup(values, path)
  228. return values
  229. # Multi can also accept any task attribute:
  230. Multi._children.update(Task._children)
  231. class TaskGroup(Node):
  232. _children = {
  233. '*': Task,
  234. }
  235. class LoadNode(Node):
  236. _children = {
  237. 'file': Atom,
  238. 'pkg': Atom,
  239. 'as': Atom,
  240. }
  241. class LoadList(Node):
  242. _children = [LoadNode]
  243. class ConfigRoot(Node):
  244. _children = {
  245. 'networks': Network,
  246. 'tasks': TaskGroup,
  247. 'auth': Auth,
  248. 'env': EnvNode,
  249. 'load': LoadList,
  250. }
  251. class Env(ChainMap):
  252. def __init__(self, *dicts):
  253. self.fmt_kind = 'new'
  254. return super().__init__(*filter(lambda x: x is not None, dicts))
  255. def fmt_env(self, child_env, kind=None):
  256. new_env = {}
  257. for key, val in child_env.items():
  258. # env wrap-around!
  259. new_val = self.fmt(val, kind=kind)
  260. if new_val == val:
  261. continue
  262. new_env[key] = new_val
  263. return Env(new_env, child_env)
  264. def fmt_string(self, string, kind=None):
  265. fmt_kind = kind or self.fmt_kind
  266. try:
  267. if fmt_kind == 'old':
  268. return string % self
  269. else:
  270. return string.format(**self)
  271. except KeyError as exc:
  272. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  273. candidates = gen_candidates(self.keys())
  274. key = exc.args[0]
  275. matches = spell(candidates, key)
  276. if matches:
  277. msg += ', try: %s' % ' or '.join(matches)
  278. raise FmtException(msg )
  279. except IndexError as exc:
  280. msg = 'Unable to format "%s", positional argument not supported'
  281. raise FmtException(msg)
  282. def fmt(self, what, kind=None):
  283. if isinstance(what, str):
  284. return self.fmt_string(what, kind=kind)
  285. return self.fmt_env(what, kind=kind)
  286. class DummyClient:
  287. '''
  288. Dummy Paramiko client, mainly usefull for testing & dry runs
  289. '''
  290. @contextmanager
  291. def open_sftp(self):
  292. yield None
  293. def get_secret(service, resource, resource_id=None):
  294. resource_id = resource_id or resource
  295. secret = keyring.get_password(service, resource_id)
  296. if not secret:
  297. secret = getpass('Password for %s: ' % resource)
  298. keyring.set_password(service, resource_id, secret)
  299. return secret
  300. def get_passphrase(key_path):
  301. service = 'SSH private key'
  302. csum = md5(open(key_path, 'rb').read()).digest().hex()
  303. return get_secret(service, key_path, csum)
  304. def get_password(host):
  305. service = 'SSH password'
  306. return get_secret(service, host)
  307. def get_sudo_passwd():
  308. service = "Sudo password"
  309. return get_secret(service, 'sudo')
  310. CONNECTION_CACHE = {}
  311. def connect(host, auth):
  312. if host in CONNECTION_CACHE:
  313. return CONNECTION_CACHE[host]
  314. private_key_file = password = None
  315. if auth and auth.get('ssh_private_key'):
  316. private_key_file = auth.ssh_private_key
  317. if not os.path.exists(auth.ssh_private_key):
  318. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  319. raise ByrdException(msg)
  320. password = get_passphrase(auth.ssh_private_key)
  321. else:
  322. password = get_password(host)
  323. username, hostname = host.split('@', 1)
  324. client = paramiko.SSHClient()
  325. client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  326. client.connect(hostname, username=username, password=password,
  327. key_filename=private_key_file,
  328. )
  329. CONNECTION_CACHE[host] = client
  330. return client
  331. def run_local(cmd, env, cli):
  332. # Run local task
  333. cmd = env.fmt(cmd)
  334. logger.info(env.fmt('{task_desc}', kind='new'))
  335. if cli.dry_run:
  336. logger.info('[dry-run] ' + cmd)
  337. return None
  338. logger.debug(TAB + TAB.join(cmd.splitlines()))
  339. process = subprocess.Popen(
  340. cmd, shell=True,
  341. stdout=subprocess.PIPE,
  342. stderr=subprocess.STDOUT,
  343. env=env,
  344. )
  345. stdout, stderr = process.communicate()
  346. success = process.returncode == 0
  347. if stdout:
  348. logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
  349. if not success:
  350. raise LocalException(stdout, stderr)
  351. return ObjectDict(stdout=stdout, stderr=stderr)
  352. def run_python(task, env, cli):
  353. # Execute a piece of python localy
  354. code = task.python
  355. logger.info(env.fmt('{task_desc}', kind='new'))
  356. if cli.dry_run:
  357. logger.info('[dry-run] ' + code)
  358. return None
  359. logger.debug(TAB + TAB.join(code.splitlines()))
  360. cmd = ['python', '-c', 'import sys;exec(sys.stdin.read())']
  361. if task.sudo:
  362. user = 'root' if task.sudo is True else task.sudo
  363. cmd = 'sudo -u {} -- {}'.format(user, cmd)
  364. process = subprocess.Popen(
  365. cmd,
  366. stdout=subprocess.PIPE,
  367. stderr=subprocess.PIPE,
  368. stdin=subprocess.PIPE,
  369. env=env,
  370. )
  371. # Plug io
  372. out_buff = io.StringIO()
  373. err_buff = io.StringIO()
  374. log_stream(process.stdout, out_buff)
  375. log_stream(process.stderr, err_buff)
  376. process.stdin.write(code.encode())
  377. process.stdin.flush()
  378. process.stdin.close()
  379. success = process.wait() == 0
  380. process.stdout.close()
  381. process.stderr.close()
  382. out = out_buff.getvalue()
  383. if out:
  384. logger.debug(TAB + TAB.join(out.splitlines()))
  385. if not success:
  386. raise LocalException(out + err_buff.getvalue())
  387. return ObjectDict(stdout=out, stderr=err_buff.getvalue())
  388. def log_stream(stream, buff):
  389. def _log():
  390. try:
  391. for chunk in iter(lambda: stream.readline(2048), ""):
  392. if isinstance(chunk, bytes):
  393. chunk = chunk.decode()
  394. buff.write(chunk)
  395. except ValueError:
  396. # read raises a ValueError on closed stream
  397. pass
  398. t = threading.Thread(target=_log)
  399. t.start()
  400. return t
  401. def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
  402. '''
  403. Helper function to run `cmd` command on remote host
  404. '''
  405. chan = client.get_transport().open_session()
  406. if env:
  407. chan.update_environment(env)
  408. stdin = chan.makefile('wb')
  409. stdout = chan.makefile('r')
  410. stderr = chan.makefile_stderr('r')
  411. out_buff = io.StringIO()
  412. err_buff = io.StringIO()
  413. out_thread = log_stream(stdout, out_buff)
  414. err_thread = log_stream(stderr, err_buff)
  415. if sudo:
  416. assert not in_buff, 'in_buff and sudo can not be combined'
  417. if isinstance(sudo, str):
  418. sudo_cmd = 'sudo -u %s -s' % sudo
  419. else:
  420. sudo_cmd = 'sudo -s'
  421. chan.exec_command(sudo_cmd)
  422. in_buff = cmd
  423. else:
  424. chan.exec_command(cmd)
  425. if in_buff:
  426. # XXX use a real buff (not a simple str) ?
  427. stdin.write(in_buff)
  428. stdin.flush()
  429. stdin.close()
  430. chan.shutdown_write()
  431. success = chan.recv_exit_status() == 0
  432. out_thread.join()
  433. err_thread.join()
  434. if not success:
  435. raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
  436. res = ObjectDict(
  437. stdout = out_buff.getvalue(),
  438. stderr = err_buff.getvalue(),
  439. )
  440. return res
  441. def run_remote(task, host, env, cli):
  442. res = None
  443. host = env.fmt(host)
  444. env.update({
  445. 'host': extract_host(host),
  446. })
  447. if cli.dry_run:
  448. client = DummyClient()
  449. else:
  450. client = connect(host, cli.cfg.auth)
  451. if task.run:
  452. cmd = env.fmt(task.run)
  453. prefix = ''
  454. if task.sudo:
  455. if task.sudo is True:
  456. prefix = '[sudo] '
  457. else:
  458. prefix = '[sudo as %s] ' % task.sudo
  459. msg = prefix + '{host}: {task_desc}'
  460. logger.info(env.fmt(msg, kind='new'))
  461. logger.debug(TAB + TAB.join(cmd.splitlines()))
  462. if cli.dry_run:
  463. logger.info('[dry-run] ' + cmd)
  464. else:
  465. res = run_helper(client, cmd, env=env, sudo=task.sudo)
  466. elif task.send:
  467. local_path = env.fmt(task.send)
  468. remote_path = env.fmt(task.to)
  469. if not os.path.exists(local_path):
  470. raise ByrdException('Path "%s" not found' % local_path)
  471. else:
  472. send(client, env, cli, task)
  473. else:
  474. raise ByrdException('Unable to run task "%s"' % task.name)
  475. if res and res.stdout:
  476. logger.debug(TAB + TAB.join(res.stdout.splitlines()))
  477. return res
  478. def send(client, env, cli, task):
  479. fmt = task.fmt and Env(env, {'fmt': 'new'}).fmt(task.fmt) or None
  480. local_path = env.fmt(task.send)
  481. remote_path = env.fmt(task.to)
  482. dry_run = cli.dry_run
  483. with client.open_sftp() as sftp:
  484. if os.path.isfile(local_path):
  485. send_file(sftp, os.path.abspath(local_path), remote_path, env,
  486. dry_run=dry_run, fmt=fmt)
  487. elif os.path.isdir(local_path):
  488. for root, subdirs, files in os.walk(local_path):
  489. rel_dir = os.path.relpath(root, local_path)
  490. rel_dirs = os.path.split(rel_dir)
  491. rem_dir = posixpath.join(remote_path, *rel_dirs)
  492. run_helper(client, 'mkdir -p {}'.format(rem_dir))
  493. for f in files:
  494. rel_f = os.path.join(root, f)
  495. rem_file = posixpath.join(rem_dir, f)
  496. send_file(sftp, os.path.abspath(rel_f), rem_file, env,
  497. dry_run=dry_run, fmt=fmt)
  498. else:
  499. msg = 'Unexpected path "%s" (not a file, not a directory)'
  500. raise ByrdException(msg % local_path)
  501. def send_file(sftp, local_path, remote_path, env, dry_run=False, fmt=None):
  502. if not fmt:
  503. logger.info(f'[send] {local_path} -> {remote_path}')
  504. lines = islice(open(local_path), 30)
  505. logger.debug('File head:' + TAB.join(lines))
  506. if not dry_run:
  507. sftp.put(local_path, remote_path)
  508. return
  509. # Format file content and save it on remote
  510. logger.info(f'[fmt] {local_path} -> {remote_path}')
  511. content = env.fmt(open(local_path).read(), kind=fmt)
  512. lines = islice(content.splitlines(), 30)
  513. logger.debug('File head:' + TAB.join(lines))
  514. if not dry_run:
  515. fh = sftp.open(remote_path, mode='w')
  516. fh.write(content)
  517. fh.close()
  518. def run_task(task, host, cli, env=None):
  519. '''
  520. Execute one task on one host (or locally)
  521. '''
  522. if task.local:
  523. res = run_local(task.local, env, cli)
  524. elif task.python:
  525. res = run_python(task, env, cli)
  526. else:
  527. res = run_remote(task, host, env, cli)
  528. if task.get('assert'):
  529. eval_env = {
  530. 'stdout': res.stdout.strip(),
  531. 'stderr': res.stderr.strip(),
  532. }
  533. assert_ = env.fmt(task['assert'])
  534. ok = eval(assert_, eval_env)
  535. if ok:
  536. logger.info('Assert ok')
  537. else:
  538. raise ByrdException('Assert "%s" failed!' % assert_)
  539. return res
  540. def run_batch(task, hosts, cli, global_env=None):
  541. '''
  542. Run one task on a list of hosts
  543. '''
  544. out = None
  545. export_env = {}
  546. task_env = global_env.fmt(task.get('env', {}))
  547. if task.get('multi'):
  548. parent_env = Env(export_env, task_env, global_env)
  549. parent_sudo = task.sudo
  550. for pos, step in enumerate(task.multi):
  551. task_name = step.task
  552. if task_name:
  553. # _cfg contain "local" config wrt the task
  554. siblings = task._cfg.tasks
  555. spellcheck(siblings, task_name)
  556. sub_task = siblings[task_name]
  557. sudo = step.sudo or sub_task.sudo or parent_sudo
  558. else:
  559. # reify a task out of attributes
  560. sub_task = Task.parse(step)
  561. sub_task._path = '%s->[%s]' % (task._path, pos)
  562. sudo = sub_task.sudo or parent_sudo
  563. sub_task.sudo = sudo
  564. network = step.get('network')
  565. if network:
  566. spellcheck(cli.cfg.networks, network)
  567. hosts = cli.cfg.networks[network].hosts
  568. child_env = step.get('env', {})
  569. child_env = parent_env.fmt(child_env)
  570. out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
  571. out = out.decode() if isinstance(out, bytes) else out
  572. export_env['_'] = out
  573. if step.export:
  574. export_env[step.export] = out
  575. else:
  576. task_env.update({
  577. 'task_desc': global_env.fmt(task.desc),
  578. 'task_name': task.name,
  579. })
  580. parent_env = Env(task_env, global_env)
  581. if task.get('fmt'):
  582. parent_env.fmt_kind = task.fmt
  583. res = None
  584. if task.once and (task.local or task.python):
  585. res = run_task(task, None, cli, parent_env)
  586. elif hosts:
  587. for host in hosts:
  588. env_host = extract_host(host)
  589. parent_env.update({
  590. 'host': env_host,
  591. })
  592. res = run_task(task, host, cli, parent_env)
  593. if task.once:
  594. break
  595. else:
  596. logger.warning('Nothing to do for task "%s"' % task._path)
  597. out = res and res.stdout.strip() or ''
  598. return out
  599. def extract_host(host_string):
  600. return host_string and host_string.split('@')[-1] or ''
  601. def abort(msg):
  602. logger.error(msg)
  603. sys.exit(1)
  604. def load_cfg(path, prefix=None):
  605. load_sections = ('networks', 'tasks', 'auth', 'env')
  606. if os.path.isfile(path):
  607. logger.debug('Load config %s' % path)
  608. cfg = yaml_load(open(path))
  609. cfg = ConfigRoot.parse(cfg)
  610. else:
  611. raise ByrdException('Config file "%s" not found' % path)
  612. # Define useful defaults
  613. cfg.networks = cfg.networks or ObjectDict()
  614. cfg.tasks = cfg.tasks or ObjectDict()
  615. # Create backrefs between tasks to the local config
  616. if cfg.get('tasks'):
  617. items = cfg['tasks'].items()
  618. for k, v in items:
  619. v._cfg = ObjectDict(cfg.copy())
  620. if prefix:
  621. key_fn = lambda x: '/'.join(prefix + [x])
  622. # Apply prefix
  623. for section in load_sections:
  624. if not section in cfg:
  625. continue
  626. items = cfg[section].items()
  627. cfg[section] = {key_fn(k): v for k, v in items}
  628. # Recursive load
  629. if cfg.load:
  630. cfg_path = os.path.dirname(path)
  631. for item in cfg.load:
  632. if item.get('file'):
  633. rel_path = item.file
  634. child_path = os.path.join(cfg_path, item.file)
  635. elif item.get('pkg'):
  636. rel_path = item.pkg
  637. child_path = os.path.join(PKG_DIR, item.pkg)
  638. if item.get('as'):
  639. child_prefix = item['as']
  640. else:
  641. child_prefix, _ = os.path.splitext(rel_path)
  642. child_cfg = load_cfg(child_path, child_prefix.split('/'))
  643. for section in load_sections:
  644. if not section in cfg:
  645. continue
  646. cfg[section].update(child_cfg.get(section, {}))
  647. return cfg
  648. def load_cli(args=None):
  649. parser = argparse.ArgumentParser()
  650. parser.add_argument('names', nargs='*',
  651. help='Hosts and commands to run them on')
  652. parser.add_argument('-c', '--config', default='bd.yaml',
  653. help='Config file')
  654. parser.add_argument('-R', '--run', nargs='*', default=[],
  655. help='Run remote task')
  656. parser.add_argument('-L', '--run-local', nargs='*', default=[],
  657. help='Run local task')
  658. parser.add_argument('-P', '--run-python', nargs='*', default=[],
  659. help='Run python task')
  660. parser.add_argument('-d', '--dry-run', action='store_true',
  661. help='Do not run actual tasks, just print them')
  662. parser.add_argument('-e', '--env', nargs='*', default=[],
  663. help='Add value to execution environment '
  664. '(ex: -e foo=bar "name=John Doe")')
  665. parser.add_argument('-s', '--sudo', default='auto',
  666. help='Enable sudo (auto|yes|no')
  667. parser.add_argument('-v', '--verbose', action='count',
  668. default=0, help='Increase verbosity')
  669. parser.add_argument('-q', '--quiet', action='count',
  670. default=0, help='Decrease verbosity')
  671. parser.add_argument('-n', '--no-color', action='store_true',
  672. help='Disable colored logs')
  673. parser.add_argument('-i', '--info', action='store_true',
  674. help='Print info')
  675. cli = parser.parse_args(args=args)
  676. cli = ObjectDict(vars(cli))
  677. # Load config
  678. cfg = load_cfg(cli.config)
  679. cli.cfg = cfg
  680. cli.update(get_hosts_and_tasks(cli, cfg))
  681. # Transformt env string into dict
  682. cli.env = dict(e.split('=') for e in cli.env)
  683. return cli
  684. def get_hosts_and_tasks(cli, cfg):
  685. # Make sure we don't have overlap between hosts and tasks
  686. items = list(cfg.networks) + list(cfg.tasks)
  687. msg = 'Name collision between tasks and networks'
  688. assert len(set(items)) == len(items), msg
  689. # Build task list
  690. tasks = []
  691. networks = []
  692. for name in cli.names:
  693. if name in cfg.networks:
  694. host = cfg.networks[name]
  695. networks.append(host)
  696. elif name in cfg.tasks:
  697. task = cfg.tasks[name]
  698. tasks.append(task)
  699. else:
  700. msg = 'Name "%s" not understood' % name
  701. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  702. if matches:
  703. msg += ', try: %s' % ' or '.join(matches)
  704. raise ByrdException (msg)
  705. # Collect custom tasks from cli
  706. customs = []
  707. for cli_key in ('run', 'run_local', 'run_python'):
  708. cmd_key = cli_key.rsplit('_', 1)[-1]
  709. customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
  710. for custom_task in customs:
  711. task = Task.parse(yaml_load(custom_task))
  712. task.desc = 'Custom command'
  713. tasks.append(task)
  714. hosts = list(chain.from_iterable(n.hosts for n in networks))
  715. return dict(hosts=hosts, tasks=tasks)
  716. def info(cli):
  717. formatter = Formatter()
  718. for name, attr in cli.cfg.tasks.items():
  719. kind = 'remote'
  720. if attr.python:
  721. kind = 'python'
  722. elif attr.local:
  723. kind = 'local'
  724. elif attr.multi:
  725. kind = 'multi'
  726. elif attr.send:
  727. kind = 'send file'
  728. print(f'{name} [{kind}]:\n\tDescription: {attr.desc}')
  729. values = []
  730. for v in attr.values():
  731. if isinstance(v, list):
  732. values.extend(v)
  733. elif isinstance(v, dict):
  734. values.extend(v.values())
  735. else:
  736. values.append(v)
  737. values = filter(lambda x: isinstance(x, str), values)
  738. fmt_fields = [i[1] for v in values for i in formatter.parse(v) if i[1]]
  739. if fmt_fields:
  740. variables = ', '.join(sorted(set(fmt_fields)))
  741. else:
  742. variables = None
  743. if variables:
  744. print(f'\tVariables: {variables}')
  745. def main():
  746. cli = None
  747. try:
  748. cli = load_cli()
  749. if not cli.no_color:
  750. enable_logging_color()
  751. cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
  752. level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
  753. log_handler.setLevel(level)
  754. logger.setLevel(level)
  755. if cli.info:
  756. info(cli)
  757. return
  758. base_env = Env(
  759. cli.env, # Highest-priority
  760. cli.cfg.get('env'),
  761. os.environ, # Lowest
  762. )
  763. for task in cli.tasks:
  764. run_batch(task, cli.hosts, cli, base_env)
  765. except ByrdException as e:
  766. if cli and cli.verbose > 2:
  767. raise
  768. abort(str(e))
  769. if __name__ == '__main__':
  770. main()