baker.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from collections import ChainMap, OrderedDict, defaultdict
  5. import argparse
  6. import logging
  7. import os
  8. import posixpath
  9. import shlex
  10. import sys
  11. import spur
  12. import yaml
  13. try:
  14. import keyring
  15. except ImportError:
  16. keyring = None
  17. __version__ = '0.0'
  18. log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  19. logger = logging.getLogger('baker')
  20. logger.setLevel(logging.INFO)
  21. log_handler = logging.StreamHandler()
  22. log_handler.setLevel(logging.INFO)
  23. log_handler.setFormatter(logging.Formatter(log_fmt))
  24. logger.addHandler(log_handler)
  25. def enable_logging_color():
  26. try:
  27. import colorama
  28. except ImportError:
  29. return
  30. colorama.init()
  31. MAGENTA = colorama.Fore.MAGENTA
  32. RED = colorama.Fore.RED
  33. RESET = colorama.Style.RESET_ALL
  34. # We define custom handler ..
  35. class Handler(logging.StreamHandler):
  36. def format(self, record):
  37. if record.levelname == 'INFO':
  38. record.msg = MAGENTA + record.msg + RESET
  39. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  40. record.msg = RED + record.msg + RESET
  41. return super(Handler, self).format(record)
  42. # .. and plug it
  43. logger.removeHandler(log_handler)
  44. handler = Handler()
  45. handler.setFormatter(logging.Formatter(log_fmt))
  46. logger.addHandler(handler)
  47. logger.propagate = 0
  48. def yaml_load(stream):
  49. class OrderedLoader(yaml.Loader):
  50. pass
  51. def construct_mapping(loader, node):
  52. loader.flatten_mapping(node)
  53. return OrderedDict(loader.construct_pairs(node))
  54. OrderedLoader.add_constructor(
  55. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  56. construct_mapping)
  57. return yaml.load(stream, OrderedLoader)
  58. def edits(word):
  59. yield word
  60. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  61. for left, right in splits:
  62. if right:
  63. yield left + right[1:]
  64. def gen_candidates(wordlist):
  65. candidates = defaultdict(set)
  66. for word in wordlist:
  67. for ed1 in edits(word):
  68. for ed2 in edits(ed1):
  69. candidates[ed2].add(word)
  70. return candidates
  71. def spell(candidates, word):
  72. matches = set(chain.from_iterable(
  73. candidates[ed] for ed in edits(word) if ed in candidates
  74. ))
  75. return matches
  76. def spellcheck(objdict, word):
  77. if word in objdict:
  78. return
  79. candidates = objdict.get('_candidates')
  80. if not candidates:
  81. candidates = gen_candidates(list(objdict))
  82. objdict._candidates = candidates
  83. msg = '"%s" not found in %s' % (word, objdict._path)
  84. matches = spell(candidates, word)
  85. if matches:
  86. msg += ', try: %s' % ' or '.join(matches)
  87. abort(msg)
  88. class ObjectDict(dict):
  89. """
  90. Simple objet sub-class that allows to transform a dict into an
  91. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  92. """
  93. _meta = {}
  94. def __getattr__(self, key):
  95. if key.startswith('_'):
  96. return ObjectDict._meta[id(self), key]
  97. if key in self:
  98. return self[key]
  99. else:
  100. return None
  101. def __setattr__(self, key, value):
  102. if key.startswith('_'):
  103. ObjectDict._meta[id(self), key] = value
  104. else:
  105. self[key] = value
  106. class Node:
  107. @staticmethod
  108. def fail(path, kind):
  109. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  110. abort(msg % (kind, '->'.join(path)))
  111. @classmethod
  112. def parse(cls, cfg, path=tuple()):
  113. children = getattr(cls, '_children', None)
  114. type_name = children and type(children).__name__ \
  115. or ' or '.join((c.__name__ for c in cls._type))
  116. res = None
  117. if type_name == 'dict':
  118. if not isinstance(cfg, dict):
  119. cls.fail(path, type_name)
  120. res = ObjectDict()
  121. if '*' in children:
  122. assert len(children) == 1, "Don't mix '*' and other keys"
  123. child_class = children['*']
  124. for name, value in cfg.items():
  125. res[name] = child_class.parse(value, path + (name,))
  126. else:
  127. # Enforce known pre-defined
  128. for key in cfg:
  129. if key not in children:
  130. path = ' -> '.join(path)
  131. msg = 'Attribute "%s" not understoodin %s' % (key, path)
  132. candidates = gen_candidates(children.keys())
  133. matches = spell(candidates, key)
  134. if matches:
  135. msg += ', try: %s' % ' or '.join(matches)
  136. abort(msg)
  137. for name, child_class in children.items():
  138. if name not in cfg:
  139. continue
  140. res[name] = child_class.parse(cfg.pop(name), path + (name,))
  141. elif type_name == 'list':
  142. if not isinstance(cfg, list):
  143. cls.fail(path, type_name)
  144. child_class = children[0]
  145. res = [child_class.parse(c, path+ ('[]',)) for c in cfg]
  146. else:
  147. if not isinstance(cfg, cls._type):
  148. cls.fail(path, type_name)
  149. res = cfg
  150. return cls.setup(res, path)
  151. @classmethod
  152. def setup(cls, values, path):
  153. if isinstance(values, ObjectDict):
  154. values._path = '->'.join(path)
  155. return values
  156. class Atom(Node):
  157. _type = (str, bool)
  158. class AtomList(Node):
  159. _children = [Atom]
  160. class Hosts(Node):
  161. _children = [Atom]
  162. class Auth(Node):
  163. _children = {'*': Atom}
  164. class EnvNode(Node):
  165. _children = {'*': Atom}
  166. class HostGroup(Node):
  167. _children = {
  168. 'hosts': Hosts,
  169. }
  170. class Network(Node):
  171. _children = {
  172. '*': HostGroup,
  173. }
  174. class Multi(Node):
  175. _children = {
  176. 'task': Atom,
  177. 'export': Atom,
  178. 'python': Atom,
  179. 'network': Atom,
  180. 'env': EnvNode,
  181. }
  182. class MultiList(Node):
  183. _children = [Multi]
  184. class Command(Node):
  185. _children = {
  186. 'desc': Atom,
  187. 'local': Atom,
  188. 'python': Atom,
  189. 'once': Atom,
  190. 'run': Atom,
  191. 'send': Atom,
  192. 'to': Atom,
  193. 'assert': Atom,
  194. 'env': EnvNode,
  195. 'multi': MultiList,
  196. }
  197. @classmethod
  198. def setup(cls, values, path):
  199. if path:
  200. values['name'] = path[-1]
  201. if 'desc' not in values:
  202. values['desc'] = values['name']
  203. super().setup(values, path)
  204. return values
  205. class Task(Node):
  206. _children = {
  207. '*': Command,
  208. }
  209. class LoadNode(Node):
  210. _children = {
  211. 'file': Atom,
  212. 'as': Atom,
  213. }
  214. class LoadList(Node):
  215. _children = [LoadNode]
  216. class ConfigRoot(Node):
  217. _children = {
  218. 'networks': Network,
  219. 'tasks': Task,
  220. 'auth': Auth,
  221. 'env': EnvNode,
  222. 'load': LoadList,
  223. }
  224. class Env(ChainMap):
  225. def __init__(self, *dicts):
  226. return super().__init__(*filter(lambda x: x is not None, dicts))
  227. def fmt(self, string):
  228. try:
  229. return string.format(**self)
  230. except KeyError as exc:
  231. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  232. candidates = gen_candidates(self.keys())
  233. key = exc.args[0]
  234. matches = spell(candidates, key)
  235. if msg:
  236. msg += ', try: %s' % ' or '.join(matches)
  237. abort(msg )
  238. except IndexError as exc:
  239. msg = 'Unable to format "%s", positional argument not supported'
  240. abort(msg)
  241. def get_passphrase(key_path):
  242. service = 'SSH private key'
  243. csum = md5(open(key_path, 'rb').read()).digest().hex()
  244. ssh_pass = keyring.get_password(service, csum)
  245. if not ssh_pass:
  246. ssh_pass = getpass('Password for %s: ' % key_path)
  247. keyring.set_password(service, csum, ssh_pass)
  248. return ssh_pass
  249. def get_password(host):
  250. service = 'SSH password'
  251. ssh_pass = keyring.get_password(service, host)
  252. if not ssh_pass:
  253. ssh_pass = getpass('Password for %s: ' % host)
  254. keyring.set_password(service, host, ssh_pass)
  255. return ssh_pass
  256. def get_sudo_passwd():
  257. service = "Sudo password"
  258. passwd = keyring.get_password(service, '-')
  259. if not passwd:
  260. passwd = getpass('Sudo password:')
  261. keyring.set_password(service, '-', passwd)
  262. return passwd
  263. CONNECTION_CACHE = {}
  264. def connect(host, auth):
  265. if host in CONNECTION_CACHE:
  266. return CONNECTION_CACHE[host]
  267. private_key_file = password = None
  268. if auth and auth.get('ssh_private_key'):
  269. private_key_file = auth.ssh_private_key
  270. if not os.path.exists(auth.ssh_private_key):
  271. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  272. abort(msg)
  273. password = get_passphrase(auth.ssh_private_key)
  274. else:
  275. password = get_password(host)
  276. username, hostname = host.split('@', 1)
  277. shell = spur.SshShell(
  278. hostname=hostname,
  279. username=username,
  280. password=password,
  281. private_key_file=private_key_file,
  282. missing_host_key=spur.ssh.MissingHostKey.accept,
  283. )
  284. CONNECTION_CACHE[host] = shell
  285. return shell
  286. def subshell(command, local=False):
  287. if not isinstance(command, (list, tuple)):
  288. command = list(shlex.shlex(command))
  289. if local and sys.platform == 'win32':
  290. shell = os.environ.get('COMSPEC', 'cmd.exe')
  291. return [shell, '/c'] + command
  292. return ['sh', '-c', command]
  293. def run_local(cmd, env, cli):
  294. # Run local task
  295. cmd = env.fmt(cmd)
  296. logger.info(env.fmt('{task_desc}'))
  297. if cli.dry_run:
  298. logger.info('[DRY-RUN] ' + cmd)
  299. return None
  300. shell = spur.LocalShell()
  301. logger.debug('\n\t' + '\n\t'.join(cmd.splitlines()))
  302. res = shell.run(subshell(cmd, local=True), update_env=env)
  303. output = res.output.decode()
  304. logger.debug('\n\t' + '\n\t'.join(output.splitlines()))
  305. return output
  306. def run_python(code, env, cli):
  307. # Execute a piece of python localy
  308. logger.info(env.fmt('{task_desc}'))
  309. if cli.dry_run:
  310. logger.info('[DRY-RUN] ' + code)
  311. return None
  312. shell = spur.LocalShell()
  313. logger.debug('\n\t' + '\n\t'.join(code.splitlines()))
  314. cmd = subshell('python -c "import sys;exec(sys.stdin.read())"', local=True)
  315. proc = shell.spawn(cmd, update_env=env)
  316. proc.stdin_write(code.encode('utf-8'))
  317. proc._process_stdin.close()
  318. res = proc.wait_for_result()
  319. output = res.output.decode()
  320. logger.debug('\n\t' + '\n\t'.join(output.splitlines()))
  321. return output
  322. def run_remote(task, host, env, cli):
  323. res = None
  324. host = env.fmt(host)
  325. env.update({
  326. 'host': host,
  327. })
  328. shell = connect(host, cli.cfg.auth)
  329. if task.run:
  330. cmd = env.fmt(task.run)
  331. logger.info(env.fmt('{host}: {task_desc}'))
  332. logger.debug('\n\t' + '\n\t'.join(cmd.splitlines()))
  333. if cli.dry_run:
  334. logger.info('[DRY-RUN] ' + cmd)
  335. else:
  336. res = shell.run(subshell(cmd), update_env=env)
  337. elif task.sudo:
  338. cmd = env.fmt(task.sudo)
  339. logger.info(env.fmt('[SUDO] {host}: {task_desc}'))
  340. if cli.dry_run:
  341. logger.info('[DRY-RUN] %s' + cmd)
  342. else:
  343. res = shell.sudo(cmd)
  344. elif task.send:
  345. local_path = env.fmt(task.send)
  346. remote_path = env.fmt(task.to)
  347. logger.info(f'[SEND] {local_path} -> {host}:{remote_path}')
  348. if cli.dry_run:
  349. logger.info('[DRY-RUN]')
  350. return
  351. else:
  352. with shell._connect_sftp() as sftp:
  353. if os.path.isfile(local_path):
  354. sftp.put(local_path, remote_path)
  355. else:
  356. for root, subdirs, files in os.walk(local_path):
  357. rel_dir = os.path.relpath(root, local_path)
  358. rem_dir = posixpath.join(remote_path, rel_dir)
  359. shell.run('mkdir -p {}'.format(rem_dir))
  360. for f in files:
  361. rel_f = os.path.join(root, f)
  362. rem_file = posixpath.join(rem_dir, f)
  363. sftp.put(os.path.abspath(rel_f), rem_file)
  364. else:
  365. abort('Unable to run task "%s"' % task.name)
  366. return res
  367. def run_task(task, host, cli, parent_env=None):
  368. '''
  369. Execute one task on one host (or locally)
  370. '''
  371. # Prepare environment
  372. env = Env(
  373. # Cli is top priority
  374. dict(e.split('=') for e in cli.env),
  375. # Then comes env from parent task
  376. parent_env,
  377. # Env on the task itself
  378. task.get('env'),
  379. # Top-level env
  380. cli.cfg.get('env'),
  381. ).new_child()
  382. env.update({
  383. 'task_desc': env.fmt(task.desc),
  384. 'task_name': task.name,
  385. 'host': host or '',
  386. })
  387. if task.local:
  388. res = run_local(task.local, env, cli)
  389. elif task.python:
  390. res = run_python(task.python, env, cli)
  391. else:
  392. res = run_remote(task, host, env, cli)
  393. if task.get('assert'):
  394. env.update({
  395. 'stdout': res.output,
  396. 'stderr': res.stderr_output,
  397. })
  398. assert_ = env.fmt(task['assert'])
  399. ok = eval(assert_, dict(env))
  400. if ok:
  401. logger.info('Assert ok')
  402. else:
  403. abort('Assert "%s" failed!' % assert_)
  404. return res
  405. def run_batch(task, hosts, cli, env=None):
  406. '''
  407. Run one task on a list of hosts
  408. '''
  409. env = Env(task.get('env'), env)
  410. res = None
  411. export_env = {}
  412. if task.get('multi'):
  413. for multi in task.multi:
  414. task = multi.task
  415. spellcheck(cli.cfg.tasks, task)
  416. sub_task = cli.cfg.tasks[task]
  417. network = multi.get('network')
  418. if network:
  419. spellcheck(cli.cfg.networks, network)
  420. hosts = cli.cfg.networks[network].hosts
  421. child_env = multi.get('env', {}).copy()
  422. for k, v in child_env.items():
  423. # env wrap-around!
  424. child_env[k] = env.fmt(child_env[k])
  425. run_env = Env(export_env, child_env, env)
  426. res = run_batch(sub_task, hosts, cli, run_env)
  427. if multi.export:
  428. export_env[multi.export] = res and res.output.strip() or ''
  429. else:
  430. if task.once and (task.local or task.python):
  431. res = run_task(task, None, cli, env)
  432. return res
  433. for host in hosts:
  434. res = run_task(task, host, cli, env)
  435. if task.once:
  436. break
  437. return res
  438. def abort(msg):
  439. logger.error(msg)
  440. sys.exit(1)
  441. def load(path, prefix=None):
  442. load_sections = ('networks', 'tasks', 'auth', 'env')
  443. if os.path.isfile(path):
  444. logger.info('Load config %s' % path)
  445. cfg = yaml_load(open(path))
  446. cfg = ConfigRoot.parse(cfg)
  447. else:
  448. abort('Config file "%s" not found' % path)
  449. # Define useful defaults
  450. cfg.networks = cfg.networks or ObjectDict()
  451. cfg.tasks = cfg.tasks or ObjectDict()
  452. if prefix:
  453. fn = lambda x: '/'.join(prefix + [x])
  454. # Apply prefix
  455. for section in load_sections:
  456. if not cfg.get(section):
  457. continue
  458. items = cfg[section].items()
  459. cfg[section] = {fn(k): v for k, v in items}
  460. # Recursive load
  461. if cfg.load:
  462. cfg_path = os.path.dirname(path)
  463. for item in cfg.load:
  464. if item.get('as'):
  465. child_prefix = item['as']
  466. else:
  467. child_prefix, _ = os.path.splitext(item.file)
  468. child_path = os.path.join(cfg_path, item.file)
  469. child_cfg = load(child_path, child_prefix.split('/'))
  470. for section in load_sections:
  471. if not cfg.get(section):
  472. cfg[section] = {}
  473. cfg[section].update(child_cfg.get(section, {}))
  474. return cfg
  475. def base_cli(args=None):
  476. parser = argparse.ArgumentParser()
  477. parser.add_argument('names', nargs='*',
  478. help='Hosts and commands to run them on')
  479. parser.add_argument('-c', '--config', default='bk.yaml',
  480. help='Config file')
  481. parser.add_argument('-r', '--run', nargs='*', default=[],
  482. help='Run custom task')
  483. parser.add_argument('-d', '--dry-run', action='store_true',
  484. help='Do not run actual tasks, just print them')
  485. parser.add_argument('-e', '--env', nargs='*', default=[],
  486. help='Add value to execution environment '
  487. '(ex: -e foo=bar "name=John Doe")')
  488. parser.add_argument('-s', '--sudo', default='auto',
  489. help='Enable sudo (auto|yes|no')
  490. parser.add_argument('-v', '--verbose', action='count',
  491. default=0, help='Increase verbosity')
  492. parser.add_argument('-q', '--quiet', action='count',
  493. default=0, help='Decrease verbosity')
  494. parser.add_argument('-n', '--no-color', action='store_true',
  495. help='Disable colored logs')
  496. cli = parser.parse_args(args=args)
  497. return ObjectDict(vars(cli))
  498. def main():
  499. cli = base_cli()
  500. if not cli.no_color:
  501. enable_logging_color()
  502. cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
  503. level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
  504. log_handler.setLevel(level)
  505. logger.setLevel(level)
  506. # Load config
  507. cfg = load(cli.config)
  508. cli.cfg = cfg
  509. # Make sure we don't have overlap between hosts and tasks
  510. items = list(cfg.networks) + list(cfg.tasks)
  511. msg = 'Name collision between tasks and networks'
  512. assert len(set(items)) == len(items), msg
  513. # Build task list
  514. tasks = []
  515. networks = []
  516. for name in cli.names:
  517. if name in cfg.networks:
  518. host = cfg.networks[name]
  519. networks.append(host)
  520. elif name in cfg.tasks:
  521. task = cfg.tasks[name]
  522. tasks.append(task)
  523. else:
  524. msg = 'Name "%s" not understood' % name
  525. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  526. if matches:
  527. msg += ', try: %s' % ' or '.join(matches)
  528. abort(msg)
  529. for custom_task in cli.run:
  530. task = Command.parse(yaml_load(custom_task))
  531. task.desc = 'Custom command'
  532. tasks.append(task)
  533. hosts = list(chain.from_iterable(n.hosts for n in networks))
  534. try:
  535. for task in tasks:
  536. run_batch(task, hosts, cli)
  537. except Exception as e:
  538. # TODO intercept only spur exceptions
  539. if cli.verbose > 2:
  540. raise
  541. abort(str(e))
  542. if __name__ == '__main__':
  543. main()