Perform handshakes asynchronously

This commit is contained in:
2022-10-11 06:55:34 +10:30
parent 8046f04e33
commit 10ff26b17f
4 changed files with 45 additions and 14 deletions

View File

@@ -1,6 +1,6 @@
""" """
Usage: Usage:
certo [-vj] <hostnames>... [-d DAYS|--days-to-expiration=DAYS] [-t SECONDS|--timeout=SECONDS] certo [-vj] [-d DAYS|--days-to-expiration=DAYS] [-t SECONDS|--timeout=SECONDS] <hostnames>...
certo -h | --help certo -h | --help
Options: Options:
@@ -10,6 +10,7 @@ Options:
-d DAYS --days-to-expiration=DAYS Warn about near expiration if within DAYS of the cert's notAfter [default: 5]. -d DAYS --days-to-expiration=DAYS Warn about near expiration if within DAYS of the cert's notAfter [default: 5].
-t SECONDS --timeout=SECONDS Timeout for SSL Handshake [default: 5]. -t SECONDS --timeout=SECONDS Timeout for SSL Handshake [default: 5].
""" """
import asyncio
import logging import logging
from docopt import docopt from docopt import docopt
@@ -17,7 +18,8 @@ from docopt import docopt
from certo.checks.hostname import check_host_certificate_expiration from certo.checks.hostname import check_host_certificate_expiration
from certo.report import JSONReporter, DefaultReporter from certo.report import JSONReporter, DefaultReporter
if __name__ == "__main__":
async def main():
args = docopt(__doc__) args = docopt(__doc__)
output_as_json = args.get("-j") output_as_json = args.get("-j")
@@ -32,12 +34,19 @@ if __name__ == "__main__":
else: else:
reporter = DefaultReporter() reporter = DefaultReporter()
# @todo async jobs = {
for hs in hostnames: check_host_certificate_expiration(hs, days_to_expiration) for hs in hostnames
logging.info(f"Getting CERT from {hs}") }
reporter.add_check(check_host_certificate_expiration(hs, days_to_expiration)) checks = await asyncio.gather(*jobs)
for check in checks:
reporter.append(check)
if log := reporter.report(): if log := reporter.report():
print(log) print(log)
exit(reporter.num_failed()) exit(reporter.num_failed())
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,4 +1,6 @@
import asyncio
import datetime import datetime
import logging
from collections import namedtuple from collections import namedtuple
from dateutil.parser import parse as dtparse from dateutil.parser import parse as dtparse
@@ -11,18 +13,38 @@ CertCheckResult = namedtuple(
) )
def get_cert(hostname, timeout): # Unit of time slept asynchronously to simulate async socket handling
AWAIT_IOTA = 0.001
async def get_cert(hostname, timeout):
ctx = ssl.create_default_context() ctx = ssl.create_default_context()
with ctx.wrap_socket(socket.socket(), server_hostname=hostname) as s: with ctx.wrap_socket(
socket.socket(), server_hostname=hostname, do_handshake_on_connect=False
) as s:
s.settimeout(timeout) s.settimeout(timeout)
# @todo simulate async connect
s.connect((hostname, 443)) s.connect((hostname, 443))
s.setblocking(False)
# Cannot await the handshake: simulate it with asyncio sleep
while "Handshake not finished":
try:
s.do_handshake()
break
except ssl.SSLWantReadError:
await asyncio.sleep(AWAIT_IOTA)
except ssl.SSLWantWriteError:
await asyncio.sleep(AWAIT_IOTA)
return s.getpeercert() return s.getpeercert()
def check_host_certificate_expiration(hostname, days_to_expiration, timeout=5): async def check_host_certificate_expiration(hostname, days_to_expiration, timeout=5):
logging.info(f"Getting CERT from {hostname}")
try: try:
cert = get_cert(hostname, timeout) cert = await get_cert(hostname, timeout)
except ssl.SSLCertVerificationError as e: except ssl.SSLCertVerificationError as e:
return CertCheckResult(hostname, False, None, e.strerror) return CertCheckResult(hostname, False, None, e.strerror)

View File

@@ -5,7 +5,7 @@ class CheckReporter:
def __init__(self): def __init__(self):
self.checks = list() self.checks = list()
def add_check(self, check): def append(self, check):
self.checks.append(check) self.checks.append(check)
def failed(self): def failed(self):
@@ -43,8 +43,8 @@ class JSONReporter(CheckReporter):
class DefaultReporter(CheckReporter): class DefaultReporter(CheckReporter):
def add_check(self, check): def append(self, check):
super().add_check(check) super().append(check)
result = f"[{'PASS' if check.check_successful else 'FAIL'}] Check host {check.hostname}" result = f"[{'PASS' if check.check_successful else 'FAIL'}] Check host {check.hostname}"
if check.debug: if check.debug:
result += f" - {check.debug}" result += f" - {check.debug}"

View File

@@ -1,7 +1,7 @@
[tool.poetry] [tool.poetry]
name = "certo" name = "certo"
version = "0.1.0" version = "0.1.0"
description = "" description = "A certificate expiration checker and reminder"
authors = ["Guilhem MARION <gmarion@netc.fr>"] authors = ["Guilhem MARION <gmarion@netc.fr>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]