baker.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from collections import ChainMap, OrderedDict, defaultdict
  5. import argparse
  6. import logging
  7. import os
  8. import posixpath
  9. import sys
  10. from fabric import Connection, Config
  11. from invoke import run
  12. import invoke
  13. import yaml
  14. try:
  15. import keyring
  16. except ImportError:
  17. keyring = None
  18. __version__ = '0.0'
  19. fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  20. logging.basicConfig(format=fmt)
  21. logger = logging.getLogger('baker')
  22. logger.setLevel(logging.INFO)
  23. try:
  24. import colorama
  25. colorama.init()
  26. MAGENTA = colorama.Fore.MAGENTA
  27. RED = colorama.Fore.RED
  28. RESET = colorama.Style.RESET_ALL
  29. # We define custom handler ..
  30. class Handler(logging.StreamHandler):
  31. def format(self, record):
  32. if record.levelname == 'INFO':
  33. record.msg = MAGENTA + record.msg + RESET
  34. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  35. record.msg = RED + record.msg + RESET
  36. return super(Handler, self).format(record)
  37. # .. and plug it
  38. handler = Handler()
  39. handler.setFormatter(logging.Formatter(fmt))
  40. logger.addHandler(handler)
  41. logger.propagate = 0
  42. except ImportError:
  43. pass
  44. def yaml_load(stream):
  45. class OrderedLoader(yaml.Loader):
  46. pass
  47. def construct_mapping(loader, node):
  48. loader.flatten_mapping(node)
  49. return OrderedDict(loader.construct_pairs(node))
  50. OrderedLoader.add_constructor(
  51. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  52. construct_mapping)
  53. return yaml.load(stream, OrderedLoader)
  54. def edits(word):
  55. yield word
  56. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  57. for left, right in splits:
  58. if right:
  59. yield left + right[1:]
  60. def gen_candidates(wordlist):
  61. candidates = defaultdict(set)
  62. for word in wordlist:
  63. for ed1 in edits(word):
  64. for ed2 in edits(ed1):
  65. candidates[ed2].add(word)
  66. return candidates
  67. def spell(candidates, word):
  68. matches = set(chain.from_iterable(
  69. candidates[ed] for ed in edits(word) if ed in candidates
  70. ))
  71. return matches
  72. def spellcheck(objdict, word):
  73. if word in objdict:
  74. return
  75. candidates = objdict.get('_candidates')
  76. if not candidates:
  77. candidates = gen_candidates(list(objdict))
  78. objdict._candidates = candidates
  79. msg = '"%s" not found in %s' % (word, objdict._path)
  80. matches = spell(candidates, word)
  81. if matches:
  82. msg += ', try: %s' % ' or '.join(matches)
  83. abort(msg)
  84. class ObjectDict(dict):
  85. """
  86. Simple objet sub-class that allows to transform a dict into an
  87. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  88. """
  89. _meta = {}
  90. def __getattr__(self, key):
  91. if key.startswith('_'):
  92. return ObjectDict._meta[id(self), key]
  93. if key in self:
  94. return self[key]
  95. else:
  96. return None
  97. def __setattr__(self, key, value):
  98. if key.startswith('_'):
  99. ObjectDict._meta[id(self), key] = value
  100. else:
  101. self[key] = value
  102. class Node:
  103. @staticmethod
  104. def fail(path, kind):
  105. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  106. abort(msg % (kind, '->'.join(path)))
  107. @classmethod
  108. def parse(cls, cfg, path=tuple()):
  109. children = getattr(cls, '_children', None)
  110. type_name = children and type(children).__name__ \
  111. or ' or '.join((c.__name__ for c in cls._type))
  112. res = None
  113. if type_name == 'dict':
  114. if not isinstance(cfg, dict):
  115. cls.fail(path, type_name)
  116. res = ObjectDict()
  117. if '*' in children:
  118. assert len(children) == 1, "Don't mix '*' and other keys"
  119. child_class = children['*']
  120. for name, value in cfg.items():
  121. res[name] = child_class.parse(value, path + (name,))
  122. else:
  123. # Enforce known pre-defined
  124. for key in cfg:
  125. if key not in children:
  126. path = ' -> '.join(path)
  127. msg = 'Attribute "%s" not understoodin %s' % (key, path)
  128. candidates = gen_candidates(children.keys())
  129. matches = spell(candidates, key)
  130. if matches:
  131. msg += ', try: %s' % ' or '.join(matches)
  132. abort(msg)
  133. for name, child_class in children.items():
  134. if name not in cfg:
  135. continue
  136. res[name] = child_class.parse(cfg.pop(name), path + (name,))
  137. elif type_name == 'list':
  138. if not isinstance(cfg, list):
  139. cls.fail(path, type_name)
  140. child_class = children[0]
  141. res = [child_class.parse(c, path+ ('[]',)) for c in cfg]
  142. else:
  143. if not isinstance(cfg, cls._type):
  144. cls.fail(path, type_name)
  145. res = cfg
  146. return cls.setup(res, path)
  147. @classmethod
  148. def setup(cls, values, path):
  149. if isinstance(values, ObjectDict):
  150. values._path = '->'.join(path)
  151. return values
  152. class Atom(Node):
  153. _type = (str, bool)
  154. class AtomList(Node):
  155. _children = [Atom]
  156. class Hosts(Node):
  157. _children = [Atom]
  158. class Auth(Node):
  159. _children = {'*': Atom}
  160. class EnvNode(Node):
  161. _children = {'*': Atom}
  162. class HostGroup(Node):
  163. _children = {
  164. 'hosts': Hosts,
  165. }
  166. class Network(Node):
  167. _children = {
  168. '*': HostGroup,
  169. }
  170. class Multi(Node):
  171. _children = {
  172. 'task': Atom,
  173. 'export': Atom,
  174. 'network': Atom,
  175. 'env': EnvNode,
  176. }
  177. class MultiList(Node):
  178. _children = [Multi]
  179. class Command(Node):
  180. _children = {
  181. 'desc': Atom,
  182. 'local': Atom,
  183. 'once': Atom,
  184. 'run': Atom,
  185. 'send': Atom,
  186. 'to': Atom,
  187. 'assert': Atom,
  188. 'env': EnvNode,
  189. 'multi': MultiList,
  190. }
  191. @classmethod
  192. def setup(cls, values, path):
  193. if path:
  194. values['name'] = path[-1]
  195. if 'desc' not in values:
  196. values['desc'] = values['name']
  197. super().setup(values, path)
  198. return values
  199. class Task(Node):
  200. _children = {
  201. '*': Command,
  202. }
  203. class LoadNode(Node):
  204. _children = {
  205. 'file': Atom,
  206. 'as': Atom,
  207. }
  208. class LoadList(Node):
  209. _children = [LoadNode]
  210. class ConfigRoot(Node):
  211. _children = {
  212. 'networks': Network,
  213. 'tasks': Task,
  214. 'auth': Auth,
  215. 'env': EnvNode,
  216. 'load': LoadList,
  217. }
  218. class Env(ChainMap):
  219. def __init__(self, *dicts):
  220. return super().__init__(*filter(lambda x: x is not None, dicts))
  221. def fmt(self, string):
  222. try:
  223. return string.format(**self)
  224. except KeyError as exc:
  225. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  226. candidates = gen_candidates(self.keys())
  227. key = exc.args[0]
  228. matches = spell(candidates, key)
  229. if msg:
  230. msg += ', try: %s' % ' or '.join(matches)
  231. abort(msg )
  232. except IndexError as exc:
  233. msg = 'Unable to format "%s", positional argument not supported'
  234. abort(msg)
  235. def get_passphrase(key_path):
  236. service = 'SSH private key'
  237. csum = md5(open(key_path, 'rb').read()).digest().hex()
  238. ssh_pass = keyring.get_password(service, csum)
  239. if not ssh_pass:
  240. ssh_pass = getpass('Password for %s: ' % key_path)
  241. keyring.set_password(service, csum, ssh_pass)
  242. return ssh_pass
  243. def get_sudo_passwd():
  244. service = "Sudo password"
  245. passwd = keyring.get_password(service, '-')
  246. if not passwd:
  247. passwd = getpass('Sudo password:')
  248. keyring.set_password(service, '-', passwd)
  249. return passwd
  250. CONNECTION_CACHE = {}
  251. def connect(host, auth, with_sudo=False):
  252. if (host, with_sudo) in CONNECTION_CACHE:
  253. return CONNECTION_CACHE[host, with_sudo]
  254. connect_kwargs = {}
  255. if auth and auth.get('ssh_private_key'):
  256. connect_kwargs['key_filename'] = auth.ssh_private_key
  257. if not os.path.exists(auth.ssh_private_key):
  258. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  259. abort(msg)
  260. ssh_pass = get_passphrase(auth.ssh_private_key)
  261. connect_kwargs['password'] = ssh_pass
  262. if with_sudo:
  263. config = Config(overrides={
  264. 'sudo': {
  265. 'password': get_sudo_passwd()
  266. }
  267. })
  268. else:
  269. config = None
  270. con = Connection(host, config=config, connect_kwargs=connect_kwargs)
  271. CONNECTION_CACHE[host, with_sudo] = con
  272. return con
  273. def run_local(task, env, cli):
  274. # Run local task
  275. cmd = env.fmt(task.local)
  276. # TODO log only task_desc and let desc contains env info like {host}
  277. logger.info(env.fmt('RUN {task_name} locally'))
  278. if cli.dry_run:
  279. logger.info('[DRY-RUN] ' + cmd)
  280. return None
  281. res = run(cmd, env=env)
  282. return res
  283. def run_remote(task, host, env, cli):
  284. res = None
  285. host = env.fmt(host)
  286. env = env.new_child({
  287. 'host': host,
  288. })
  289. con = connect(host, cli.cfg.auth, bool(task.sudo))
  290. if task.run:
  291. cmd = env.fmt(task.run)
  292. logger.info(env.fmt('RUN {task_name} ON {host}'))
  293. if cli.dry_run:
  294. logger.info('[DRY-RUN] ' + cmd)
  295. else:
  296. res = con.run(cmd, pty=True, env=env)
  297. elif task.sudo:
  298. cmd = env.fmt(task.sudo)
  299. logger.info(env.fmt('SUDO {task_name} ON {host}'))
  300. if cli.dry_run:
  301. logger.info('[DRY-RUN] %s' + cmd)
  302. else:
  303. res = con.sudo(cmd)
  304. elif task.send:
  305. local_path = env.fmt(task.send)
  306. remote_path = env.fmt(task.to)
  307. logger.info(f'SEND {local_path} TO {host}:{remote_path}')
  308. if cli.dry_run:
  309. logger.info('[DRY-RUN]')
  310. elif os.path.isfile(local_path):
  311. con.put(local_path, remote=remote_path)
  312. else:
  313. for root, subdirs, files in os.walk(local_path):
  314. rel_dir = os.path.relpath(root, local_path)
  315. rem_dir = posixpath.join(remote_path, rel_dir)
  316. con.run('mkdir -p {}'.format(rem_dir))
  317. for f in files:
  318. rel_f = os.path.join(root, f)
  319. rem_file = posixpath.join(rem_dir, f)
  320. con.put(os.path.abspath(rel_f), remote=rem_file)
  321. else:
  322. abort('Unable to run task "%s"' % task.name)
  323. return res
  324. def run_task(task, host, cli, parent_env=None):
  325. '''
  326. Execute one task on one host (or locally)
  327. '''
  328. # Prepare environment
  329. env = Env(
  330. # Cli is top priority
  331. dict(e.split('=') for e in cli.env),
  332. # Then comes env from parent task
  333. parent_env,
  334. # Env on the task itself
  335. task.get('env'),
  336. # Top-level env
  337. cli.cfg.get('env'),
  338. # OS env
  339. os.environ,
  340. ).new_child()
  341. env.update({
  342. 'task_desc': env.fmt(task.desc),
  343. 'task_name': task.name,
  344. 'host': host or '',
  345. })
  346. if task.local:
  347. res = run_local(task, env, cli)
  348. else:
  349. res = run_remote(task, host, env, cli)
  350. if task.get('assert'):
  351. env.update({
  352. 'stdout': res.stdout,
  353. 'stderr': res.stderr,
  354. })
  355. assert_ = env.fmt(task['assert'])
  356. ok = eval(assert_, dict(env))
  357. if ok:
  358. logger.info('Assert ok')
  359. else:
  360. abort('Assert "%s" failed!' % assert_)
  361. return res
  362. def run_batch(task, hosts, cli, env=None):
  363. '''
  364. Run one task on a list of hosts
  365. '''
  366. env = Env(task.get('env'), env)
  367. res = None
  368. export_env = {}
  369. if task.get('multi'):
  370. for multi in task.multi:
  371. task = multi.task
  372. spellcheck(cli.cfg.tasks, task)
  373. sub_task = cli.cfg.tasks[task]
  374. network = multi.get('network')
  375. if network:
  376. spellcheck(cli.cfg.networks, network)
  377. hosts = cli.cfg.networks[network].hosts
  378. child_env = multi.get('env', {}).copy()
  379. for k, v in child_env.items():
  380. # env wrap-around!
  381. child_env[k] = env.fmt(child_env[k])
  382. run_env = Env(export_env, child_env, env)
  383. res = run_batch(sub_task, hosts, cli, run_env)
  384. if multi.export:
  385. export_env[multi.export] = res and res.stdout.strip() or ''
  386. else:
  387. if task.once and task.local:
  388. res = run_task(task, None, cli, env)
  389. return res
  390. for host in hosts:
  391. res = run_task(task, host, cli, env)
  392. if task.once:
  393. break
  394. return res
  395. def abort(msg):
  396. logger.error(msg)
  397. sys.exit(1)
  398. def load(path, prefix=None):
  399. load_sections = ('networks', 'tasks', 'auth', 'env')
  400. if os.path.isfile(path):
  401. logger.info('Load config %s' % path)
  402. cfg = yaml_load(open(path))
  403. cfg = ConfigRoot.parse(cfg)
  404. else:
  405. abort('Config file "%s" not found' % path)
  406. # Define useful defaults
  407. cfg.networks = cfg.networks or ObjectDict()
  408. cfg.tasks = cfg.tasks or ObjectDict()
  409. if prefix:
  410. fn = lambda x: '/'.join(prefix + [x])
  411. # Apply prefix
  412. for section in load_sections:
  413. if not cfg.get(section):
  414. continue
  415. items = cfg[section].items()
  416. cfg[section] = {fn(k): v for k, v in items}
  417. # Recursive load
  418. if cfg.load:
  419. cfg_path = os.path.dirname(path)
  420. for item in cfg.load:
  421. if item.get('as'):
  422. child_prefix = item['as']
  423. else:
  424. child_prefix, _ = os.path.splitext(item.file)
  425. child_path = os.path.join(cfg_path, item.file)
  426. child_cfg = load(child_path, child_prefix.split('/'))
  427. for section in load_sections:
  428. if not cfg.get(section):
  429. cfg[section] = {}
  430. cfg[section].update(child_cfg.get(section, {}))
  431. return cfg
  432. def base_cli(args=None):
  433. parser = argparse.ArgumentParser()
  434. parser.add_argument('names', nargs='*',
  435. help='Hosts and commands to run them on')
  436. parser.add_argument('-c', '--config', default='bk.yaml',
  437. help='Config file')
  438. parser.add_argument('-r', '--run', nargs='*', default=[],
  439. help='Run custom task')
  440. parser.add_argument('-d', '--dry-run', action='store_true',
  441. help='Do not run actual tasks, just print them')
  442. parser.add_argument('-e', '--env', nargs='*', default=[],
  443. help='Add value to execution environment '
  444. '(ex: -e foo=bar "name=John Doe")')
  445. parser.add_argument('-s', '--sudo', default='auto',
  446. help='Enable sudo (auto|yes|no')
  447. parser.add_argument('-v', '--verbose', action='count',
  448. default=0, help='Increase verbosity')
  449. cli = parser.parse_args(args=args)
  450. return ObjectDict(vars(cli))
  451. def main():
  452. cli = base_cli()
  453. if cli.verbose:
  454. level = 'INFO' if cli.verbose == 1 else 'DEBUG'
  455. logger.setLevel(level)
  456. logger.info('Log level set to: %s' % level)
  457. # Load config
  458. cfg = load(cli.config)
  459. cli.cfg = cfg
  460. # Make sure we don't have overlap between hosts and tasks
  461. items = list(cfg.networks) + list(cfg.tasks)
  462. msg = 'Name collision between tasks and networks'
  463. assert len(set(items)) == len(items), msg
  464. # Build task list
  465. tasks = []
  466. networks = []
  467. for name in cli.names:
  468. if name in cfg.networks:
  469. host = cfg.networks[name]
  470. networks.append(host)
  471. elif name in cfg.tasks:
  472. task = cfg.tasks[name]
  473. tasks.append(task)
  474. else:
  475. msg = 'Name "%s" not understood' % name
  476. matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
  477. if matches:
  478. msg += ', try: %s' % ' or '.join(matches)
  479. abort(msg)
  480. for custom_task in cli.run:
  481. task = Command.parse(yaml_load(custom_task))
  482. task.desc = 'Custom command'
  483. tasks.append(task)
  484. hosts = list(chain.from_iterable(n.hosts for n in networks))
  485. try:
  486. for task in tasks:
  487. run_batch(task, hosts, cli)
  488. except invoke.exceptions.Failure as e:
  489. logger.error(str(e))
  490. if __name__ == '__main__':
  491. main()