utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import sys
  2. from itertools import chain
  3. from collections import defaultdict, ChainMap
  4. from contextlib import contextmanager
  5. import logging
  6. log_fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  7. logger = logging.getLogger('byrd')
  8. logger.setLevel(logging.INFO)
  9. log_handler = logging.StreamHandler()
  10. log_handler.setLevel(logging.INFO)
  11. log_handler.setFormatter(logging.Formatter(log_fmt))
  12. logger.addHandler(log_handler)
  13. class ByrdException(Exception):
  14. pass
  15. class FmtException(ByrdException):
  16. pass
  17. class ExecutionException(ByrdException):
  18. pass
  19. class RemoteException(ExecutionException):
  20. pass
  21. class LocalException(ExecutionException):
  22. pass
  23. def enable_logging_color():
  24. try:
  25. import colorama
  26. except ImportError:
  27. return
  28. colorama.init()
  29. MAGENTA = colorama.Fore.MAGENTA
  30. RED = colorama.Fore.RED
  31. RESET = colorama.Style.RESET_ALL
  32. # We define custom handler ..
  33. class Handler(logging.StreamHandler):
  34. def format(self, record):
  35. if record.levelname == 'INFO':
  36. record.msg = MAGENTA + record.msg + RESET
  37. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  38. record.msg = RED + record.msg + RESET
  39. return super(Handler, self).format(record)
  40. # .. and plug it
  41. logger.removeHandler(log_handler)
  42. handler = Handler()
  43. handler.setFormatter(logging.Formatter(log_fmt))
  44. logger.addHandler(handler)
  45. logger.propagate = 0
  46. def edits(word):
  47. yield word
  48. splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
  49. for left, right in splits:
  50. if right:
  51. yield left + right[1:]
  52. def gen_candidates(wordlist):
  53. candidates = defaultdict(set)
  54. for word in wordlist:
  55. for ed1 in edits(word):
  56. for ed2 in edits(ed1):
  57. candidates[ed2].add(word)
  58. return candidates
  59. def spell(candidates, word):
  60. matches = set(chain.from_iterable(
  61. candidates[ed] for ed in edits(word) if ed in candidates
  62. ))
  63. return matches
  64. def spellcheck(objdict, word):
  65. if word in objdict:
  66. return objdict[word]
  67. candidates = objdict.get('_candidates')
  68. if not candidates:
  69. candidates = gen_candidates(list(objdict))
  70. objdict._candidates = candidates
  71. msg = '"%s" not found in %s' % (word, objdict._path)
  72. matches = spell(candidates, word)
  73. if matches:
  74. msg += ', try: %s' % ' or '.join(matches)
  75. raise ByrdException(msg)
  76. def abort(msg):
  77. logger.error(msg)
  78. sys.exit(1)
  79. class ObjectDict(dict):
  80. """
  81. Simple objet sub-class that allows to transform a dict into an
  82. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  83. """
  84. # Meta allows to hide all the keys starting with an '_'
  85. _meta = {}
  86. def copy(self):
  87. res = ObjectDict(super().copy())
  88. ObjectDict._meta[id(res)] = ObjectDict._meta.get(id(self), {}).copy()
  89. return res
  90. def __getattr__(self, key):
  91. if key.startswith('_'):
  92. return ObjectDict._meta[id(self), key]
  93. if key in self:
  94. return self[key]
  95. else:
  96. return None
  97. def __setattr__(self, key, value):
  98. if key.startswith('_'):
  99. ObjectDict._meta[id(self), key] = value
  100. else:
  101. self[key] = value
  102. class DummyClient:
  103. '''
  104. Dummy Paramiko client, mainly usefull for testing & dry runs
  105. '''
  106. @contextmanager
  107. def open_sftp(self):
  108. yield None
  109. class Env(ChainMap):
  110. def __init__(self, *dicts):
  111. self.fmt_kind = 'new'
  112. return super().__init__(*filter(lambda x: x is not None, dicts))
  113. def fmt_env(self, child_env, kind=None):
  114. new_env = {}
  115. for key, val in child_env.items():
  116. # env wrap-around!
  117. new_val = self.fmt(val, kind=kind)
  118. if new_val == val:
  119. continue
  120. new_env[key] = new_val
  121. return Env(new_env, child_env)
  122. def fmt_string(self, string, kind=None):
  123. fmt_kind = kind or self.fmt_kind
  124. try:
  125. if fmt_kind == 'old':
  126. return string % self
  127. else:
  128. return string.format(**self)
  129. except KeyError as exc:
  130. msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
  131. candidates = gen_candidates(self.keys())
  132. key = exc.args[0]
  133. matches = spell(candidates, key)
  134. if matches:
  135. msg += ', try: %s' % ' or '.join(matches)
  136. raise FmtException(msg )
  137. except IndexError:
  138. msg = 'Unable to format "%s", positional argument not supported'
  139. raise FmtException(msg)
  140. def fmt(self, what, kind=None):
  141. if isinstance(what, str):
  142. return self.fmt_string(what, kind=kind)
  143. return self.fmt_env(what, kind=kind)