소스 검색

Add sudo directive, allow to inline sub-task

Bertrand Chenal 7 년 전
부모
커밋
64db6c6a46
1개의 변경된 파일46개의 추가작업 그리고 32개의 파일을 삭제
  1. 46 32
      baker.py

+ 46 - 32
baker.py

@@ -246,21 +246,20 @@ class Multi(Node):
     _children = {
         'task': Atom,
         'export': Atom,
-        'python': Atom,
         'network': Atom,
-        'env': EnvNode,
     }
 
 class MultiList(Node):
     _children = [Multi]
 
-class Command(Node):
+class Task(Node):
     _children = {
         'desc': Atom,
         'local': Atom,
         'python': Atom,
         'once': Atom,
         'run': Atom,
+        'sudo': Atom,
         'send': Atom,
         'to': Atom,
         'assert': Atom,
@@ -276,9 +275,10 @@ class Command(Node):
         super().setup(values, path)
         return values
 
-class Task(Node):
+
+class TaskGroup(Node):
     _children = {
-        '*': Command,
+        '*': Task,
     }
 
 class LoadNode(Node):
@@ -293,12 +293,16 @@ class LoadList(Node):
 class ConfigRoot(Node):
     _children = {
         'networks': Network,
-        'tasks': Task,
+        'tasks': TaskGroup,
         'auth': Auth,
         'env': EnvNode,
         'load': LoadList,
     }
 
+# Multi can also accept any task attribute:
+Multi._children.update(Task._children)
+
+
 
 class Env(ChainMap):
 
@@ -391,20 +395,24 @@ def run_local(cmd, env, cli):
     )
     stdout, stderr = process.communicate()
     success = process.returncode == 0
-    logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
+    if stdout:
+        logger.debug(TAB + TAB.join(stdout.decode().splitlines()))
     if not success:
         raise LocalException(stdout, stderr)
     return ObjectDict(stdout=stdout, stderr=stderr)
 
 
-def run_python(code, env, cli):
+def run_python(task, env, cli):
     # Execute a piece of python localy
+    code = task.python
     logger.info(env.fmt('{task_desc}'))
     if cli.dry_run:
         logger.info('[DRY-RUN] ' + code)
         return None
     logger.debug(TAB + TAB.join(code.splitlines()))
     cmd = 'python -c "import sys;exec(sys.stdin.read())"'
+    if task.sudo:
+        cmd = 'sudo -- ' + cmd
     process = subprocess.Popen(
         cmd,
         stdout=subprocess.PIPE,
@@ -425,7 +433,8 @@ def run_python(code, env, cli):
     process.stdout.close()
     process.stderr.close()
     out = out_buff.getvalue()
-    logger.debug(TAB + TAB.join(out.splitlines()))
+    if out:
+        logger.debug(TAB + TAB.join(out.splitlines()))
     if not success:
         raise LocalException(out + err_buff.getvalue())
     return ObjectDict(stdout=out, stderr=err_buff.getvalue())
@@ -448,13 +457,9 @@ def log_stream(stream, buff):
 
 
 def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
-    assert not sudo, 'Not implemented'
-
-    # stdin, stdout, stderr = client.exec_command(cmd)
     chan = client.get_transport().open_session()
     if env:
         chan.update_environment(env)
-    chan.exec_command(cmd)
 
     stdin = chan.makefile('wb')
     stdout = chan.makefile('r')
@@ -463,15 +468,24 @@ def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
     err_buff = io.StringIO()
     out_thread = log_stream(stdout, out_buff)
     err_thread = log_stream(stderr, err_buff)
+
+    if sudo:
+        assert not in_buff, 'in_buff and sudo can not be combined'
+        chan.exec_command('sudo -s')
+        in_buff = cmd
+    else:
+        chan.exec_command(cmd)
     if in_buff:
-        # XXX use a real buff (not a simple str)
+        # XXX use a real buff (not a simple str) ?
         stdin.write(in_buff)
+        stdin.flush()
         stdin.close()
+        chan.shutdown_write()
 
+    success = chan.recv_exit_status() == 0
     out_thread.join()
     err_thread.join()
 
-    success = chan.recv_exit_status() == 0
     if not success:
         raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
 
@@ -481,6 +495,7 @@ def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
     )
     return res
 
+
 def run_remote(task, host, env, cli):
     res = None
     host = env.fmt(host)
@@ -493,21 +508,14 @@ def run_remote(task, host, env, cli):
         client = connect(host, cli.cfg.auth)
     if task.run:
         cmd = env.fmt(task.run)
-        logger.info(env.fmt('{host}: {task_desc}'))
+        prefix = task.sudo and '[SUDO]' or ''
+        msg = prefix + '{host}: {task_desc}'
+        logger.info(env.fmt(msg))
         logger.debug(TAB + TAB.join(cmd.splitlines()))
         if cli.dry_run:
             logger.info('[DRY-RUN] ' + cmd)
         else:
-            res = run_helper(client, cmd, env=env)
-
-    elif task.sudo:
-        cmd = env.fmt(task.sudo)
-        logger.info(env.fmt('[SUDO] {host}: {task_desc}'))
-
-        if cli.dry_run:
-            logger.info('[DRY-RUN] %s' + cmd)
-        else:
-            res = run_helper(client, cmd, env=env, sudo=True)
+            res = run_helper(client, cmd, env=env, sudo=task.sudo)
 
     elif task.send:
         local_path = env.fmt(task.send)
@@ -519,7 +527,7 @@ def run_remote(task, host, env, cli):
         else:
             with client.open_sftp() as sftp:
                 if os.path.isfile(local_path):
-                    sftp.put(local_path, remote_path)
+                    sftp.put(os.path.abspath(local_path), remote_path)
                 else:
                     for root, subdirs, files in os.walk(local_path):
                         rel_dir = os.path.relpath(root, local_path)
@@ -532,7 +540,7 @@ def run_remote(task, host, env, cli):
     else:
         raise BakerException('Unable to run task "%s"' % task.name)
 
-    if res:
+    if res and res.stdout:
         logger.debug(TAB + TAB.join(res.stdout.splitlines()))
 
     return res
@@ -564,7 +572,7 @@ def run_task(task, host, cli, parent_env=None):
     if task.local:
         res = run_local(task.local, env, cli)
     elif task.python:
-        res = run_python(task.python, env, cli)
+        res = run_python(task, env, cli)
     else:
         res = run_remote(task, host, env, cli)
 
@@ -591,10 +599,16 @@ def run_batch(task, hosts, cli, env=None):
     env = Env(export_env, task.get('env'), env)
 
     if task.get('multi'):
+        sudo = task.sudo
         for multi in task.multi:
             task = multi.task
-            spellcheck(cli.cfg.tasks, task)
-            sub_task = cli.cfg.tasks[task]
+            if task:
+                spellcheck(cli.cfg.tasks, task)
+                sub_task = cli.cfg.tasks[task]
+            else:
+                sub_task = Task.parse(multi)
+            if multi.sudo is not None or sudo is not None:
+                sub_task.sudo = multi.sudo or sudo
             network = multi.get('network')
             if network:
                 spellcheck(cli.cfg.networks, network)
@@ -734,7 +748,7 @@ def get_hosts_and_tasks(cli, cfg):
         cmd_key = cli_key.rsplit('_', 1)[-1]
         customs.extend('%s: %s' % (cmd_key, ck) for ck in cli[cli_key])
     for custom_task in customs:
-        task = Command.parse(yaml_load(custom_task))
+        task = Task.parse(yaml_load(custom_task))
         task.desc = 'Custom command'
         tasks.append(task)