瀏覽代碼

Get rid of spur

Bertrand Chenal 7 年之前
父節點
當前提交
228adb7dbc

+ 1 - 1
README.md

@@ -104,7 +104,7 @@ WHAT?
 
 # Roadmap
 
-- Document: assert, send dry-run and auth
+- Document: assert, send dry-run, python and auth
 - Add tests
 - Implement: import of other config file, add a contrib directory with
   ready-made tasks for common operations

+ 198 - 86
baker.py

@@ -3,13 +3,15 @@ from hashlib import md5
 from itertools import chain
 from collections import ChainMap, OrderedDict, defaultdict
 import argparse
+import io
 import logging
 import os
 import posixpath
-import shlex
+import subprocess
 import sys
+import threading
 
-import spur
+import paramiko
 import yaml
 
 
@@ -30,6 +32,23 @@ log_handler.setFormatter(logging.Formatter(log_fmt))
 logger.addHandler(log_handler)
 
 
+TAB = '\n    '
+
+class BakerException(Exception):
+    pass
+
+class FmtException(BakerException):
+    pass
+
+class ExecutionException(BakerException):
+    pass
+
+class RemoteException(ExecutionException):
+    pass
+
+class LocalException(ExecutionException):
+    pass
+
 def enable_logging_color():
     try:
         import colorama
@@ -161,7 +180,12 @@ class Node:
                 for key in cfg:
                     if key not in children:
                         path = ' -> '.join(path)
-                        msg = 'Attribute "%s" not understoodin %s' % (key, path)
+                        if path:
+                            msg = 'Attribute "%s" not understood in %s' % (
+                                key, path)
+                        else:
+                            msg = 'Top-level attribute "%s" not understood' % (
+                                key)
                         candidates = gen_candidates(children.keys())
                         matches = spell(candidates, key)
                         if matches:
@@ -246,10 +270,9 @@ class Command(Node):
 
     @classmethod
     def setup(cls, values, path):
-        if path:
-            values['name'] = path[-1]
+        values['name'] = path and path[-1] or ''
         if 'desc' not in values:
-            values['desc'] = values['name']
+            values['desc'] = values.get('name', '')
         super().setup(values, path)
         return values
 
@@ -290,12 +313,12 @@ class Env(ChainMap):
             candidates = gen_candidates(self.keys())
             key = exc.args[0]
             matches = spell(candidates, key)
-            if msg:
+            if matches:
                 msg += ', try: %s' % ' or '.join(matches)
-            abort(msg )
+            raise FmtException(msg )
         except IndexError as exc:
             msg = 'Unable to format "%s", positional argument not supported'
-            abort(msg)
+            raise FmtException(msg)
 
 
 def get_passphrase(key_path):
@@ -342,25 +365,15 @@ def connect(host, auth):
         password = get_password(host)
 
     username, hostname = host.split('@', 1)
-    shell = spur.SshShell(
-        hostname=hostname,
-        username=username,
-        password=password,
-        private_key_file=private_key_file,
-        missing_host_key=spur.ssh.MissingHostKey.accept,
-    )
-
-    CONNECTION_CACHE[host] = shell
-    return shell
 
+    client = paramiko.SSHClient()
+    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+    client.connect(hostname, username=username, password=password,
+                   key_filename=private_key_file,
+    )
+    CONNECTION_CACHE[host] = client
+    return client
 
-def subshell(command, local=False):
-    if not isinstance(command, (list, tuple)):
-        command = list(shlex.shlex(command))
-    if local and sys.platform == 'win32':
-        shell = os.environ.get('COMSPEC', 'cmd.exe')
-        return [shell, '/c'] + command
-    return ['sh', '-c', command]
 
 def run_local(cmd, env, cli):
     # Run local task
@@ -369,12 +382,20 @@ def run_local(cmd, env, cli):
     if cli.dry_run:
         logger.info('[DRY-RUN] ' + cmd)
         return None
-    shell = spur.LocalShell()
-    logger.debug('\n\t' + '\n\t'.join(cmd.splitlines()))
-    res = shell.run(subshell(cmd, local=True), update_env=env)
-    output = res.output.decode()
-    logger.debug('\n\t' + '\n\t'.join(output.splitlines()))
-    return output
+    logger.debug(TAB + TAB.join(cmd.splitlines()))
+    process = subprocess.Popen(
+        cmd, shell=True,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.STDOUT,
+        env=env,
+    )
+    stdout, stderr = process.communicate()
+    success = process.returncode == 0
+    logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
+    if not success:
+        raise LocalException(stdout, stderr)
+    return ObjectDict(stdout=stdout, stderr=stderr)
+
 
 def run_python(code, env, cli):
     # Execute a piece of python localy
@@ -382,16 +403,83 @@ def run_python(code, env, cli):
     if cli.dry_run:
         logger.info('[DRY-RUN] ' + code)
         return None
-    shell = spur.LocalShell()
-    logger.debug('\n\t' + '\n\t'.join(code.splitlines()))
-    cmd = subshell('python -c "import sys;exec(sys.stdin.read())"', local=True)
-    proc = shell.spawn(cmd, update_env=env)
-    proc.stdin_write(code.encode('utf-8'))
-    proc._process_stdin.close()
-    res = proc.wait_for_result()
-    output = res.output.decode()
-    logger.debug('\n\t' + '\n\t'.join(output.splitlines()))
-    return output
+    logger.debug(TAB + TAB.join(code.splitlines()))
+    cmd = 'python -c "import sys;exec(sys.stdin.read())"'
+    process = subprocess.Popen(
+        cmd,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+        stdin=subprocess.PIPE,
+        env=env,
+    )
+
+    # Plug io
+    out_buff = io.StringIO()
+    err_buff = io.StringIO()
+    log_stream(process.stdout, out_buff)
+    log_stream(process.stderr, err_buff)
+    process.stdin.write(code.encode())
+    process.stdin.flush()
+    process.stdin.close()
+    success = process.wait() == 0
+    process.stdout.close()
+    process.stderr.close()
+    out = out_buff.getvalue()
+    logger.debug(TAB + TAB.join(out.splitlines()))
+    if not success:
+        raise LocalException(out + err_buff.getvalue())
+    return ObjectDict(stdout=out, stderr=err_buff.getvalue())
+
+
+def log_stream(stream, buff):
+    def _log():
+        try:
+            for chunk in iter(lambda: stream.readline(2048), ""):
+                if isinstance(chunk, bytes):
+                    chunk = chunk.decode()
+                buff.write(chunk)
+        except ValueError:
+            # read raises a ValueError on closed stream
+            pass
+
+    t = threading.Thread(target=_log)
+    t.start()
+    return t
+
+
+def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
+    assert not sudo, 'Not implemented'
+
+    # stdin, stdout, stderr = client.exec_command(cmd)
+    chan = client.get_transport().open_session()
+    if env:
+        chan.update_environment(env)
+    chan.exec_command(cmd)
+
+    stdin = chan.makefile('wb')
+    stdout = chan.makefile('r')
+    stderr = chan.makefile_stderr('r')
+    out_buff = io.StringIO()
+    err_buff = io.StringIO()
+    out_thread = log_stream(stdout, out_buff)
+    err_thread = log_stream(stderr, err_buff)
+    if in_buff:
+        # XXX use a real buff (not a simple str)
+        stdin.write(in_buff)
+        stdin.close()
+
+    out_thread.join()
+    err_thread.join()
+
+    success = chan.recv_exit_status() == 0
+    if not success:
+        raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
+
+    res = ObjectDict(
+        stdout = out_buff.getvalue(),
+        stderr = err_buff.getvalue(),
+    )
+    return res
 
 def run_remote(task, host, env, cli):
     res = None
@@ -399,15 +487,15 @@ def run_remote(task, host, env, cli):
     env.update({
         'host': host,
     })
-    shell = connect(host, cli.cfg.auth)
+    client = connect(host, cli.cfg.auth)
     if task.run:
         cmd = env.fmt(task.run)
         logger.info(env.fmt('{host}: {task_desc}'))
-        logger.debug('\n\t' + '\n\t'.join(cmd.splitlines()))
+        logger.debug(TAB + TAB.join(cmd.splitlines()))
         if cli.dry_run:
             logger.info('[DRY-RUN] ' + cmd)
         else:
-            res = shell.run(subshell(cmd), update_env=env)
+            res = run_helper(client, cmd, env=env)
 
     elif task.sudo:
         cmd = env.fmt(task.sudo)
@@ -416,7 +504,7 @@ def run_remote(task, host, env, cli):
         if cli.dry_run:
             logger.info('[DRY-RUN] %s' + cmd)
         else:
-            res = shell.sudo(cmd)
+            res = run_helper(client, cmd, env=env, sudo=True)
 
     elif task.send:
         local_path = env.fmt(task.send)
@@ -426,14 +514,14 @@ def run_remote(task, host, env, cli):
             logger.info('[DRY-RUN]')
             return
         else:
-            with shell._connect_sftp() as sftp:
+            with client.open_sftp() as sftp:
                 if os.path.isfile(local_path):
                     sftp.put(local_path, remote_path)
                 else:
                     for root, subdirs, files in os.walk(local_path):
                         rel_dir = os.path.relpath(root, local_path)
                         rem_dir = posixpath.join(remote_path, rel_dir)
-                        shell.run('mkdir -p {}'.format(rem_dir))
+                        run_helper(client, 'mkdir -p {}'.format(rem_dir))
                         for f in files:
                             rel_f = os.path.join(root, f)
                             rem_file = posixpath.join(rem_dir, f)
@@ -441,6 +529,9 @@ def run_remote(task, host, env, cli):
     else:
         abort('Unable to run task "%s"' % task.name)
 
+    if res:
+        logger.debug(TAB + TAB.join(res.stdout.splitlines()))
+
     return res
 
 
@@ -451,14 +542,14 @@ def run_task(task, host, cli, parent_env=None):
 
     # Prepare environment
     env = Env(
-        # Cli is top priority
-        dict(e.split('=') for e in cli.env),
-        # Then comes env from parent task
+        # Env from parent task
         parent_env,
         # Env on the task itself
         task.get('env'),
         # Top-level env
         cli.cfg.get('env'),
+        # OS env
+        os.environ,
     ).new_child()
 
     env.update({
@@ -476,8 +567,8 @@ def run_task(task, host, cli, parent_env=None):
 
     if task.get('assert'):
         env.update({
-            'stdout': res.output,
-            'stderr': res.stderr_output,
+            'stdout': res.stdout.strip(),
+            'stderr': res.stderr.strip(),
         })
         assert_ = env.fmt(task['assert'])
         ok = eval(assert_, dict(env))
@@ -493,7 +584,7 @@ def run_batch(task, hosts, cli, env=None):
     Run one task on a list of hosts
     '''
     env = Env(task.get('env'), env)
-    res = None
+    out = None
     export_env = {}
 
     if task.get('multi'):
@@ -510,26 +601,31 @@ def run_batch(task, hosts, cli, env=None):
                 # env wrap-around!
                 child_env[k] = env.fmt(child_env[k])
             run_env = Env(export_env, child_env, env)
-            res = run_batch(sub_task, hosts, cli, run_env)
+            out = run_batch(sub_task, hosts, cli, run_env)
+            out = out.decode() if isinstance(out, bytes) else out
+            export_env['_'] = out
             if multi.export:
-                export_env[multi.export] = res and res.output.strip() or ''
+                export_env[multi.export] = out
 
     else:
+        res = None
         if task.once and (task.local or task.python):
             res = run_task(task, None, cli, env)
-            return res
-        for host in hosts:
-            res = run_task(task, host, cli, env)
-            if task.once:
-                break
-    return res
+        else:
+            for host in hosts:
+                res = run_task(task, host, cli, env)
+                if task.once:
+                    break
+        out = res and res.stdout.strip() or ''
+    return out
 
 
 def abort(msg):
     logger.error(msg)
     sys.exit(1)
 
-def load(path, prefix=None):
+
+def load_cfg(path, prefix=None):
     load_sections = ('networks', 'tasks', 'auth', 'env')
 
     if os.path.isfile(path):
@@ -561,7 +657,7 @@ def load(path, prefix=None):
             else:
                 child_prefix, _ = os.path.splitext(item.file)
             child_path = os.path.join(cfg_path, item.file)
-            child_cfg = load(child_path, child_prefix.split('/'))
+            child_cfg = load_cfg(child_path, child_prefix.split('/'))
 
             for section in load_sections:
                 if not cfg.get(section):
@@ -570,14 +666,18 @@ def load(path, prefix=None):
     return cfg
 
 
-def base_cli(args=None):
+def load_cli(args=None):
     parser = argparse.ArgumentParser()
     parser.add_argument('names',  nargs='*',
                         help='Hosts and commands to run them on')
     parser.add_argument('-c', '--config', default='bk.yaml',
                         help='Config file')
-    parser.add_argument('-r', '--run', nargs='*', default=[],
-                        help='Run custom task')
+    parser.add_argument('-R', '--run', nargs='*', default=[],
+                        help='Run remote task')
+    parser.add_argument('-L', '--run-local', nargs='*', default=[],
+                        help='Run local task')
+    parser.add_argument('-P', '--run-python', nargs='*', default=[],
+                        help='Run python task')
     parser.add_argument('-d', '--dry-run', action='store_true',
                         help='Do not run actual tasks, just print them')
     parser.add_argument('-e', '--env', nargs='*', default=[],
@@ -592,22 +692,18 @@ def base_cli(args=None):
     parser.add_argument('-n', '--no-color', action='store_true',
                         help='Disable colored logs')
     cli = parser.parse_args(args=args)
-    return ObjectDict(vars(cli))
-
-
-def main():
-    cli = base_cli()
-    if not cli.no_color:
-        enable_logging_color()
-    cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
-    level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
-    log_handler.setLevel(level)
-    logger.setLevel(level)
+    cli = ObjectDict(vars(cli))
 
     # Load config
-    cfg = load(cli.config)
+    cfg = load_cfg(cli.config)
     cli.cfg = cfg
+    cli.update(get_hosts_and_tasks(cli, cfg))
+
+    # Transformt env string into dict
+    cli.env = dict(e.split('=') for e in cli.env)
+    return cli
 
+def get_hosts_and_tasks(cli, cfg):
     # Make sure we don't have overlap between hosts and tasks
     items = list(cfg.networks) + list(cfg.tasks)
     msg = 'Name collision between tasks and networks'
@@ -628,20 +724,36 @@ def main():
             matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
             if matches:
                 msg += ', try: %s' % ' or '.join(matches)
-            abort(msg)
-
-    for custom_task in cli.run:
+            raise BakerException (msg)
+
+    # Collect custom tasks from cli
+    customs = []
+    for cli_key in ('run', 'run_local', 'run_python'):
+        cmd_key = cli_key.rsplit('_', 1)[-1]
+        customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
+    for custom_task in customs:
         task = Command.parse(yaml_load(custom_task))
         task.desc = 'Custom command'
         tasks.append(task)
 
     hosts = list(chain.from_iterable(n.hosts for n in networks))
 
+    return dict(hosts=hosts, tasks=tasks)
+
+
+def main():
+    cli = load_cli()
+    if not cli.no_color:
+        enable_logging_color()
+    cli.verbose = max(0, 1 + cli.verbose - cli.quiet)
+    level = ['WARNING', 'INFO', 'DEBUG'][min(cli.verbose, 2)]
+    log_handler.setLevel(level)
+    logger.setLevel(level)
+
     try:
-        for task in tasks:
-            run_batch(task, hosts, cli)
-    except Exception as e:
-        # TODO intercept only spur exceptions
+        for task in cli.tasks:
+            run_batch(task, cli.hosts, cli, cli.env)
+    except BakerException as e:
         if cli.verbose > 2:
             raise
         abort(str(e))

+ 5 - 0
examples/assert.yaml

@@ -0,0 +1,5 @@
+tasks:
+  one:
+    python: print('one')
+    assert: "stdout == 'one'"
+    once: true

+ 0 - 31
examples/bk1.yaml

@@ -1,31 +0,0 @@
-networks:
-  web:
-    hosts:
-      - web1.example.com
-      - web2.example.com
-  db:
-    hosts:
-      - db1.example.com
-      - db2.example.com
-
-tasks:
-  health:
-    desc: Get basic health info
-    run: uptime
-
-  time:
-    desc: Print current time (on local machine)
-    local: date -Iseconds
-    once: true
-
-  hello:
-    desc: Word count on "Hello"
-    local: echo 'Hello' | wc
-    once: true
-
-  hello-python:
-    desc: Says hello with python
-    python: |
-      for i in range(10):
-          print('hello')
-    once: true

+ 0 - 19
examples/bk2.yaml

@@ -1,19 +0,0 @@
-tasks:
-  echo:
-    desc: Simple echo
-    local: echo "{what}"
-    once: true
-    env:
-      what: "ECHO!"
-
-  echo-var:
-    desc: Echo an env variable
-    local: echo {my_var}
-    once: true
-    
-  both:
-    desc: Run both tasks
-    multi:
-      - task: echo
-        export: my_var
-      - task: echo-var

+ 6 - 0
examples/load.yaml

@@ -0,0 +1,6 @@
+load:
+  - file: network_only.yaml
+    as: net
+tasks:
+  echo-host:
+    local: echo {host}

+ 23 - 0
examples/multi.yaml

@@ -0,0 +1,23 @@
+tasks:
+  one:
+    python: print('one')
+    once: true
+  two:
+    python: print('two')
+    once: true
+  three:
+    python: print('three')
+    once: true
+  concat:
+    python: |
+      import os
+      print(os.environ['one'], os.environ['two'], os.environ['_'])
+    once: true
+  all:
+    multi:
+      - task: one
+        export: one
+      - task: two
+        export: two
+      - task: three
+      - task: concat

+ 0 - 0
tests/network_only.yaml → examples/network_only.yaml


+ 7 - 0
examples/python.yaml

@@ -0,0 +1,7 @@
+tasks:
+  print:
+    desc: Print module type
+    python: |-
+      import os
+      print(type(os))
+    once: true

+ 0 - 0
tests/task_only.yaml → examples/task_only.yaml


+ 7 - 0
tests/assert_test.yaml

@@ -0,0 +1,7 @@
+cli: '-c examples/assert.yaml one'
+output: |
+  Load config examples/assert.yaml
+  one
+  print('one')
+  one
+  Assert ok

+ 29 - 7
tests/base_test.py

@@ -1,10 +1,32 @@
-from baker import base_cli, run_batch
+from shlex import shlex
 
+import pytest
 
-def test_all_conf(cfg, nominal_log, log_handler):
-    cli = base_cli(['--dry-run'])
-    cli.cfg = cfg
-    for task in cfg.tasks.values():
-        run_batch(task, [], cli)
+from baker import run_batch, load_cli
 
-    assert nominal_log == log_handler.getvalue()
+
+def test_all_conf(test_cfg, log_buff):
+    verbose = pytest.config.getoption('verbose', 0) > 0
+
+    args = []
+    if test_cfg.cli:
+        lexer = shlex(test_cfg.cli)
+        lexer.wordchars += '.!=<>:{}-/'
+        args = list(lexer)
+    cli = load_cli(args)
+
+    for task in cli.tasks:
+        run_batch(task, cli.hosts, cli, cli.env)
+
+    actual_lines = log_buff.getvalue().splitlines()
+    actual_lines = [l.strip() for l in filter(None, actual_lines)]
+    expected_output = test_cfg.output or ''
+    expected_lines = expected_output.splitlines()
+    expected_lines = [l.strip() for l in filter(None, expected_lines)]
+
+    # Join and makes path independant of OS
+    act = '\n'.join(actual_lines).replace('\\', '/')
+    exp = '\n'.join(expected_lines).replace('\\', '/')
+    if verbose:
+        print(act)
+    assert act == exp

+ 21 - 12
tests/conftest.py

@@ -4,29 +4,38 @@ import logging
 
 import pytest
 
-from baker import load, logger, log_handler
+from baker import logger, log_handler, yaml_load, ObjectDict
 
 # Disable default handler
 logger.removeHandler(log_handler)
 
 
 @pytest.yield_fixture(scope='function')
-def log_handler(request):
+def log_buff(request):
     buff = io.StringIO()
     handler = logging.StreamHandler(buff)
+    handler.setLevel('DEBUG')
+    logger.setLevel('DEBUG')
     logger.addHandler(handler)
     yield buff
     logger.removeHandler(handler)
 
 
+def pytest_addoption(parser):
+    parser.addoption("-F", "--filter", help="Filter yaml files")
+
+
 def pytest_generate_tests(metafunc):
-    if 'cfg' in metafunc.fixturenames:
-        configs = []
-        logs = []
-        for name in glob.glob('tests/*yaml'):
-            print(name)
-            configs.append(load(name))
-            log_file = name.replace('.yaml', '.log')
-            logs.append(open(log_file).read())
-        metafunc.parametrize("cfg,nominal_log", zip(configs, logs))
-        # metafunc.parametrize("log", logs)
+    if not 'test_cfg' in metafunc.fixturenames:
+        return
+
+    test_configs = []
+    ids = []
+    fltr = metafunc.config.option.filter
+    for name in glob.glob('tests/*test.yaml'):
+        if fltr and not fltr in name:
+            continue
+        test_cfg = ObjectDict(yaml_load(open(name)))
+        ids.append(test_cfg.cfg)
+        test_configs.append(test_cfg)
+    metafunc.parametrize("test_cfg", test_configs, ids=ids)

+ 8 - 0
tests/load_test.yaml

@@ -0,0 +1,8 @@
+cli: '-c examples/load.yaml net/web echo-host --dry-run'
+output: |
+  Load config examples/load.yaml
+  Load config examples/network_only.yaml
+  echo-host
+  [DRY-RUN] echo web1.example.com
+  echo-host
+  [DRY-RUN] echo web2.example.com

+ 16 - 0
tests/multi_test.yaml

@@ -0,0 +1,16 @@
+cli: '-c examples/multi.yaml all'
+output: |
+  Load config examples/multi.yaml
+  one
+  print('one')
+  one
+  two
+  print('two')
+  two
+  three
+  print('three')
+  three
+  concat
+  import os
+  print(os.environ['one'], os.environ['two'], os.environ['_'])
+  one two three

+ 2 - 0
tests/network_only_test.yaml

@@ -0,0 +1,2 @@
+cli: "--dry-run -c examples/network_only.yaml"
+output: Load config examples/network_only.yaml

+ 7 - 0
tests/python_test.yaml

@@ -0,0 +1,7 @@
+cli: '-c examples/python.yaml print'
+output: |
+  Load config examples/python.yaml
+  Print module type
+  import os
+  print(type(os))
+  <class 'module'>

+ 5 - 0
tests/task_only_test.yaml

@@ -0,0 +1,5 @@
+cli: "--dry-run -c examples/task_only.yaml time"
+output: |
+  Load config examples/task_only.yaml
+  Print current time (on local machine)
+  [DRY-RUN] date -Iseconds