@ -218,7 +218,7 @@ def testing_context(server_cert=SIGNED_CERTFILE):
server_context = ssl . SSLContext ( ssl . PROTOCOL_TLS_SERVER )
server_context . load_cert_chain ( server_cert )
client _context. load_verify_locations ( SIGNING_CA )
server _context. load_verify_locations ( SIGNING_CA )
return client_context , server_context , hostname
@ -2262,6 +2262,23 @@ class ThreadedEchoServer(threading.Thread):
sys . stdout . write ( " server: read CB tls-unique from client, sending our CB data... \n " )
data = self . sslconn . get_channel_binding ( " tls-unique " )
self . write ( repr ( data ) . encode ( " us-ascii " ) + b " \n " )
elif stripped == b ' PHA ' :
if support . verbose and self . server . connectionchatty :
sys . stdout . write ( " server: initiating post handshake auth \n " )
try :
self . sslconn . verify_client_post_handshake ( )
except ssl . SSLError as e :
self . write ( repr ( e ) . encode ( " us-ascii " ) + b " \n " )
else :
self . write ( b " OK \n " )
elif stripped == b ' HASCERT ' :
if self . sslconn . getpeercert ( ) is not None :
self . write ( b ' TRUE \n ' )
else :
self . write ( b ' FALSE \n ' )
elif stripped == b ' GETCERT ' :
cert = self . sslconn . getpeercert ( )
self . write ( repr ( cert ) . encode ( " us-ascii " ) + b " \n " )
else :
if ( support . verbose and
self . server . connectionchatty ) :
@ -4148,6 +4165,179 @@ class ThreadedTests(unittest.TestCase):
' Session refers to a different SSLContext. ' )
@unittest.skipUnless ( ssl . HAS_TLSv1_3 , " Test needs TLS 1.3 " )
class TestPostHandshakeAuth ( unittest . TestCase ) :
def test_pha_setter ( self ) :
protocols = [
ssl . PROTOCOL_TLS , ssl . PROTOCOL_TLS_SERVER , ssl . PROTOCOL_TLS_CLIENT
]
for protocol in protocols :
ctx = ssl . SSLContext ( protocol )
self . assertEqual ( ctx . post_handshake_auth , False )
ctx . post_handshake_auth = True
self . assertEqual ( ctx . post_handshake_auth , True )
ctx . verify_mode = ssl . CERT_REQUIRED
self . assertEqual ( ctx . verify_mode , ssl . CERT_REQUIRED )
self . assertEqual ( ctx . post_handshake_auth , True )
ctx . post_handshake_auth = False
self . assertEqual ( ctx . verify_mode , ssl . CERT_REQUIRED )
self . assertEqual ( ctx . post_handshake_auth , False )
ctx . verify_mode = ssl . CERT_OPTIONAL
ctx . post_handshake_auth = True
self . assertEqual ( ctx . verify_mode , ssl . CERT_OPTIONAL )
self . assertEqual ( ctx . post_handshake_auth , True )
def test_pha_required ( self ) :
client_context , server_context , hostname = testing_context ( )
server_context . post_handshake_auth = True
server_context . verify_mode = ssl . CERT_REQUIRED
client_context . post_handshake_auth = True
client_context . load_cert_chain ( SIGNED_CERTFILE )
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' FALSE \n ' )
s . write ( b ' PHA ' )
self . assertEqual ( s . recv ( 1024 ) , b ' OK \n ' )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' TRUE \n ' )
# PHA method just returns true when cert is already available
s . write ( b ' PHA ' )
self . assertEqual ( s . recv ( 1024 ) , b ' OK \n ' )
s . write ( b ' GETCERT ' )
cert_text = s . recv ( 4096 ) . decode ( ' us-ascii ' )
self . assertIn ( ' Python Software Foundation CA ' , cert_text )
def test_pha_required_nocert ( self ) :
client_context , server_context , hostname = testing_context ( )
server_context . post_handshake_auth = True
server_context . verify_mode = ssl . CERT_REQUIRED
client_context . post_handshake_auth = True
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
s . write ( b ' PHA ' )
# receive CertificateRequest
self . assertEqual ( s . recv ( 1024 ) , b ' OK \n ' )
# send empty Certificate + Finish
s . write ( b ' HASCERT ' )
# receive alert
with self . assertRaisesRegex (
ssl . SSLError ,
' tlsv13 alert certificate required ' ) :
s . recv ( 1024 )
def test_pha_optional ( self ) :
if support . verbose :
sys . stdout . write ( " \n " )
client_context , server_context , hostname = testing_context ( )
server_context . post_handshake_auth = True
server_context . verify_mode = ssl . CERT_REQUIRED
client_context . post_handshake_auth = True
client_context . load_cert_chain ( SIGNED_CERTFILE )
# check CERT_OPTIONAL
server_context . verify_mode = ssl . CERT_OPTIONAL
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' FALSE \n ' )
s . write ( b ' PHA ' )
self . assertEqual ( s . recv ( 1024 ) , b ' OK \n ' )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' TRUE \n ' )
def test_pha_optional_nocert ( self ) :
if support . verbose :
sys . stdout . write ( " \n " )
client_context , server_context , hostname = testing_context ( )
server_context . post_handshake_auth = True
server_context . verify_mode = ssl . CERT_OPTIONAL
client_context . post_handshake_auth = True
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' FALSE \n ' )
s . write ( b ' PHA ' )
self . assertEqual ( s . recv ( 1024 ) , b ' OK \n ' )
# optional doens't fail when client does not have a cert
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' FALSE \n ' )
def test_pha_no_pha_client ( self ) :
client_context , server_context , hostname = testing_context ( )
server_context . post_handshake_auth = True
server_context . verify_mode = ssl . CERT_REQUIRED
client_context . load_cert_chain ( SIGNED_CERTFILE )
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
with self . assertRaisesRegex ( ssl . SSLError , ' not server ' ) :
s . verify_client_post_handshake ( )
s . write ( b ' PHA ' )
self . assertIn ( b ' extension not received ' , s . recv ( 1024 ) )
def test_pha_no_pha_server ( self ) :
# server doesn't have PHA enabled, cert is requested in handshake
client_context , server_context , hostname = testing_context ( )
server_context . verify_mode = ssl . CERT_REQUIRED
client_context . post_handshake_auth = True
client_context . load_cert_chain ( SIGNED_CERTFILE )
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' TRUE \n ' )
# PHA doesn't fail if there is already a cert
s . write ( b ' PHA ' )
self . assertEqual ( s . recv ( 1024 ) , b ' OK \n ' )
s . write ( b ' HASCERT ' )
self . assertEqual ( s . recv ( 1024 ) , b ' TRUE \n ' )
def test_pha_not_tls13 ( self ) :
# TLS 1.2
client_context , server_context , hostname = testing_context ( )
server_context . verify_mode = ssl . CERT_REQUIRED
client_context . maximum_version = ssl . TLSVersion . TLSv1_2
client_context . post_handshake_auth = True
client_context . load_cert_chain ( SIGNED_CERTFILE )
server = ThreadedEchoServer ( context = server_context , chatty = False )
with server :
with client_context . wrap_socket ( socket . socket ( ) ,
server_hostname = hostname ) as s :
s . connect ( ( HOST , server . port ) )
# PHA fails for TLS != 1.3
s . write ( b ' PHA ' )
self . assertIn ( b ' WRONG_SSL_VERSION ' , s . recv ( 1024 ) )
def test_main ( verbose = False ) :
if support . verbose :
import warnings
@ -4183,6 +4373,7 @@ def test_main(verbose=False):
tests = [
ContextTests , BasicSocketTests , SSLErrorTests , MemoryBIOTests ,
SSLObjectTests , SimpleBackgroundTests , ThreadedTests ,
TestPostHandshakeAuth
]
if support . is_resource_enabled ( ' network ' ) :