byrd.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. from contextlib import contextmanager
  2. from getpass import getpass
  3. from hashlib import md5
  4. from itertools import chain
  5. from collections import ChainMap, OrderedDict, defaultdict
  6. from itertools import islice
  7. from string import Formatter
  8. import argparse
  9. import io
  10. import logging
  11. import os
  12. import posixpath
  13. import subprocess
  14. import sys
  15. import threading
  16. try:
  17. # This file is imported by setup.py at install time
  18. import keyring
  19. import paramiko
  20. import yaml
  21. except ImportError:
  22. pass
  23. __version__ = '0.0'
  24. log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  25. logger = logging.getLogger('byrd')
  26. logger.setLevel(logging.INFO)
  27. log_handler = logging.StreamHandler()
  28. log_handler.setLevel(logging.INFO)
  29. log_handler.setFormatter(logging.Formatter(log_fmt))
  30. logger.addHandler(log_handler)
  31. basedir, _ = os.path.split(__file__)
  32. PKG_DIR = os.path.join(basedir, 'pkg')
  33. TAB = '\n '
  34. class ByrdException(Exception):
  35. pass
  36. class FmtException(ByrdException):
  37. pass
  38. class ExecutionException(ByrdException):
  39. pass
  40. class RemoteException(ExecutionException):
  41. pass
  42. class LocalException(ExecutionException):
  43. pass
  44. def enable_logging_color():
  45. try:
  46. import colorama
  47. except ImportError:
  48. return
  49. colorama.init()
  50. MAGENTA = colorama.Fore.MAGENTA
  51. RED = colorama.Fore.RED
  52. RESET = colorama.Style.RESET_ALL
  53. # We define custom handler ..
  54. class Handler(logging.StreamHandler):
  55. def format(self, record):
  56. if record.levelname == 'INFO':
  57. record.msg = MAGENTA + record.msg + RESET
  58. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  59. record.msg = RED + record.msg + RESET
  60. return super(Handler, self).format(record)
  61. # .. and plug it
  62. logger.removeHandler(log_handler)
  63. handler = Handler()
  64. handler.setFormatter(logging.Formatter(log_fmt))
  65. logger.addHandler(handler)
  66. logger.propagate = 0
  67. def yaml_load(stream):
  68. class OrderedLoader(yaml.Loader):
  69. pass
  70. def construct_mapping(loader, node):
  71. loader.flatten_mapping(node)
  72. return OrderedDict(loader.construct_pairs(node))
  73. OrderedLoader.add_constructor(
  74. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  75. construct_mapping)
  76. return yaml.load(stream, OrderedLoader)
  77. def edits(word):
  78. yield word
  79. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  80. for left, right in splits:
  81. if right:
  82. yield left + right[1:]
  83. def gen_candidates(wordlist):
  84. candidates = defaultdict(set)
  85. for word in wordlist:
  86. for ed1 in edits(word):
  87. for ed2 in edits(ed1):
  88. candidates[ed2].add(word)
  89. return candidates
  90. def spell(candidates, word):
  91. matches = set(chain.from_iterable(
  92. candidates[ed] for ed in edits(word) if ed in candidates
  93. ))
  94. return matches
  95. def spellcheck(objdict, word):
  96. if word in objdict:
  97. return
  98. candidates = objdict.get('_candidates')
  99. if not candidates:
  100. candidates = gen_candidates(list(objdict))
  101. objdict._candidates = candidates
  102. msg = '"%s" not found in %s' % (word, objdict._path)
  103. matches = spell(candidates, word)
  104. if matches:
  105. msg += ', try: %s' % ' or '.join(matches)
  106. raise ByrdException(msg)
  107. class ObjectDict(dict):
  108. """
  109. Simple objet sub-class that allows to transform a dict into an
  110. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  111. """
  112. _meta = {}
  113. def copy(self):
  114. res = ObjectDict(super().copy())
  115. ObjectDict._meta[id(res)] = ObjectDict._meta.get(id(self), {}).copy()
  116. return res
  117. def __getattr__(self, key):
  118. if key.startswith('_'):
  119. return ObjectDict._meta[id(self), key]
  120. if key in self:
  121. return self[key]
  122. else:
  123. return None
  124. def __setattr__(self, key, value):
  125. if key.startswith('_'):
  126. ObjectDict._meta[id(self), key] = value
  127. else:
  128. self[key] = value
  129. class Node:
  130. @staticmethod
  131. def fail(path, kind):
  132. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  133. raise ByrdException(msg % (kind, '->'.join(path)))
  134. @classmethod
  135. def parse(cls, cfg, path=tuple()):
  136. children = getattr(cls, '_children', None)
  137. type_name = children and type(children).__name__ \
  138. or ' or '.join((c.__name__ for c in cls._type))
  139. res = None
  140. if type_name == 'dict':
  141. if not isinstance(cfg, dict):
  142. cls.fail(path, type_name)
  143. res = ObjectDict()
  144. if '*' in children:
  145. assert len(children) == 1, "Don't mix '*' and other keys"
  146. child_class = children['*']
  147. for name, value in cfg.items():
  148. res[name] = child_class.parse(value, path + (name,))
  149. else:
  150. # Enforce known pre-defined
  151. for key in cfg:
  152. if key not in children:
  153. path = ' -> '.join(path)
  154. if path:
  155. msg = 'Attribute "%s" not understood in %s' % (
  156. key, path)
  157. else:
  158. msg = 'Top-level attribute "%s" not understood' % (
  159. key)
  160. candidates = gen_candidates(children.keys())
  161. matches = spell(candidates, key)
  162. if matches:
  163. msg += ', try: %s' % ' or '.join(matches)
  164. raise ByrdException(msg)
  165. for name, child_class in children.items():
  166. if name not in cfg:
  167. continue
  168. res[name] = child_class.parse(cfg[name], path + (name,))
  169. elif type_name == 'list':
  170. if not isinstance(cfg, list):
  171. cls.fail(path, type_name)
  172. child_class = children[0]
  173. res = [child_class.parse(c, path+ ('[%s]' % pos,))
  174. for pos, c in enumerate(cfg)]
  175. else:
  176. if not isinstance(cfg, cls._type):
  177. cls.fail(path, type_name)
  178. res = cfg
  179. return cls.setup(res, path)
  180. @classmethod
  181. def setup(cls, values, path):
  182. if isinstance(values, ObjectDict):
  183. values._path = '->'.join(path)
  184. return values
  185. class Atom(Node):
  186. _type = (str, bool)
  187. class AtomList(Node):
  188. _children = [Atom]
  189. class Hosts(Node):
  190. _children = [Atom]
  191. class Auth(Node):
  192. _children = {'*': Atom}
  193. class EnvNode(Node):
  194. _children = {'*': Atom}
  195. class HostGroup(Node):
  196. _children = {
  197. 'hosts': Hosts,
  198. }
  199. class Network(Node):
  200. _children = {
  201. '*': HostGroup,
  202. }
  203. class Multi(Node):
  204. _children = {
  205. 'task': Atom,
  206. 'export': Atom,
  207. 'network': Atom,
  208. }
  209. class MultiList(Node):
  210. _children = [Multi]
  211. class Task(Node):
  212. _children = {
  213. 'desc': Atom,
  214. 'local': Atom,
  215. 'python': Atom,
  216. 'once': Atom,
  217. 'run': Atom,
  218. 'sudo': Atom,
  219. 'send': Atom,
  220. 'to': Atom,
  221. 'assert': Atom,
  222. 'env': EnvNode,
  223. 'multi': MultiList,
  224. 'fmt': Atom,
  225. }
  226. @classmethod
  227. def setup(cls, values, path):
  228. values['name'] = path and path[-1] or ''
  229. if 'desc' not in values:
  230. values['desc'] = values.get('name', '')
  231. super().setup(values, path)
  232. return values
  233. # Multi can also accept any task attribute:
  234. Multi._children.update(Task._children)
  235. class TaskGroup(Node):
  236. _children = {
  237. '*': Task,
  238. }
  239. class LoadNode(Node):
  240. _children = {
  241. 'file': Atom,
  242. 'pkg': Atom,
  243. 'as': Atom,
  244. }
  245. class LoadList(Node):
  246. _children = [LoadNode]
  247. class ConfigRoot(Node):
  248. _children = {
  249. 'networks': Network,
  250. 'tasks': TaskGroup,
  251. 'auth': Auth,
  252. 'env': EnvNode,
  253. 'load': LoadList,
  254. }
  255. class Env(ChainMap):
  256. def __init__(self, *dicts):
  257. self.fmt_kind = 'new'
  258. return super().__init__(*filter(lambda x: x is not None, dicts))
  259. def fmt_env(self, child_env, kind=None):
  260. new_env = {}
  261. for key, val in child_env.items():
  262. # env wrap-around!
  263. new_val = self.fmt(val, kind=kind)
  264. if new_val == val:
  265. continue
  266. new_env[key] = new_val
  267. return Env(new_env, child_env)
  268. def fmt_string(self, string, kind=None):
  269. fmt_kind = kind or self.fmt_kind
  270. try:
  271. if fmt_kind == 'old':
  272. return string % self
  273. else:
  274. return string.format(**self)
  275. except KeyError as exc:
  276. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  277. candidates = gen_candidates(self.keys())
  278. key = exc.args[0]
  279. matches = spell(candidates, key)
  280. if matches:
  281. msg += ', try: %s' % ' or '.join(matches)
  282. raise FmtException(msg )
  283. except IndexError as exc:
  284. msg = 'Unable to format "%s", positional argument not supported'
  285. raise FmtException(msg)
  286. def fmt(self, what, kind=None):
  287. if isinstance(what, str):
  288. return self.fmt_string(what, kind=kind)
  289. return self.fmt_env(what, kind=kind)
  290. class DummyClient:
  291. '''
  292. Dummy Paramiko client, mainly usefull for testing & dry runs
  293. '''
  294. @contextmanager
  295. def open_sftp(self):
  296. yield None
  297. def get_secret(service, resource, resource_id=None):
  298. resource_id = resource_id or resource
  299. secret = keyring.get_password(service, resource_id)
  300. if not secret:
  301. secret = getpass('Password for %s: ' % resource)
  302. keyring.set_password(service, resource_id, secret)
  303. return secret
  304. def get_passphrase(key_path):
  305. service = 'SSH private key'
  306. csum = md5(open(key_path, 'rb').read()).digest().hex()
  307. return get_secret(service, key_path, csum)
  308. def get_password(host):
  309. service = 'SSH password'
  310. return get_secret(service, host)
  311. def get_sudo_passwd():
  312. service = "Sudo password"
  313. return get_secret(service, 'sudo')
  314. CONNECTION_CACHE = {}
  315. def connect(host, auth):
  316. if host in CONNECTION_CACHE:
  317. return CONNECTION_CACHE[host]
  318. private_key_file = password = None
  319. if auth and auth.get('ssh_private_key'):
  320. private_key_file = auth.ssh_private_key
  321. if not os.path.exists(auth.ssh_private_key):
  322. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  323. raise ByrdException(msg)
  324. password = get_passphrase(auth.ssh_private_key)
  325. else:
  326. password = get_password(host)
  327. username, hostname = host.split('@', 1)
  328. client = paramiko.SSHClient()
  329. client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  330. client.connect(hostname, username=username, password=password,
  331. key_filename=private_key_file,
  332. )
  333. CONNECTION_CACHE[host] = client
  334. return client
  335. def run_local(cmd, env, cli):
  336. # Run local task
  337. cmd = env.fmt(cmd)
  338. logger.info(env.fmt('{task_desc}', kind='new'))
  339. if cli.dry_run:
  340. logger.info('[dry-run] ' + cmd)
  341. return None
  342. logger.debug(TAB + TAB.join(cmd.splitlines()))
  343. process = subprocess.Popen(
  344. cmd, shell=True,
  345. stdout=subprocess.PIPE,
  346. stderr=subprocess.STDOUT,
  347. env=env,
  348. )
  349. stdout, stderr = process.communicate()
  350. success = process.returncode == 0
  351. if stdout:
  352. logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
  353. if not success:
  354. raise LocalException(stdout, stderr)
  355. return ObjectDict(stdout=stdout, stderr=stderr)
  356. def run_python(task, env, cli):
  357. # Execute a piece of python localy
  358. code = task.python
  359. logger.info(env.fmt('{task_desc}', kind='new'))
  360. if cli.dry_run:
  361. logger.info('[dry-run] ' + code)
  362. return None
  363. logger.debug(TAB + TAB.join(code.splitlines()))
  364. cmd = ['python', '-c', 'import sys;exec(sys.stdin.read())']
  365. if task.sudo:
  366. user = 'root' if task.sudo is True else task.sudo
  367. cmd = 'sudo -u {} -- {}'.format(user, cmd)
  368. process = subprocess.Popen(
  369. cmd,
  370. stdout=subprocess.PIPE,
  371. stderr=subprocess.PIPE,
  372. stdin=subprocess.PIPE,
  373. env=env,
  374. )
  375. # Plug io
  376. out_buff = io.StringIO()
  377. err_buff = io.StringIO()
  378. log_stream(process.stdout, out_buff)
  379. log_stream(process.stderr, err_buff)
  380. process.stdin.write(code.encode())
  381. process.stdin.flush()
  382. process.stdin.close()
  383. success = process.wait() == 0
  384. process.stdout.close()
  385. process.stderr.close()
  386. out = out_buff.getvalue()
  387. if out:
  388. logger.debug(TAB + TAB.join(out.splitlines()))
  389. if not success:
  390. raise LocalException(out + err_buff.getvalue())
  391. return ObjectDict(stdout=out, stderr=err_buff.getvalue())
  392. def log_stream(stream, buff):
  393. def _log():
  394. try:
  395. for chunk in iter(lambda: stream.readline(2048), ""):
  396. if isinstance(chunk, bytes):
  397. chunk = chunk.decode()
  398. buff.write(chunk)
  399. except ValueError:
  400. # read raises a ValueError on closed stream
  401. pass
  402. t = threading.Thread(target=_log)
  403. t.start()
  404. return t
  405. def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
  406. '''
  407. Helper function to run `cmd` command on remote host
  408. '''
  409. chan = client.get_transport().open_session()
  410. if env:
  411. chan.update_environment(env)
  412. stdin = chan.makefile('wb')
  413. stdout = chan.makefile('r')
  414. stderr = chan.makefile_stderr('r')
  415. out_buff = io.StringIO()
  416. err_buff = io.StringIO()
  417. out_thread = log_stream(stdout, out_buff)
  418. err_thread = log_stream(stderr, err_buff)
  419. if sudo:
  420. assert not in_buff, 'in_buff and sudo can not be combined'
  421. if isinstance(sudo, str):
  422. sudo_cmd = 'sudo -u %s -s' % sudo
  423. else:
  424. sudo_cmd = 'sudo -s'
  425. chan.exec_command(sudo_cmd)
  426. in_buff = cmd
  427. else:
  428. chan.exec_command(cmd)
  429. if in_buff:
  430. # XXX use a real buff (not a simple str) ?
  431. stdin.write(in_buff)
  432. stdin.flush()
  433. stdin.close()
  434. chan.shutdown_write()
  435. success = chan.recv_exit_status() == 0
  436. out_thread.join()
  437. err_thread.join()
  438. if not success:
  439. raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
  440. res = ObjectDict(
  441. stdout = out_buff.getvalue(),
  442. stderr = err_buff.getvalue(),
  443. )
  444. return res
  445. def run_remote(task, host, env, cli):
  446. res = None
  447. host = env.fmt(host)
  448. env.update({
  449. 'host': extract_host(host),
  450. })
  451. if cli.dry_run:
  452. client = DummyClient()
  453. else:
  454. client = connect(host, cli.cfg.auth)
  455. if task.run:
  456. cmd = env.fmt(task.run)
  457. prefix = ''
  458. if task.sudo:
  459. if task.sudo is True:
  460. prefix = '[sudo] '
  461. else:
  462. prefix = '[sudo as %s] ' % task.sudo
  463. msg = prefix + '{host}: {task_desc}'
  464. logger.info(env.fmt(msg, kind='new'))
  465. logger.debug(TAB + TAB.join(cmd.splitlines()))
  466. if cli.dry_run:
  467. logger.info('[dry-run] ' + cmd)
  468. else:
  469. res = run_helper(client, cmd, env=env, sudo=task.sudo)
  470. elif task.send:
  471. local_path = env.fmt(task.send)
  472. remote_path = env.fmt(task.to)
  473. if not os.path.exists(local_path):
  474. raise ByrdException('Path "%s" not found' % local_path)
  475. else:
  476. send(client, env, cli, task)
  477. else:
  478. raise ByrdException('Unable to run task "%s"' % task.name)
  479. if res and res.stdout:
  480. logger.debug(TAB + TAB.join(res.stdout.splitlines()))
  481. return res
  482. def send(client, env, cli, task):
  483. fmt = task.fmt and Env(env, {'fmt': 'new'}).fmt(task.fmt) or None
  484. local_path = env.fmt(task.send)
  485. remote_path = env.fmt(task.to)
  486. dry_run = cli.dry_run
  487. with client.open_sftp() as sftp:
  488. if os.path.isfile(local_path):
  489. send_file(sftp, os.path.abspath(local_path), remote_path, env,
  490. dry_run=dry_run, fmt=fmt)
  491. elif os.path.isdir(local_path):
  492. for root, subdirs, files in os.walk(local_path):
  493. rel_dir = os.path.relpath(root, local_path)
  494. rel_dirs = os.path.split(rel_dir)
  495. rem_dir = posixpath.join(remote_path, *rel_dirs)
  496. run_helper(client, 'mkdir -p {}'.format(rem_dir))
  497. for f in files:
  498. rel_f = os.path.join(root, f)
  499. rem_file = posixpath.join(rem_dir, f)
  500. send_file(sftp, os.path.abspath(rel_f), rem_file, env,
  501. dry_run=dry_run, fmt=fmt)
  502. else:
  503. msg = 'Unexpected path "%s" (not a file, not a directory)'
  504. raise ByrdException(msg % local_path)
  505. def send_file(sftp, local_path, remote_path, env, dry_run=False, fmt=None):
  506. if not fmt:
  507. logger.info(f'[send] {local_path} -> {remote_path}')
  508. lines = islice(open(local_path), 30)
  509. logger.debug('File head:' + TAB.join(lines))
  510. if not dry_run:
  511. sftp.put(local_path, remote_path)
  512. return
  513. # Format file content and save it on remote
  514. logger.info(f'[fmt] {local_path} -> {remote_path}')
  515. content = env.fmt(open(local_path).read(), kind=fmt)
  516. lines = islice(content.splitlines(), 30)
  517. logger.debug('File head:' + TAB.join(lines))
  518. if not dry_run:
  519. fh = sftp.open(remote_path, mode='w')
  520. fh.write(content)
  521. fh.close()
  522. def run_task(task, host, cli, env=None):
  523. '''
  524. Execute one task on one host (or locally)
  525. '''
  526. if task.local:
  527. res = run_local(task.local, env, cli)
  528. elif task.python:
  529. res = run_python(task, env, cli)
  530. else:
  531. res = run_remote(task, host, env, cli)
  532. if task.get('assert'):
  533. eval_env = {
  534. 'stdout': res.stdout.strip(),
  535. 'stderr': res.stderr.strip(),
  536. }
  537. assert_ = env.fmt(task['assert'])
  538. ok = eval(assert_, eval_env)
  539. if ok:
  540. logger.info('Assert ok')
  541. else:
  542. raise ByrdException('Assert "%s" failed!' % assert_)
  543. return res
  544. def run_batch(task, hosts, cli, global_env=None):
  545. '''
  546. Run one task on a list of hosts
  547. '''
  548. out = None
  549. export_env = {}
  550. task_env = global_env.fmt(task.get('env', {}))
  551. if task.get('multi'):
  552. parent_env = Env(export_env, task_env, global_env)
  553. parent_sudo = task.sudo
  554. for pos, step in enumerate(task.multi):
  555. task_name = step.task
  556. if task_name:
  557. # _cfg contain "local" config wrt the task
  558. siblings = task._cfg.tasks
  559. spellcheck(siblings, task_name)
  560. sub_task = siblings[task_name]
  561. sudo = step.sudo or sub_task.sudo or parent_sudo
  562. else:
  563. # reify a task out of attributes
  564. sub_task = Task.parse(step)
  565. sub_task._path = '%s->[%s]' % (task._path, pos)
  566. sudo = sub_task.sudo or parent_sudo
  567. sub_task.sudo = sudo
  568. network = step.get('network')
  569. if network:
  570. spellcheck(cli.cfg.networks, network)
  571. hosts = cli.cfg.networks[network].hosts
  572. child_env = step.get('env', {})
  573. child_env = parent_env.fmt(child_env)
  574. out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
  575. out = out.decode() if isinstance(out, bytes) else out
  576. export_env['_'] = out
  577. if step.export:
  578. export_env[step.export] = out
  579. else:
  580. task_env.update({
  581. 'task_desc': global_env.fmt(task.desc),
  582. 'task_name': task.name,
  583. })
  584. parent_env = Env(task_env, global_env)
  585. if task.get('fmt'):
  586. parent_env.fmt_kind = task.fmt
  587. res = None
  588. if task.once and (task.local or task.python):
  589. res = run_task(task, None, cli, parent_env)
  590. elif hosts:
  591. for host in hosts:
  592. env_host = extract_host(host)
  593. parent_env.update({
  594. 'host': env_host,
  595. })
  596. res = run_task(task, host, cli, parent_env)
  597. if task.once:
  598. break
  599. else:
  600. logger.warning('Nothing to do for task "%s"' % task._path)
  601. out = res and res.stdout.strip() or ''
  602. return out
  603. def extract_host(host_string):
  604. return host_string and host_string.split('@')[-1] or ''
  605. def abort(msg):
  606. logger.error(msg)
  607. sys.exit(1)
  608. def load_cfg(path, prefix=None):
  609. load_sections = ('networks', 'tasks', 'auth', 'env')
  610. if os.path.isfile(path):
  611. logger.debug('Load config %s' % path)
  612. cfg = yaml_load(open(path))
  613. cfg = ConfigRoot.parse(cfg)
  614. else:
  615. raise ByrdException('Config file "%s" not found' % path)
  616. # Define useful defaults
  617. cfg.networks = cfg.networks or ObjectDict()
  618. cfg.tasks = cfg.tasks or ObjectDict()
  619. # Create backrefs between tasks to the local config
  620. if cfg.get('tasks'):
  621. items = cfg['tasks'].items()
  622. for k, v in items:
  623. v._cfg = ObjectDict(cfg.copy())
  624. if prefix:
  625. key_fn = lambda x: '/'.join(prefix + [x])
  626. # Apply prefix
  627. for section in load_sections:
  628. if not section in cfg:
  629. continue
  630. items = cfg[section].items()
  631. cfg[section] = {key_fn(k): v for k, v in items}
  632. # Recursive load
  633. if cfg.load:
  634. cfg_path = os.path.dirname(path)
  635. for item in cfg.load:
  636. if item.get('file'):
  637. rel_path = item.file
  638. child_path = os.path.join(cfg_path, item.file)
  639. elif item.get('pkg'):
  640. rel_path = item.pkg
  641. child_path = os.path.join(PKG_DIR, item.pkg)
  642. if item.get('as'):
  643. child_prefix = item['as']
  644. else:
  645. child_prefix, _ = os.path.splitext(rel_path)
  646. child_cfg = load_cfg(child_path, child_prefix.split('/'))
  647. for section in load_sections:
  648. if not section in cfg:
  649. continue
  650. cfg[section].update(child_cfg.get(section, {}))
  651. return cfg
  652. def load_cli(args=None):
  653. parser = argparse.ArgumentParser()
  654. parser.add_argument('names', nargs='*',
  655. help='Hosts and commands to run them on')
  656. parser.add_argument('-c', '--config', default='bd.yaml',
  657. help='Config file')
  658. parser.add_argument('-R', '--run', nargs='*', default=[],
  659. help='Run remote task')
  660. parser.add_argument('-L', '--run-local', nargs='*', default=[],
  661. help='Run local task')
  662. parser.add_argument('-P', '--run-python', nargs='*', default=[],
  663. help='Run python task')
  664. parser.add_argument('-d', '--dry-run', action='store_true',
  665. help='Do not run actual tasks, just print them')
  666. parser.add_argument('-e', '--env', nargs='*', default=[],
  667. help='Add value to execution environment '
  668. '(ex: -e foo=bar "name=John Doe")')
  669. parser.add_argument('-s', '--sudo', default='auto',
  670. help='Enable sudo (auto|yes|no')
  671. parser.add_argument('-v', '--verbose', action='count',
  672. default=0, help='Increase verbosity')
  673. parser.add_argument('-q', '--quiet', action='count',
  674. default=0, help='Decrease verbosity')
  675. parser.add_argument('-n', '--no-color', action='store_true',
  676. help='Disable colored logs')
  677. parser.add_argument('-i', '--info', action='store_true',
  678. help='Print info')
  679. cli = parser.parse_args(args=args)
  680. cli = ObjectDict(vars(cli))
  681. # Load config
  682. cfg = load_cfg(cli.config)
  683. cli.cfg = cfg
  684. cli.update(get_hosts_and_tasks(cli, cfg))
  685. # Transformt env string into dict
  686. cli.env = dict(e.split('=') for e in cli.env)
  687. return cli
  688. def get_hosts_and_tasks(cli, cfg):
  689. # Make sure we don't have overlap between hosts and tasks
  690. items = list(cfg.networks) + list(cfg.tasks)
  691. msg = 'Name collision between tasks and networks'
  692. assert len(set(items)) == len(items), msg
  693. # Build task list
  694. tasks = []
  695. networks = []
  696. for name in cli.names:
  697. if name in cfg.networks:
  698. host = cfg.networks[name]
  699. networks.append(host)
  700. elif name in cfg.tasks:
  701. task = cfg.tasks[name]
  702. tasks.append(task)
  703. else:
  704. msg = 'Name "%s" not understood' % name
  705. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  706. if matches:
  707. msg += ', try: %s' % ' or '.join(matches)
  708. raise ByrdException (msg)
  709. # Collect custom tasks from cli
  710. customs = []
  711. for cli_key in ('run', 'run_local', 'run_python'):
  712. cmd_key = cli_key.rsplit('_', 1)[-1]
  713. customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
  714. for custom_task in customs:
  715. task = Task.parse(yaml_load(custom_task))
  716. task.desc = 'Custom command'
  717. tasks.append(task)
  718. hosts = list(chain.from_iterable(n.hosts for n in networks))
  719. return dict(hosts=hosts, tasks=tasks)
  720. def info(cli):
  721. formatter = Formatter()
  722. for name, attr in cli.cfg.tasks.items():
  723. kind = 'remote'
  724. if attr.python:
  725. kind = 'python'
  726. elif attr.local:
  727. kind = 'local'
  728. elif attr.multi:
  729. kind = 'multi'
  730. elif attr.send:
  731. kind = 'send file'
  732. print(f'{name} [{kind}]:\n\tDescription: {attr.desc}')
  733. values = []
  734. for v in attr.values():
  735. if isinstance(v, list):
  736. values.extend(v)
  737. elif isinstance(v, dict):
  738. values.extend(v.values())
  739. else:
  740. values.append(v)
  741. values = filter(lambda x: isinstance(x, str), values)
  742. fmt_fields = [i[1] for v in values for i in formatter.parse(v) if i[1]]
  743. if fmt_fields:
  744. variables = ', '.join(sorted(set(fmt_fields)))
  745. else:
  746. variables = None
  747. if variables:
  748. print(f'\tVariables: {variables}')
  749. def main():
  750. cli = None
  751. try:
  752. cli = load_cli()
  753. if not cli.no_color:
  754. enable_logging_color()
  755. cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
  756. level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
  757. log_handler.setLevel(level)
  758. logger.setLevel(level)
  759. if cli.info:
  760. info(cli)
  761. return
  762. base_env = Env(
  763. cli.env, # Highest-priority
  764. cli.cfg.get('env'),
  765. os.environ, # Lowest
  766. )
  767. for task in cli.tasks:
  768. run_batch(task, cli.hosts, cli, base_env)
  769. except ByrdException as e:
  770. if cli and cli.verbose > 2:
  771. raise
  772. abort(str(e))
  773. if __name__ == '__main__':
  774. main()