sha1 supprt checkpoint
[certmaster.git] / scripts / certmaster-sync
index bd27af5..8e6db44 100644 (file)
@@ -7,14 +7,33 @@
 
 import os
 import sys
-import sha
+import warning
+try:
+    import hashlib
+except ImportError:
+    # Python-2.4.z ... gah! (or even 2.3!)
+    import sha
+    class hashlib:
+        @staticmethod
+        def new(algo):
+            if algo == 'sha1':
+                warnings.warn('sha1 is deprecated',DeprecationWarning)
+                return sha.new()
+            raise ValueError, "Bad checksum type"
+
+
 import xmlrpclib
 from glob import glob
 from time import sleep
 from certmaster import certmaster as certmaster
-from func.overlord.client import Client
-from func.CommonErrors import Func_Client_Exception
-import func.jobthing as jobthing
+
+func_import_failure = None
+try:
+    from func.overlord.client import Client
+    from func.CommonErrors import Func_Client_Exception
+    import func.jobthing as jobthing
+except ImportError, e:
+    func_import_failure = str(e)
 
 def syncable(cert_list):
     """
@@ -59,7 +78,7 @@ def remote_peers(hosts):
 
 def local_certs():
     """
-    Returns (hostname, sha1) hash of local certs
+    Returns (hostname, hashval) hash of local certs
     """
     globby = '*.%s' % cm.cfg.cert_extension
     globby = os.path.join(cm.cfg.certroot, globby)
@@ -67,12 +86,13 @@ def local_certs():
     results = []
     for f in files:
         hostname = os.path.basename(f).replace('.' + cm.cfg.cert_extension, '')
-        digest = checksum(f)
-        results.append([hostname, digest])
+        dirname = os.path.dirname(f)
+        digest = checksum(f,cm.cfg.hashfunc)
+        results.append([hostname, digest, dirname])
     return results
 
-def checksum(f):
-    thissum = sha.new()
+def checksum(f,hashfunc):
+    thissum = hashlib.new(hashfunc)
     if os.path.exists(f):
         fo = open(f, 'r')
         data = fo.read()
@@ -107,7 +127,7 @@ def copy_updated_certs(local, remote):
         for cert in local:
             if cert not in peers:
                 cert_name = '%s.%s' % (cert[0], cm.cfg.cert_extension)
-                full_path = os.path.join(cm.cfg.certroot, cert_name)
+                full_path = os.path.join(cert[2], cert_name)
                 fd = open(full_path)
                 certblob = fd.read()
                 fd.close()
@@ -124,6 +144,11 @@ def main():
     if not cm.cfg.sync_certs and not forced:
         sys.exit(0)
 
+    # Don't complain about func not being available until you actually want it
+    if func_import_failure != None:
+        print >> sys.stderr,  "errors importing func: %s" % func_import_failure
+        sys.exit(1)
+
     certs = glob(os.path.join(cm.cfg.certroot,
                               '*.%s' % cm.cfg.cert_extension))
     hosts = syncable(certs)