Pārlūkot izejas kodu

Get rid of Fabric & Invoke

Bertrand Chenal 7 gadi atpakaļ
vecāks
revīzija
6efd1e6471
4 mainītis faili ar 166 papildinājumiem un 74 dzēšanām
  1. 127 68
      baker.py
  2. 12 0
      examples/bk1.yaml
  3. 3 2
      tests/base_test.py
  4. 24 4
      tests/conftest.py

+ 127 - 68
baker.py

@@ -6,11 +6,10 @@ import argparse
 import logging
 import os
 import posixpath
+import shlex
 import sys
 
-from fabric import Connection, Config
-from invoke import run
-import invoke
+import spur
 import yaml
 
 
@@ -21,13 +20,22 @@ except ImportError:
 
 __version__ = '0.0'
 
-fmt = '%(levelname)s:%(asctime).19s: %(message)s'
-logging.basicConfig(format=fmt)
+
+log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
 logger = logging.getLogger('baker')
 logger.setLevel(logging.INFO)
+log_handler = logging.StreamHandler()
+log_handler.setLevel(logging.INFO)
+log_handler.setFormatter(logging.Formatter(log_fmt))
+logger.addHandler(log_handler)
+
+
+def enable_logging_color():
+    try:
+        import colorama
+    except ImportError:
+        return
 
-try:
-    import colorama
     colorama.init()
     MAGENTA = colorama.Fore.MAGENTA
     RED = colorama.Fore.RED
@@ -43,12 +51,11 @@ try:
             return super(Handler, self).format(record)
 
     #  .. and plug it
+    logger.removeHandler(log_handler)
     handler = Handler()
-    handler.setFormatter(logging.Formatter(fmt))
+    handler.setFormatter(logging.Formatter(log_fmt))
     logger.addHandler(handler)
     logger.propagate = 0
-except ImportError:
-    pass
 
 
 def yaml_load(stream):
@@ -215,6 +222,7 @@ class Multi(Node):
     _children = {
         'task': Atom,
         'export': Atom,
+        'python': Atom,
         'network': Atom,
         'env': EnvNode,
     }
@@ -226,6 +234,7 @@ class Command(Node):
     _children = {
         'desc': Atom,
         'local': Atom,
+        'python': Atom,
         'once': Atom,
         'run': Atom,
         'send': Atom,
@@ -299,6 +308,15 @@ def get_passphrase(key_path):
     return ssh_pass
 
 
+def get_password(host):
+    service = 'SSH password'
+    ssh_pass = keyring.get_password(service, host)
+    if not ssh_pass:
+        ssh_pass = getpass('Password for %s: ' % host)
+        keyring.set_password(service, host, ssh_pass)
+    return ssh_pass
+
+
 def get_sudo_passwd():
     service = "Sudo password"
     passwd = keyring.get_password(service, '-')
@@ -309,86 +327,117 @@ def get_sudo_passwd():
 
 
 CONNECTION_CACHE = {}
-def connect(host, auth, with_sudo=False):
-    if (host, with_sudo) in CONNECTION_CACHE:
-        return CONNECTION_CACHE[host, with_sudo]
+def connect(host, auth):
+    if host in CONNECTION_CACHE:
+        return CONNECTION_CACHE[host]
 
-    connect_kwargs = {}
+    private_key_file = password = None
     if auth and auth.get('ssh_private_key'):
-        connect_kwargs['key_filename'] = auth.ssh_private_key
+        private_key_file = auth.ssh_private_key
         if not os.path.exists(auth.ssh_private_key):
             msg = 'Private key file "%s" not found' % auth.ssh_private_key
             abort(msg)
-        ssh_pass = get_passphrase(auth.ssh_private_key)
-        connect_kwargs['password'] = ssh_pass
-
-    if with_sudo:
-        config = Config(overrides={
-            'sudo': {
-                'password': get_sudo_passwd()
-            }
-        })
+        password = get_passphrase(auth.ssh_private_key)
     else:
-        config = None
-
-    con = Connection(host, config=config, connect_kwargs=connect_kwargs)
-    CONNECTION_CACHE[host, with_sudo] = con
-    return con
-
-
-def run_local(task, env, cli):
+        password = get_password(host)
+
+    username, hostname = host.split('@', 1)
+    shell = spur.SshShell(
+        hostname=hostname,
+        username=username,
+        password=password,
+        private_key_file=private_key_file,
+        missing_host_key=spur.ssh.MissingHostKey.accept,
+    )
+
+    CONNECTION_CACHE[host] = shell
+    return shell
+
+
+def subshell(command, local=False):
+    if not isinstance(command, (list, tuple)):
+        command = list(shlex.shlex(command))
+    if local and sys.platform == 'win32':
+        shell = os.environ.get('COMSPEC', 'cmd.exe')
+        return [shell, '/c'] + command
+    return ['sh', '-c', command]
+
+def run_local(cmd, env, cli):
     # Run local task
-    cmd = env.fmt(task.local)
-    # TODO log only task_desc and let desc contains env info like {host}
-    logger.info(env.fmt('RUN {task_name} locally'))
+    cmd = env.fmt(cmd)
+    logger.info(env.fmt('{task_desc}'))
     if cli.dry_run:
         logger.info('[DRY-RUN] ' + cmd)
         return None
-    res = run(cmd, env=env)
-    return res
-
+    shell = spur.LocalShell()
+    logger.debug('\n\t' + '\n\t'.join(cmd.splitlines()))
+    res = shell.run(subshell(cmd, local=True), update_env=env)
+    output = res.output.decode()
+    logger.debug('\n\t' + '\n\t'.join(output.splitlines()))
+    return output
+
+def run_python(code, env, cli):
+    # Execute a piece of python localy
+    logger.info(env.fmt('{task_desc}'))
+    if cli.dry_run:
+        logger.info('[DRY-RUN] ' + code)
+        return None
+    shell = spur.LocalShell()
+    logger.debug('\n\t' + '\n\t'.join(code.splitlines()))
+    cmd = subshell('python -c "import sys;exec(sys.stdin.read())"', local=True)
+    proc = shell.spawn(cmd, update_env=env)
+    proc.stdin_write(code.encode('utf-8'))
+    proc._process_stdin.close()
+    res = proc.wait_for_result()
+    output = res.output.decode()
+    logger.debug('\n\t' + '\n\t'.join(output.splitlines()))
+    return output
 
 def run_remote(task, host, env, cli):
     res = None
     host = env.fmt(host)
-    env = env.new_child({
+    env.update({
         'host': host,
     })
-    con = connect(host, cli.cfg.auth, bool(task.sudo))
+    shell = connect(host, cli.cfg.auth)
     if task.run:
         cmd = env.fmt(task.run)
-        logger.info(env.fmt('RUN {task_name} ON {host}'))
+        logger.info(env.fmt('{host}: {task_desc}'))
+        logger.debug('\n\t' + '\n\t'.join(cmd.splitlines()))
         if cli.dry_run:
             logger.info('[DRY-RUN] ' + cmd)
         else:
-            res = con.run(cmd, pty=True, env=env)
+            res = shell.run(subshell(cmd), update_env=env)
 
     elif task.sudo:
         cmd = env.fmt(task.sudo)
-        logger.info(env.fmt('SUDO {task_name} ON {host}'))
+        logger.info(env.fmt('[SUDO] {host}: {task_desc}'))
 
         if cli.dry_run:
             logger.info('[DRY-RUN] %s' + cmd)
         else:
-            res = con.sudo(cmd)
+            res = shell.sudo(cmd)
 
     elif task.send:
         local_path = env.fmt(task.send)
         remote_path = env.fmt(task.to)
-        logger.info(f'SEND {local_path} TO {host}:{remote_path}')
+        logger.info(f'[SEND] {local_path} -> {host}:{remote_path}')
         if cli.dry_run:
             logger.info('[DRY-RUN]')
-        elif os.path.isfile(local_path):
-            con.put(local_path, remote=remote_path)
+            return
         else:
-            for root, subdirs, files in os.walk(local_path):
-                rel_dir = os.path.relpath(root, local_path)
-                rem_dir = posixpath.join(remote_path, rel_dir)
-                con.run('mkdir -p {}'.format(rem_dir))
-                for f in files:
-                    rel_f = os.path.join(root, f)
-                    rem_file = posixpath.join(rem_dir, f)
-                    con.put(os.path.abspath(rel_f), remote=rem_file)
+            with shell._connect_sftp() as sftp:
+                if os.path.isfile(local_path):
+                    sftp.put(local_path, remote_path)
+                else:
+                    for root, subdirs, files in os.walk(local_path):
+                        rel_dir = os.path.relpath(root, local_path)
+                        rem_dir = posixpath.join(remote_path, rel_dir)
+                        shell.run('mkdir -p {}'.format(rem_dir))
+                        for f in files:
+                            rel_f = os.path.join(root, f)
+                            rem_file = posixpath.join(rem_dir, f)
+                            sftp.put(os.path.abspath(rel_f), rem_file)
     else:
         abort('Unable to run task "%s"' % task.name)
 
@@ -410,8 +459,6 @@ def run_task(task, host, cli, parent_env=None):
         task.get('env'),
         # Top-level env
         cli.cfg.get('env'),
-        # OS env
-        os.environ,
     ).new_child()
 
     env.update({
@@ -419,15 +466,18 @@ def run_task(task, host, cli, parent_env=None):
         'task_name': task.name,
         'host': host or '',
     })
+
     if task.local:
-        res = run_local(task, env, cli)
+        res = run_local(task.local, env, cli)
+    elif task.python:
+        res = run_python(task.python, env, cli)
     else:
         res = run_remote(task, host, env, cli)
 
     if task.get('assert'):
         env.update({
-            'stdout': res.stdout,
-            'stderr': res.stderr,
+            'stdout': res.output,
+            'stderr': res.stderr_output,
         })
         assert_ = env.fmt(task['assert'])
         ok = eval(assert_, dict(env))
@@ -462,10 +512,10 @@ def run_batch(task, hosts, cli, env=None):
             run_env = Env(export_env, child_env, env)
             res = run_batch(sub_task, hosts, cli, run_env)
             if multi.export:
-                export_env[multi.export] = res and res.stdout.strip() or ''
+                export_env[multi.export] = res and res.output.strip() or ''
 
     else:
-        if task.once and task.local:
+        if task.once and (task.local or task.python):
             res = run_task(task, None, cli, env)
             return res
         for host in hosts:
@@ -537,16 +587,22 @@ def base_cli(args=None):
                         help='Enable sudo (auto|yes|no')
     parser.add_argument('-v', '--verbose', action='count',
                         default=0, help='Increase verbosity')
+    parser.add_argument('-q', '--quiet', action='count',
+                        default=0, help='Decrease verbosity')
+    parser.add_argument('-n', '--no-color', action='store_true',
+                        help='Disable colored logs')
     cli = parser.parse_args(args=args)
     return ObjectDict(vars(cli))
 
 
 def main():
     cli = base_cli()
-    if cli.verbose:
-        level = 'INFO' if cli.verbose == 1 else 'DEBUG'
-        logger.setLevel(level)
-        logger.info('Log level set to: %s' % level)
+    if not cli.no_color:
+        enable_logging_color()
+    cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
+    level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
+    log_handler.setLevel(level)
+    logger.setLevel(level)
 
     # Load config
     cfg = load(cli.config)
@@ -584,8 +640,11 @@ def main():
     try:
         for task in tasks:
             run_batch(task, hosts, cli)
-    except invoke.exceptions.Failure as e:
-        logger.error(str(e))
+    except Exception as e:
+        # TODO intercept only spur exceptions
+        if cli.verbose > 2:
+            raise
+        abort(str(e))
 
 
 if __name__ == '__main__':

+ 12 - 0
examples/bk1.yaml

@@ -17,3 +17,15 @@ tasks:
     desc: Print current time (on local machine)
     local: date -Iseconds
     once: true
+
+  hello:
+    desc: Word count on "Hello"
+    local: echo 'Hello' | wc
+    once: true
+
+  hello-python:
+    desc: Says hello with python
+    python: |
+      for i in range(10):
+          print('hello')
+    once: true

+ 3 - 2
tests/base_test.py

@@ -1,9 +1,10 @@
 from baker import base_cli, run_batch
 
 
-def test_all_conf(cfg):
+def test_all_conf(cfg, nominal_log, log_handler):
     cli = base_cli(['--dry-run'])
     cli.cfg = cfg
-    import pdb;pdb.set_trace()
     for task in cfg.tasks.values():
         run_batch(task, [], cli)
+
+    assert nominal_log == log_handler.getvalue()

+ 24 - 4
tests/conftest.py

@@ -1,12 +1,32 @@
 import glob
+import io
+import logging
 
-from baker import load
+import pytest
+
+from baker import load, logger, log_handler
+
+# Disable default handler
+logger.removeHandler(log_handler)
+
+
+@pytest.yield_fixture(scope='function')
+def log_handler(request):
+    buff = io.StringIO()
+    handler = logging.StreamHandler(buff)
+    logger.addHandler(handler)
+    yield buff
+    logger.removeHandler(handler)
 
 
 def pytest_generate_tests(metafunc):
     if 'cfg' in metafunc.fixturenames:
         configs = []
+        logs = []
         for name in glob.glob('tests/*yaml'):
-            cfg = load(name)
-            configs.append(cfg)
-        metafunc.parametrize("cfg", configs)
+            print(name)
+            configs.append(load(name))
+            log_file = name.replace('.yaml', '.log')
+            logs.append(open(log_file).read())
+        metafunc.parametrize("cfg,nominal_log", zip(configs, logs))
+        # metafunc.parametrize("log", logs)