From 96e1a83fb282021448460c42f92ebbcc6078952b Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Fri, 7 May 2021 13:34:43 -0700 Subject: [PATCH] Add Gloo TCP_TLS transport (#56442) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56442 Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D27896285 Pulled By: pbelevich fbshipit-source-id: 589af59ca4c7c9bab2329f079382c09b71cfcf9e --- .jenkins/pytorch/build.sh | 4 ++ .jenkins/pytorch/create_test_cert.py | 96 ++++++++++++++++++++++++++++ .jenkins/pytorch/run_glootls_test.sh | 18 ++++++ .jenkins/pytorch/test.sh | 8 +++ CMakeLists.txt | 7 ++ tools/print_test_stats.py | 4 +- torch/lib/c10d/GlooDeviceFactory.cpp | 33 ++++++++++ 7 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 .jenkins/pytorch/create_test_cert.py create mode 100755 .jenkins/pytorch/run_glootls_test.sh diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index 2f37a51c105..bc309b8a54d 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -204,6 +204,10 @@ if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then ./xla/scripts/apply_patches.sh fi +if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-build || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-build ]]; then + export USE_GLOO_WITH_OPENSSL=ON +fi + if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then set -e diff --git a/.jenkins/pytorch/create_test_cert.py b/.jenkins/pytorch/create_test_cert.py new file mode 100644 index 00000000000..d3ead7ae259 --- /dev/null +++ b/.jenkins/pytorch/create_test_cert.py @@ -0,0 +1,96 @@ +from datetime import datetime, timedelta +from tempfile import mkdtemp +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes + +temp_dir = mkdtemp() +print(temp_dir) + + +def genrsa(path): + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + with open(path, "wb") as f: + f.write(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + )) + return key + + +def create_cert(path, C, ST, L, O, key): + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, C), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST), + x509.NameAttribute(NameOID.LOCALITY_NAME, L), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, O), + ]) + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.utcnow() + ).not_valid_after( + # Our certificate will be valid for 10 days + datetime.utcnow() + timedelta(days=10) + ).add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True, + ).sign(key, hashes.SHA256()) + # Write our certificate out to disk. + with open(path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + return cert + + +def create_req(path, C, ST, L, O, key): + csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([ + # Provide various details about who we are. + x509.NameAttribute(NameOID.COUNTRY_NAME, C), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST), + x509.NameAttribute(NameOID.LOCALITY_NAME, L), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, O), + ])).sign(key, hashes.SHA256()) + with open(path, "wb") as f: + f.write(csr.public_bytes(serialization.Encoding.PEM)) + return csr + + +def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key): + cert = x509.CertificateBuilder().subject_name( + csr_cert.subject + ).issuer_name( + ca_cert.subject + ).public_key( + csr_cert.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.utcnow() + ).not_valid_after( + # Our certificate will be valid for 10 days + datetime.utcnow() + timedelta(days=10) + # Sign our certificate with our private key + ).sign(private_ca_key, hashes.SHA256()) + with open(path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + return cert + + +ca_key = genrsa(temp_dir + "/ca.key") +ca_cert = create_cert(temp_dir + "/ca.pem", u"US", u"New York", u"New York", u"Gloo Certificate Authority", ca_key) + +pkey = genrsa(temp_dir + "/pkey.key") +csr = create_req(temp_dir + "/csr.csr", u"US", u"California", u"San Francisco", u"Gloo Testing Company", pkey) + +cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key) diff --git a/.jenkins/pytorch/run_glootls_test.sh b/.jenkins/pytorch/run_glootls_test.sh new file mode 100755 index 00000000000..fa48066f613 --- /dev/null +++ b/.jenkins/pytorch/run_glootls_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +CREATE_TEST_CERT="$(dirname "${BASH_SOURCE[0]}")/create_test_cert.py" +TMP_CERT_DIR=$(python "$CREATE_TEST_CERT") + +openssl verify -CAfile "${TMP_CERT_DIR}/ca.pem" "${TMP_CERT_DIR}/cert.pem" + +export GLOO_DEVICE_TRANSPORT=TCP_TLS +export GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY=${TMP_CERT_DIR}/pkey.key +export GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT=${TMP_CERT_DIR}/cert.pem +export GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE=${TMP_CERT_DIR}/ca.pem + +time python test/run_test.py --include distributed/test_c10d_gloo --verbose --determine-from="$DETERMINE_FROM" -- ProcessGroupGlooTest + +unset GLOO_DEVICE_TRANSPORT +unset GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY +unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT +unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 6795c519cf2..2005ffc1a64 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -153,6 +153,11 @@ test_python() { assert_git_not_dirty } +test_python_gloo_with_tls() { + source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh" + assert_git_not_dirty +} + test_aten() { # Test ATen @@ -478,6 +483,9 @@ else test_distributed test_benchmarks test_rpc + if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc7-test || "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3.6-gcc5.4-test ]]; then + test_python_gloo_with_tls + fi fi if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then diff --git a/CMakeLists.txt b/CMakeLists.txt index 790feb2cb91..4ad16a8a301 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -277,6 +277,9 @@ cmake_dependent_option( cmake_dependent_option( USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) +cmake_dependent_option( + USE_GLOO_WITH_OPENSSL "Use Gloo with OpenSSL. Only available if USE_GLOO is on." OFF + "USE_GLOO AND LINUX AND NOT INTERN_BUILD_MOBILE" OFF) cmake_dependent_option( USE_TENSORPIPE "Use TensorPipe. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) @@ -327,6 +330,10 @@ if(WIN32) endif() endif() +if(USE_GLOO_WITH_OPENSSL) + set(USE_TCP_OPENSSL_LOAD ON CACHE STRING "") +endif() + # Linux distributions do not want too many embedded sources, in that sense we # need to be able to build pytorch with an (almost) empty third_party # directory. diff --git a/tools/print_test_stats.py b/tools/print_test_stats.py index 1d48f9d5adc..7d7d27f2c4b 100755 --- a/tools/print_test_stats.py +++ b/tools/print_test_stats.py @@ -631,12 +631,14 @@ class TestFile: self.test_suites[suite_name] = TestSuite(suite_name) if test_case.name in self.test_suites[suite_name].test_cases: # We expect duplicate tests for test_cpp_extensions_aot, distributed/test_distributed_fork, - # and distributed/test_distributed_spawn. In these cases, we store the test case that took the longest, + # and distributed/test_distributed_spawn and test_c10d_gloo. + # In these cases, we store the test case that took the longest, # as in these jobs, the duplicate tests are run in parallel. # For other unexpected cases, we should raise a warning. if self.name == 'test_cpp_extensions_aot' or \ self.name == 'distributed/test_distributed_fork' or \ self.name == 'distributed/test_distributed_spawn' or \ + self.name == 'distributed/test_c10d_gloo' or \ self.name == 'cpp': # The caffe2 cpp tests spawn duplicate test cases as well. time_difference = self.test_suites[suite_name].replace(test_case) self.total_time += time_difference diff --git a/torch/lib/c10d/GlooDeviceFactory.cpp b/torch/lib/c10d/GlooDeviceFactory.cpp index dca6b03eb9d..5177e6815ef 100644 --- a/torch/lib/c10d/GlooDeviceFactory.cpp +++ b/torch/lib/c10d/GlooDeviceFactory.cpp @@ -8,6 +8,10 @@ #include #endif +#if GLOO_HAVE_TRANSPORT_TCP_TLS +#include +#endif + #if GLOO_HAVE_TRANSPORT_UV #include #endif @@ -59,6 +63,35 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice); C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice); #endif +#if GLOO_HAVE_TRANSPORT_TCP_TLS +static std::string cstr_to_std_string(const char* chars) { + return std::string (chars != nullptr ? chars : ""); +} + +static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( + const std::string& interface, + const std::string& hostname) { + TORCH_CHECK( + !interface.empty() || !hostname.empty(), + "GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname " + "can't be empty"); + + ::gloo::transport::tcp::attr attr; + if (!interface.empty()) { + attr.iface = interface; + } else { + attr.hostname = hostname; + } + const auto pkey = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY")); + const auto cert = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT")); + const auto caFile = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE")); + const auto caPath = cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH")); + return ::gloo::transport::tcp::tls::CreateDevice(attr, pkey, cert, caFile, caPath); +} + +C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice); +#endif + #if GLOO_HAVE_TRANSPORT_UV static std::shared_ptr<::gloo::transport::Device> makeUVDevice( const std::string& interfaceName,