test_fakeopensslclasses.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. '''
  2. Unit tests for the FakeOpenSSL classes
  3. '''
  4. import os
  5. import subprocess
  6. import sys
  7. import pytest
  8. MODULE_PATH = os.path.realpath(os.path.join(__file__, os.pardir, os.pardir, 'library'))
  9. sys.path.insert(1, MODULE_PATH)
  10. # pylint: disable=import-error,wrong-import-position,missing-docstring
  11. # pylint: disable=invalid-name,redefined-outer-name
  12. from openshift_cert_expiry import FakeOpenSSLCertificate # noqa: E402
  13. @pytest.fixture(scope='module')
  14. def fake_valid_cert(valid_cert):
  15. cmd = ['openssl', 'x509', '-in', str(valid_cert['cert_file']), '-text']
  16. cert = subprocess.check_output(cmd)
  17. return FakeOpenSSLCertificate(cert.decode('utf8'))
  18. def test_not_after(valid_cert, fake_valid_cert):
  19. ''' Validate value returned back from get_notAfter() '''
  20. real_cert = valid_cert['cert']
  21. # Internal representation of pyOpenSSL is bytes, while FakeOpenSSLCertificate
  22. # is text, so decode the result from pyOpenSSL prior to comparing
  23. assert real_cert.get_notAfter().decode('utf8') == fake_valid_cert.get_notAfter()
  24. def test_serial(valid_cert, fake_valid_cert):
  25. ''' Validate value returned back form get_serialnumber() '''
  26. real_cert = valid_cert['cert']
  27. assert real_cert.get_serial_number() == fake_valid_cert.get_serial_number()
  28. def test_get_subject(valid_cert, fake_valid_cert):
  29. ''' Validate the certificate subject '''
  30. # Gather the subject components and create a list of colon separated strings.
  31. # Since the internal representation of pyOpenSSL uses bytes, we need to decode
  32. # the results before comparing.
  33. c_subjects = valid_cert['cert'].get_subject().get_components()
  34. c_subj = ', '.join(['{}:{}'.format(x.decode('utf8'), y.decode('utf8')) for x, y in c_subjects])
  35. f_subjects = fake_valid_cert.get_subject().get_components()
  36. f_subj = ', '.join(['{}:{}'.format(x, y) for x, y in f_subjects])
  37. assert c_subj == f_subj
  38. def get_san_extension(cert):
  39. # Internal representation of pyOpenSSL is bytes, while FakeOpenSSLCertificate
  40. # is text, so we need to set the value to search for accordingly.
  41. if isinstance(cert, FakeOpenSSLCertificate):
  42. san_short_name = 'subjectAltName'
  43. else:
  44. san_short_name = b'subjectAltName'
  45. for i in range(cert.get_extension_count()):
  46. ext = cert.get_extension(i)
  47. if ext.get_short_name() == san_short_name:
  48. # return the string representation to compare the actual SAN
  49. # values instead of the data types
  50. return str(ext)
  51. return None
  52. def test_subject_alt_names(valid_cert, fake_valid_cert):
  53. real_cert = valid_cert['cert']
  54. san = get_san_extension(real_cert)
  55. f_san = get_san_extension(fake_valid_cert)
  56. assert san == f_san
  57. # If there are either dns or ip sans defined, verify common_name present
  58. if valid_cert['ip'] or valid_cert['dns']:
  59. assert 'DNS:' + valid_cert['common_name'] in f_san
  60. # Verify all ip sans are present
  61. for ip in valid_cert['ip']:
  62. assert 'IP Address:' + ip in f_san
  63. # Verify all dns sans are present
  64. for name in valid_cert['dns']:
  65. assert 'DNS:' + name in f_san