baker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from collections import ChainMap, OrderedDict
  5. import argparse
  6. import logging
  7. import os
  8. import posixpath
  9. import sys
  10. from fabric import Connection, Config
  11. from invoke import run
  12. from tanker import yaml_load
  13. import invoke
  14. import yaml
  15. try:
  16. import keyring
  17. except ImportError:
  18. keyring = None
  19. __version__ = '0.0'
  20. fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  21. logging.basicConfig(format=fmt)
  22. logger = logging.getLogger('baker')
  23. logger.setLevel(logging.INFO)
  24. try:
  25. import colorama
  26. colorama.init()
  27. MAGENTA = colorama.Fore.MAGENTA
  28. RED = colorama.Fore.RED
  29. RESET = colorama.Style.RESET_ALL
  30. # We define custom handler ..
  31. class Handler(logging.StreamHandler):
  32. def format(self, record):
  33. if record.levelname == 'INFO':
  34. record.msg = MAGENTA + record.msg + RESET
  35. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  36. record.msg = RED + record.msg + RESET
  37. return super(Handler, self).format(record)
  38. # .. and plug it
  39. handler = Handler()
  40. handler.setFormatter(logging.Formatter(fmt))
  41. logger.addHandler(handler)
  42. logger.propagate = 0
  43. except ImportError:
  44. pass
  45. def yaml_load(stream):
  46. class OrderedLoader(yaml.Loader):
  47. pass
  48. def construct_mapping(loader, node):
  49. loader.flatten_mapping(node)
  50. return OrderedDict(loader.construct_pairs(node))
  51. OrderedLoader.add_constructor(
  52. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  53. construct_mapping)
  54. return yaml.load(stream, OrderedLoader)
  55. class ObjectDict(dict):
  56. """
  57. Simple objet sub-class that allows to transform a dict into an
  58. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  59. """
  60. def __getattr__(self, key):
  61. if key in self:
  62. return self[key]
  63. return None
  64. def __setattr__(self, key, value):
  65. self[key] = value
  66. class Node:
  67. @staticmethod
  68. def fail(path, kind):
  69. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  70. logger.error(msg % (kind, ' -> '.join(path)))
  71. sys.exit()
  72. @classmethod
  73. def parse(cls, cfg, path=tuple()):
  74. children = getattr(cls, '_children', None)
  75. type_name = children and type(children).__name__ \
  76. or ' or '.join((c.__name__ for c in cls._type))
  77. res = None
  78. if type_name == 'dict':
  79. if not isinstance(cfg, dict):
  80. cls.fail(path, type_name)
  81. res = ObjectDict()
  82. for name, child_class in children.items():
  83. if name == '*':
  84. continue
  85. if name not in cfg:
  86. continue
  87. res[name] = child_class.parse(cfg.pop(name), path + (name,))
  88. if '*' in children:
  89. child_class = children['*']
  90. for name, value in cfg.items():
  91. res[name] = child_class.parse(value, path + (name,))
  92. elif type_name == 'list':
  93. if not isinstance(cfg, list):
  94. cls.fail(path, type_name)
  95. child_class = children[0]
  96. res = [child_class.parse(c, path+ ('[]',)) for c in cfg]
  97. else:
  98. if not isinstance(cfg, cls._type):
  99. cls.fail(path, type_name)
  100. res = cfg
  101. return cls.setup(res, path)
  102. @classmethod
  103. def setup(cls, values, path):
  104. return values
  105. class Atom(Node):
  106. _type = (str, bool)
  107. class AtomList(Node):
  108. _children = [Atom]
  109. class Hosts(Node):
  110. _children = [Atom]
  111. class Auth(Node):
  112. _children = {'*': Atom}
  113. class EnvNode(Node):
  114. _children = {'*': Atom}
  115. class HostGroup(Node):
  116. _children = {
  117. 'hosts': Hosts,
  118. }
  119. class Network(Node):
  120. _children = {
  121. '*': HostGroup,
  122. }
  123. class Multi(Node):
  124. _children = {
  125. '*': Atom,
  126. 'env': EnvNode,
  127. }
  128. class MultiList(Node):
  129. _children = [Multi]
  130. class Command(Node):
  131. _children = {
  132. '*': Atom,
  133. 'env': EnvNode,
  134. 'multi': MultiList,
  135. }
  136. @classmethod
  137. def setup(cls, values, path):
  138. if path:
  139. values['name'] = path[-1]
  140. if 'desc' not in values:
  141. values['desc'] = values['name']
  142. return values
  143. class Task(Node):
  144. _children = {
  145. '*': Command,
  146. }
  147. class LoadNode(Node):
  148. _children = {
  149. 'file': Atom,
  150. 'as': Atom,
  151. }
  152. class LoadList(Node):
  153. _children = [LoadNode]
  154. class ConfigRoot(Node):
  155. _children = {
  156. 'networks': Network,
  157. 'tasks': Task,
  158. 'auth': Auth,
  159. 'env': EnvNode,
  160. 'load': LoadList,
  161. }
  162. class Env(ChainMap):
  163. def __init__(self, *dicts):
  164. return super().__init__(*filter(bool, dicts))
  165. def fmt(self, string):
  166. try:
  167. return string.format(**self)
  168. except KeyError as exc:
  169. msg = 'Unable to format "%s" (missing: "%s")'
  170. logger.error(msg % (string, exc.args[0]))
  171. sys.exit()
  172. except IndexError as exc:
  173. msg = 'Unable to format "%s", positional argument not supported'
  174. logger.error(msg)
  175. sys.exit()
  176. def get_passphrase(key_path):
  177. service = 'SSH private key'
  178. csum = md5(open(key_path, 'rb').read()).digest().hex()
  179. ssh_pass = keyring.get_password(service, csum)
  180. if not ssh_pass:
  181. ssh_pass = getpass('Password for %s: ' % key_path)
  182. keyring.set_password(service, csum, ssh_pass)
  183. return ssh_pass
  184. def get_sudo_passwd():
  185. service = "Sudo password"
  186. passwd = keyring.get_password(service, '-')
  187. if not passwd:
  188. passwd = getpass('Sudo password:')
  189. keyring.set_password(service, '-', passwd)
  190. return passwd
  191. CONNECTION_CACHE = {}
  192. def connect(host, auth, with_sudo=False):
  193. if (host, with_sudo) in CONNECTION_CACHE:
  194. return CONNECTION_CACHE[host, with_sudo]
  195. connect_kwargs = {}
  196. if auth and auth.get('ssh_private_key'):
  197. connect_kwargs['key_filename'] = auth.ssh_private_key
  198. if not os.path.exists(auth.ssh_private_key):
  199. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  200. logger.error(msg)
  201. sys.exit()
  202. ssh_pass = get_passphrase(auth.ssh_private_key)
  203. connect_kwargs['password'] = ssh_pass
  204. if with_sudo:
  205. config = Config(overrides={
  206. 'sudo': {
  207. 'password': get_sudo_passwd()
  208. }
  209. })
  210. else:
  211. config = None
  212. con = Connection(host, config=config, connect_kwargs=connect_kwargs)
  213. CONNECTION_CACHE[host, with_sudo] = con
  214. return con
  215. def run_local(task, env, cli):
  216. # Run local task
  217. cmd = env.fmt(task.local)
  218. logger.info(env.fmt('RUN {task_name} locally'))
  219. if cli.dry_run:
  220. logger.info('[DRY-RUN] ' + cmd)
  221. return None
  222. res = run(cmd, env=env)
  223. return res
  224. def run_remote(task, host, env, cli):
  225. res = None
  226. host = env.fmt(host)
  227. env = env.new_child({
  228. 'host': host,
  229. })
  230. con = connect(host, cli.cfg.auth, bool(task.sudo))
  231. if task.run:
  232. cmd = env.fmt(task.run)
  233. logger.info(env.fmt('RUN {task_name} ON {host}'))
  234. if cli.dry_run:
  235. logger.info('[DRY-RUN] ' + cmd)
  236. else:
  237. res = con.run(cmd, pty=True, env=env)
  238. elif task.sudo:
  239. cmd = env.fmt(task.sudo)
  240. logger.info(env.fmt('SUDO {task_name} ON {host}'))
  241. if cli.dry_run:
  242. logger.info('[DRY-RUN] %s' + cmd)
  243. else:
  244. res = con.sudo(cmd)
  245. elif task.send:
  246. local_path = env.fmt(task.send)
  247. remote_path = env.fmt(task.to)
  248. logger.info(f'SEND {local_path} TO {host}:{remote_path}')
  249. if cli.dry_run:
  250. logger.info('[DRY-RUN]')
  251. elif os.path.isfile(local_path):
  252. con.put(local_path, remote=remote_path)
  253. else:
  254. for root, subdirs, files in os.walk(local_path):
  255. rel_dir = os.path.relpath(root, local_path)
  256. rem_dir = posixpath.join(remote_path, rel_dir)
  257. con.run('mkdir -p {}'.format(rem_dir))
  258. for f in files:
  259. rel_f = os.path.join(root, f)
  260. rem_file = posixpath.join(rem_dir, f)
  261. con.put(os.path.abspath(rel_f), remote=rem_file)
  262. else:
  263. logger.error('Unable to run task "%s"' % task.name)
  264. sys.exit()
  265. return res
  266. def run_task(task, host, cli, parent_env=None):
  267. '''
  268. Execute one task on one host (or locally)
  269. '''
  270. # Prepare environment
  271. env = Env(
  272. # Cli is top priority
  273. dict(e.split('=') for e in cli.env),
  274. # Then comes env from parent task
  275. parent_env,
  276. # Env on the task itself
  277. task.get('env'),
  278. # Top-level env
  279. cli.cfg.get('env'),
  280. # OS env
  281. os.environ,
  282. ).new_child()
  283. env.update({
  284. 'task_desc': env.fmt(task.desc),
  285. 'task_name': task.name,
  286. 'host': host,
  287. })
  288. if task.local:
  289. res = run_local(task, env, cli)
  290. else:
  291. res = run_remote(task, host, env, cli)
  292. if task.get('assert'):
  293. env.update({
  294. 'stdout': res.stdout,
  295. 'stderr': res.stderr,
  296. })
  297. assert_ = env.fmt(task['assert'])
  298. ok = eval(assert_, dict(env))
  299. if ok:
  300. logger.info('Assert ok')
  301. else:
  302. logger.error('Assert "%s" failed!' % assert_)
  303. sys.exit()
  304. return res
  305. def run_batch(task, hosts, cli, env=None):
  306. '''
  307. Run one task on a list of hosts
  308. '''
  309. env = Env(task.get('env'), env)
  310. res = None
  311. export_env = {}
  312. if task.get('multi'):
  313. for multi in task.multi:
  314. task = multi.task
  315. sub_task = cli.cfg.tasks[task]
  316. network = multi.get('network')
  317. if network:
  318. hosts = cli.cfg.networks[network].hosts
  319. child_env = multi.get('env', {}).copy()
  320. for k, v in child_env.items():
  321. # env wrap-around!
  322. child_env[k] = env.fmt(child_env[k])
  323. run_env = Env(export_env, child_env, env)
  324. res = run_batch(sub_task, hosts, cli, run_env)
  325. if multi.export:
  326. export_env[multi.export] = res and res.stdout.strip() or ''
  327. else:
  328. if task.once and task.local:
  329. res = run_task(task, None, cli, env)
  330. return res
  331. for host in hosts:
  332. res = run_task(task, host, cli, env)
  333. if task.once:
  334. break
  335. return res
  336. def base_cli(args=None):
  337. parser = argparse.ArgumentParser()
  338. parser.add_argument('names', nargs='*',
  339. help='Hosts and commands to run them on')
  340. parser.add_argument('-c', '--config', default='bk.yaml',
  341. help='Config file')
  342. parser.add_argument('-r', '--run', nargs='*', default=[],
  343. help='Run custom task')
  344. parser.add_argument('-d', '--dry-run', action='store_true',
  345. help='Do not run actual tasks, just print them')
  346. parser.add_argument('-e', '--env', nargs='*', default=[],
  347. help='Add value to execution environment '
  348. '(ex: -e foo=bar "name=John Doe")')
  349. parser.add_argument('-s', '--sudo', default='auto',
  350. help='Enable sudo (auto|yes|no')
  351. parser.add_argument('-v', '--verbose', action='count',
  352. default=0, help='Increase verbosity')
  353. cli = parser.parse_args(args=args)
  354. return ObjectDict(vars(cli))
  355. def load(path, prefix=None):
  356. load_sections = ('networks', 'tasks', 'auth', 'env')
  357. logger.info('Load config %s' % path)
  358. cfg = yaml_load(open(path))
  359. cfg = ConfigRoot.parse(cfg)
  360. # Define useful defaults
  361. cfg.networks = cfg.networks or {}
  362. cfg.tasks = cfg.tasks or {}
  363. if prefix:
  364. fn = lambda x: '/'.join(prefix + [x])
  365. # Apply prefix
  366. for section in load_sections:
  367. if not cfg.get(section):
  368. continue
  369. items = cfg[section].items()
  370. cfg[section] = {fn(k): v for k, v in items}
  371. # Recursive load
  372. if cfg.load:
  373. cfg_path = os.path.dirname(path)
  374. for item in cfg.load:
  375. if item.get('as'):
  376. child_prefix = item['as']
  377. else:
  378. child_prefix, _ = os.path.splitext(item.file)
  379. child_path = os.path.join(cfg_path, item.file)
  380. child_cfg = load(child_path, child_prefix.split('/'))
  381. for section in load_sections:
  382. if not cfg.get(section):
  383. cfg[section] = {}
  384. cfg[section].update(child_cfg.get(section, {}))
  385. return cfg
  386. def main():
  387. cli = base_cli()
  388. if cli.verbose:
  389. level = 'INFO' if cli.verbose == 1 else 'DEBUG'
  390. logger.setLevel(level)
  391. logger.info('Log level set to: %s' % level)
  392. # Load config
  393. cfg = load(cli.config)
  394. cli.cfg = cfg
  395. # Make sure we don't have overlap between hosts and tasks
  396. items = list(cfg.networks) + list(cfg.tasks)
  397. msg = 'Name collision between tasks and networks'
  398. assert len(set(items)) == len(items), msg
  399. # Build task list
  400. tasks = []
  401. networks = []
  402. for name in cli.names:
  403. if name in cfg.networks:
  404. host = cfg.networks[name]
  405. networks.append(host)
  406. elif name in cfg.tasks:
  407. task = cfg.tasks[name]
  408. tasks.append(task)
  409. else:
  410. logger.error('Name "%s" not understood' % name)
  411. sys.exit()
  412. for custom_task in cli.run:
  413. task = Command.parse(yaml_load(custom_task))
  414. task.desc = 'Custom command'
  415. tasks.append(task)
  416. hosts = list(chain.from_iterable(n.hosts for n in networks))
  417. try:
  418. for task in tasks:
  419. run_batch(task, hosts, cli)
  420. except invoke.exceptions.Failure as e:
  421. logger.error(str(e))
  422. if __name__ == '__main__':
  423. main()