Pārlūkot izejas kodu

Fix host parsing and improve PEP8 in main.py

Host parsing enforced user@host syntax. Patched to allow just host
and use current user on the master machine.
Aurelien 5 gadi atpakaļ
vecāks
revīzija
dcb808a9b5
1 mainītis faili ar 17 papildinājumiem un 12 dzēšanām
  1. 17 12
      byrd/main.py

+ 17 - 12
byrd/main.py

@@ -6,12 +6,13 @@ from string import Formatter
 import io
 import os
 import posixpath
+import pwd
 import subprocess
 import threading
 
 from .config import Task, yaml_load
 from .utils import (ByrdException, LocalException, ObjectDict, RemoteException,
-                   DummyClient, Env, spellcheck, spell, logger)
+                    DummyClient, Env, spellcheck, spell, logger)
 
 try:
     # This file is imported by setup.py at install time
@@ -24,7 +25,6 @@ __version__ = '0.0.2'
 TAB = '\n    '
 
 
-
 def get_secret(service, resource, resource_id=None):
     resource_id = resource_id or resource
     secret = keyring.get_password(service, resource_id)
@@ -51,6 +51,8 @@ def get_sudo_passwd():
 
 
 CONNECTION_CACHE = {}
+
+
 def connect(host, auth):
     if host in CONNECTION_CACHE:
         return CONNECTION_CACHE[host]
@@ -65,13 +67,17 @@ def connect(host, auth):
     else:
         password = get_password(host)
 
-    username, hostname = host.split('@', 1)
+    host_info = host.split('@')
+    hostname = host_info[-1]
+    if len(host_info) > 1:
+        username = host_info[0]
+    else:
+        username = pwd.getpwuid(os.getuid()).pw_name
 
     client = paramiko.SSHClient()
     client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
     client.connect(hostname, username=username, password=password,
-                   key_filename=private_key_file,
-    )
+                   key_filename=private_key_file)
     logger.debug(f'Connected to {hostname} as {username}')
     CONNECTION_CACHE[host] = client
     return client
@@ -197,8 +203,8 @@ def run_helper(client, cmd, env=None, in_buff=None, sudo=False):
         raise RemoteException(out_buff.getvalue() + err_buff.getvalue())
 
     res = ObjectDict(
-        stdout = out_buff.getvalue(),
-        stderr = err_buff.getvalue(),
+        stdout=out_buff.getvalue(),
+        stderr=err_buff.getvalue(),
     )
     return res
 
@@ -221,7 +227,7 @@ def run_remote(task, host, env, cli):
             if task.sudo is True:
                 prefix = '[sudo] '
             else:
-                 prefix = '[sudo as %s] ' % task.sudo
+                prefix = '[sudo as %s] ' % task.sudo
         msg = prefix + '{host}: {task_desc}'
         logger.info(env.fmt(msg, kind='new'))
         logger.debug(TAB + TAB.join(cmd.splitlines()))
@@ -233,7 +239,7 @@ def run_remote(task, host, env, cli):
     elif task.send:
         local_path = env.fmt(task.send)
         if not os.path.exists(local_path):
-            raise ByrdException('Path "%s" not found'  % local_path)
+            raise ByrdException('Path "%s" not found' % local_path)
         else:
             send(client, env, cli, task)
 
@@ -245,6 +251,7 @@ def run_remote(task, host, env, cli):
 
     return res
 
+
 def send(client, env, cli, task):
     fmt = task.fmt and Env(env, {'fmt': 'new'}).fmt(task.fmt) or None
     local_path = env.fmt(task.send)
@@ -389,8 +396,6 @@ def extract_host(host_string):
     return host_string and host_string.split('@')[-1] or ''
 
 
-
-
 def get_hosts_and_tasks(cli, cfg):
     # Make sure we don't have overlap between hosts and tasks
     items = list(cfg.networks) + list(cfg.tasks)
@@ -412,7 +417,7 @@ def get_hosts_and_tasks(cli, cfg):
             matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
             if matches:
                 msg += ', try: %s' % ' or '.join(matches)
-            raise ByrdException (msg)
+            raise ByrdException(msg)
 
     # Collect custom tasks from cli
     customs = []