diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index 3d913a396..750b4a76b 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -36,7 +36,7 @@ VALUE cSSLSocket; static VALUE eSSLErrorWaitReadable; static VALUE eSSLErrorWaitWritable; -static ID id_call, ID_callback_state, id_npn_protocols_encoded, id_each; +static ID id_call, id_npn_protocols_encoded, id_each; static VALUE sym_exception, sym_wait_readable, sym_wait_writable; static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode, @@ -47,16 +47,22 @@ static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode, id_i_session_remove_cb, id_i_npn_select_cb, id_i_npn_protocols, id_i_alpn_select_cb, id_i_alpn_protocols, id_i_servername_cb, id_i_verify_hostname, id_i_keylog_cb, id_i_tmp_dh_callback; -static ID id_i_io, id_i_context, id_i_hostname, id_i_sync_close; +static ID id_i_context, id_i_hostname, id_i_sync_close; -static int ossl_ssl_ex_ptr_idx; +static int ossl_ssl_ex_data_idx; static int ossl_sslctx_ex_ptr_idx; +struct ossl_ssl_data * +ossl_ssl_data(const SSL *ssl) +{ + return SSL_get_ex_data(ssl, ossl_ssl_ex_data_idx); +} + static void ossl_sslctx_mark(void *ptr) { SSL_CTX *ctx = ptr; - rb_gc_mark((VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx)); + rb_gc_mark_movable((VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx)); } static void @@ -65,12 +71,25 @@ ossl_sslctx_free(void *ptr) SSL_CTX_free(ptr); } +static void +ossl_sslctx_compact(void *ptr) +{ + SSL_CTX *ctx = ptr; + VALUE self = (VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx); + if (self) { + (void)SSL_CTX_set_ex_data(ctx, ossl_sslctx_ex_ptr_idx, + (void *)rb_gc_location(self)); + } +} + static const rb_data_type_t ossl_sslctx_type = { - "OpenSSL/SSL/CTX", - { - ossl_sslctx_mark, ossl_sslctx_free, + .wrap_struct_name = "OpenSSL/SSL/CTX", + .function = { + .dmark = ossl_sslctx_mark, + .dfree = ossl_sslctx_free, + .dcompact = ossl_sslctx_compact, }, - 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, + .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; static VALUE @@ -97,37 +116,55 @@ ossl_sslctx_s_alloc(VALUE klass) return obj; } +struct client_cert_cb_args { + VALUE ssl_obj; + VALUE cb; + X509 **x509; + EVP_PKEY **pkey; +}; + static VALUE -ossl_call_client_cert_cb(VALUE obj) +ossl_call_client_cert_cb(VALUE args_) { - VALUE ctx_obj, cb, ary, cert, key; - - ctx_obj = rb_attr_get(obj, id_i_context); - cb = rb_attr_get(ctx_obj, id_i_client_cert_cb); - if (NIL_P(cb)) - return Qnil; - - ary = rb_funcallv(cb, id_call, 1, &obj); + struct client_cert_cb_args *args = (struct client_cert_cb_args *)args_; + VALUE ary = rb_funcall(args->cb, id_call, 1, args->ssl_obj); Check_Type(ary, T_ARRAY); - GetX509CertPtr(cert = rb_ary_entry(ary, 0)); - GetPrivPKeyPtr(key = rb_ary_entry(ary, 1)); + if (RARRAY_LEN(ary) != 2) + rb_raise(rb_eTypeError, "client_cert_cb must return [cert, key]"); + + X509 *cert = GetX509CertPtr(rb_ary_entry(ary, 0)); + EVP_PKEY *pkey = GetPrivPKeyPtr(rb_ary_entry(ary, 1)); + if (!X509_up_ref(cert)) + ossl_raise(eSSLError, "X509_up_ref"); + if (!EVP_PKEY_up_ref(pkey)) { + X509_free(cert); + ossl_raise(eSSLError, "EVP_PKEY_up_ref"); + } - return rb_ary_new3(2, cert, key); + *args->x509 = cert; + *args->pkey = pkey; + return Qnil; } static int ossl_client_cert_cb(SSL *ssl, X509 **x509, EVP_PKEY **pkey) { - VALUE obj, ret; - - obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - ret = rb_protect(ossl_call_client_cert_cb, obj, NULL); - if (NIL_P(ret)) + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + if (p->cb_state) return 0; - *x509 = DupX509CertPtr(RARRAY_AREF(ret, 0)); - *pkey = DupPKeyPtr(RARRAY_AREF(ret, 1)); + VALUE ctx_obj = rb_attr_get(p->self, id_i_context); + VALUE cb = rb_attr_get(ctx_obj, id_i_client_cert_cb); + if (NIL_P(cb)) + return 0; + int state; + struct client_cert_cb_args args = { p->self, cb, x509, pkey }; + rb_protect(ossl_call_client_cert_cb, (VALUE)&args, &state); + if (state) { + p->cb_state = state; + return 0; + } return 1; } @@ -163,12 +200,15 @@ ossl_call_tmp_dh_callback(VALUE arg) static DH * ossl_tmp_dh_callback(SSL *ssl, int is_export, int keylength) { + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + if (p->cb_state) + return NULL; + int state; - VALUE rb_ssl = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - struct tmp_dh_callback_args args = {rb_ssl, is_export, keylength}; + struct tmp_dh_callback_args args = {p->self, is_export, keylength}; VALUE ret = rb_protect(ossl_call_tmp_dh_callback, (VALUE)&args, &state); if (state) { - rb_ivar_set(rb_ssl, ID_callback_state, INT2NUM(state)); + p->cb_state = state; return NULL; } return (DH *)ret; @@ -179,12 +219,13 @@ static VALUE call_verify_certificate_identity(VALUE ctx_v) { X509_STORE_CTX *ctx = (X509_STORE_CTX *)ctx_v; + struct ossl_ssl_data *p; SSL *ssl; - VALUE ssl_obj, hostname, cert_obj; + VALUE hostname, cert_obj; ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); - ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - hostname = rb_attr_get(ssl_obj, id_i_hostname); + p = ossl_ssl_data(ssl); + hostname = rb_attr_get(p->self, id_i_hostname); if (!RTEST(hostname)) { rb_warning("verify_hostname requires hostname to be set"); @@ -199,13 +240,20 @@ call_verify_certificate_identity(VALUE ctx_v) static int ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { - VALUE cb, ssl_obj, sslctx_obj, verify_hostname, ret; + VALUE cb, sslctx_obj, verify_hostname, ret; SSL *ssl; + struct ossl_ssl_data *p; int status; ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); - ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + p = ossl_ssl_data(ssl); + if (p->cb_state) { + if (X509_STORE_CTX_get_error(ctx) == X509_V_OK) + X509_STORE_CTX_set_error(ctx, X509_V_ERR_UNSPECIFIED); + return 0; + } + + sslctx_obj = rb_attr_get(p->self, id_i_context); cb = rb_attr_get(sslctx_obj, id_i_verify_callback); verify_hostname = rb_attr_get(sslctx_obj, id_i_verify_hostname); @@ -213,7 +261,7 @@ ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) !X509_STORE_CTX_get_error_depth(ctx)) { ret = rb_protect(call_verify_certificate_identity, (VALUE)ctx, &status); if (status) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status)); + p->cb_state = status; return 0; } if (ret != Qtrue) { @@ -225,82 +273,89 @@ ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) return ossl_verify_cb_call(cb, preverify_ok, ctx); } +struct sess_get_cb_args { + VALUE ssl_obj; + const unsigned char *data; + int len; +}; + static VALUE -ossl_call_session_get_cb(VALUE ary) +ossl_call_sess_get_cb(VALUE args_) { - VALUE ssl_obj, cb; - - Check_Type(ary, T_ARRAY); - ssl_obj = rb_ary_entry(ary, 0); + struct sess_get_cb_args *args = (struct sess_get_cb_args *)args_; + VALUE ssl_obj = args->ssl_obj, sslctx_obj, cb; - cb = rb_funcall(ssl_obj, rb_intern("session_get_cb"), 0); + sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + cb = rb_attr_get(sslctx_obj, id_i_session_get_cb); if (NIL_P(cb)) return Qnil; - return rb_funcallv(cb, id_call, 1, &ary); + VALUE session_id = rb_str_new((const char *)args->data, args->len); + VALUE ret_obj = rb_funcall(cb, id_call, 1, rb_assoc_new(ssl_obj, session_id)); + // XXX: Should we raise if ret is neither SSLSession nor nil? + if (!rb_obj_is_instance_of(ret_obj, cSSLSession)) + return (VALUE)NULL; + + SSL_SESSION *sess; + GetSSLSession(ret_obj, sess); + return (VALUE)sess; } static SSL_SESSION * ossl_sslctx_session_get_cb(SSL *ssl, const unsigned char *buf, int len, int *copy) { - VALUE ary, ssl_obj, ret_obj; - SSL_SESSION *sess; - int state = 0; + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + if (p->cb_state) + return NULL; OSSL_Debug("SSL SESSION get callback entered"); - ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - ary = rb_ary_new2(2); - rb_ary_push(ary, ssl_obj); - rb_ary_push(ary, rb_str_new((const char *)buf, len)); - - ret_obj = rb_protect(ossl_call_session_get_cb, ary, &state); + struct sess_get_cb_args args = { p->self, buf, len }; + int state; + VALUE ret = rb_protect(ossl_call_sess_get_cb, (VALUE)&args, &state); if (state) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); + p->cb_state = state; return NULL; } - if (!rb_obj_is_instance_of(ret_obj, cSSLSession)) - return NULL; - - GetSSLSession(ret_obj, sess); *copy = 1; - - return sess; + return (SSL_SESSION *)ret; } +struct sess_new_cb_args { + VALUE ssl_obj; + SSL_SESSION *sess; +}; + static VALUE -ossl_call_session_new_cb(VALUE ary) +ossl_call_session_new_cb(VALUE args_) { - VALUE ssl_obj, cb; - - Check_Type(ary, T_ARRAY); - ssl_obj = rb_ary_entry(ary, 0); + struct sess_new_cb_args *args = (struct sess_new_cb_args *)args_; + VALUE ssl_obj = args->ssl_obj, sslctx_obj, cb; - cb = rb_funcall(ssl_obj, rb_intern("session_new_cb"), 0); + sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + cb = rb_attr_get(sslctx_obj, id_i_session_new_cb); if (NIL_P(cb)) return Qnil; - return rb_funcallv(cb, id_call, 1, &ary); + VALUE sess_obj = rb_obj_alloc(cSSLSession); + if (!SSL_SESSION_up_ref(args->sess)) + ossl_raise(eSSLError, "SSL_SESSION_up_ref"); + RTYPEDDATA_DATA(sess_obj) = args->sess; + + return rb_funcall(cb, id_call, 1, rb_assoc_new(ssl_obj, sess_obj)); } /* return 1 normal. return 0 removes the session */ static int ossl_sslctx_session_new_cb(SSL *ssl, SSL_SESSION *sess) { - VALUE ary, ssl_obj, sess_obj; - int state = 0; + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + if (p->cb_state) + return 0; OSSL_Debug("SSL SESSION new callback entered"); - - ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - sess_obj = rb_obj_alloc(cSSLSession); - SSL_SESSION_up_ref(sess); - DATA_PTR(sess_obj) = sess; - - ary = rb_ary_new2(2); - rb_ary_push(ary, ssl_obj); - rb_ary_push(ary, sess_obj); - - rb_protect(ossl_call_session_new_cb, ary, &state); + struct sess_new_cb_args args = { p->self, sess }; + int state; + rb_protect(ossl_call_session_new_cb, (VALUE)&args, &state); if (state) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); + p->cb_state = state; } /* @@ -344,43 +399,47 @@ ossl_call_keylog_cb(VALUE args_v) static void ossl_sslctx_keylog_cb(const SSL *ssl, const char *line) { - VALUE ssl_obj; - struct ossl_call_keylog_cb_args args; - int state = 0; + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + struct ossl_call_keylog_cb_args args = { p->self, line }; + int state; + if (p->cb_state) + return; OSSL_Debug("SSL keylog callback entered"); - ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - args.ssl_obj = ssl_obj; - args.line = line; - rb_protect(ossl_call_keylog_cb, (VALUE)&args, &state); if (state) { - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); + p->cb_state = state; } } #endif +struct sess_remove_cb_args { + SSL_CTX *ctx; + SSL_SESSION *sess; +}; + static VALUE -ossl_call_session_remove_cb(VALUE ary) +ossl_call_session_remove_cb(VALUE args_) { + struct sess_remove_cb_args *args = (struct sess_remove_cb_args *)args_; VALUE sslctx_obj, cb; - Check_Type(ary, T_ARRAY); - sslctx_obj = rb_ary_entry(ary, 0); - + sslctx_obj = (VALUE)SSL_CTX_get_ex_data(args->ctx, ossl_sslctx_ex_ptr_idx); cb = rb_attr_get(sslctx_obj, id_i_session_remove_cb); if (NIL_P(cb)) return Qnil; - return rb_funcallv(cb, id_call, 1, &ary); + VALUE sess_obj = rb_obj_alloc(cSSLSession); + if (!SSL_SESSION_up_ref(args->sess)) + ossl_raise(eSSLError, "SSL_SESSION_up_ref"); + RTYPEDDATA_DATA(sess_obj) = args->sess; + + return rb_funcall(cb, id_call, 1, rb_assoc_new(sslctx_obj, sess_obj)); } static void ossl_sslctx_session_remove_cb(SSL_CTX *ctx, SSL_SESSION *sess) { - VALUE ary, sslctx_obj, sess_obj; - int state = 0; - /* * This callback is also called for all sessions in the internal store * when SSL_CTX_free() is called. @@ -389,23 +448,16 @@ ossl_sslctx_session_remove_cb(SSL_CTX *ctx, SSL_SESSION *sess) return; OSSL_Debug("SSL SESSION remove callback entered"); - - sslctx_obj = (VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx); - sess_obj = rb_obj_alloc(cSSLSession); - SSL_SESSION_up_ref(sess); - DATA_PTR(sess_obj) = sess; - - ary = rb_ary_new2(2); - rb_ary_push(ary, sslctx_obj); - rb_ary_push(ary, sess_obj); - - rb_protect(ossl_call_session_remove_cb, ary, &state); + struct sess_remove_cb_args args = { ctx, sess }; + int state; + rb_protect(ossl_call_session_remove_cb, (VALUE)&args, &state); if (state) { -/* - the SSL_CTX is frozen, nowhere to save state. - there is no common accessor method to check it either. - rb_ivar_set(sslctx_obj, ID_callback_state, INT2NUM(state)); -*/ + /* + * the SSL_CTX is frozen, nowhere to save state. + * there is no common accessor method to check it either. + */ + rb_warn("exception in session_remove_cb is ignored"); + rb_set_errinfo(Qnil); } } @@ -435,10 +487,10 @@ ossl_call_servername_cb(VALUE arg) if (!servername) return Qnil; - VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - VALUE sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + VALUE sslctx_obj = rb_attr_get(p->self, id_i_context); VALUE cb = rb_attr_get(sslctx_obj, id_i_servername_cb); - VALUE ary = rb_assoc_new(ssl_obj, rb_str_new_cstr(servername)); + VALUE ary = rb_assoc_new(p->self, rb_str_new_cstr(servername)); VALUE ret_obj = rb_funcallv(cb, id_call, 1, &ary); if (rb_obj_is_kind_of(ret_obj, cSSLContext)) { @@ -447,7 +499,7 @@ ossl_call_servername_cb(VALUE arg) GetSSLCTX(ret_obj, ctx2); if (!SSL_set_SSL_CTX(ssl, ctx2)) ossl_raise(eSSLError, "SSL_set_SSL_CTX"); - rb_ivar_set(ssl_obj, id_i_context, ret_obj); + rb_ivar_set(p->self, id_i_context, ret_obj); } else if (!NIL_P(ret_obj)) { ossl_raise(rb_eArgError, "servername_cb must return an " "OpenSSL::SSL::SSLContext object or nil"); @@ -459,29 +511,43 @@ ossl_call_servername_cb(VALUE arg) static int ssl_servername_cb(SSL *ssl, int *ad, void *arg) { + struct ossl_ssl_data *p = ossl_ssl_data(ssl); int state; + if (p->cb_state) + return SSL_TLSEXT_ERR_ALERT_FATAL; rb_protect(ossl_call_servername_cb, (VALUE)ssl, &state); if (state) { - VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state)); + p->cb_state = state; return SSL_TLSEXT_ERR_ALERT_FATAL; } return SSL_TLSEXT_ERR_OK; } -static void -ssl_renegotiation_cb(const SSL *ssl) +static VALUE +call_renegotiation_cb(VALUE args_) { - VALUE ssl_obj, sslctx_obj, cb; + VALUE *args = (VALUE *)args_; + return rb_funcall(args[0], id_call, 1, args[1]); +} - ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - sslctx_obj = rb_attr_get(ssl_obj, id_i_context); - cb = rb_attr_get(sslctx_obj, id_i_renegotiation_cb); - if (NIL_P(cb)) return; +/* This function may serve as the entry point to support further callbacks. */ +static void +ssl_info_cb(const SSL *ssl, int where, int val) +{ + struct ossl_ssl_data *p = ossl_ssl_data(ssl); - rb_funcallv(cb, id_call, 1, &ssl_obj); + if (p->cb_state) + return; + if (where & SSL_CB_HANDSHAKE_START && SSL_is_server(ssl)) { + VALUE sslctx_obj = rb_attr_get(p->self, id_i_context); + VALUE cb = rb_attr_get(sslctx_obj, id_i_renegotiation_cb); + if (!NIL_P(cb)) { + VALUE args[] = { cb, p->self }; + rb_protect(call_renegotiation_cb, (VALUE)&args, &p->cb_state); + } + } } static VALUE @@ -544,19 +610,21 @@ ssl_npn_select_cb_common(SSL *ssl, VALUE cb, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen) { + struct ossl_ssl_data *p = ossl_ssl_data(ssl); VALUE selected; - int status; + int state; struct npn_select_cb_common_args args; + if (p->cb_state) + return SSL_TLSEXT_ERR_ALERT_FATAL; + args.cb = cb; args.in = in; args.inlen = inlen; - selected = rb_protect(npn_select_cb_common_i, (VALUE)&args, &status); - if (status) { - VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - - rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status)); + selected = rb_protect(npn_select_cb_common_i, (VALUE)&args, &state); + if (state) { + p->cb_state = state; return SSL_TLSEXT_ERR_ALERT_FATAL; } @@ -605,17 +673,6 @@ ssl_alpn_select_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, return ssl_npn_select_cb_common(ssl, cb, out, outlen, in, inlen); } -/* This function may serve as the entry point to support further callbacks. */ -static void -ssl_info_cb(const SSL *ssl, int where, int val) -{ - int is_server = SSL_is_server((SSL *)ssl); - - if (is_server && where & SSL_CB_HANDSHAKE_START) { - ssl_renegotiation_cb(ssl); - } -} - /* * call-seq: * ctx.options -> integer @@ -779,6 +836,9 @@ ossl_sslctx_setup(VALUE self) val = rb_attr_get(self, id_i_verify_depth); if(!NIL_P(val)) SSL_CTX_set_verify_depth(ctx, NUM2INT(val)); + if (!NIL_P(rb_attr_get(self, id_i_renegotiation_cb))) + SSL_CTX_set_info_callback(ctx, ssl_info_cb); + #ifdef OSSL_USE_NEXTPROTONEG val = rb_attr_get(self, id_i_npn_protocols); if (!NIL_P(val)) { @@ -1567,21 +1627,37 @@ static void ossl_ssl_mark(void *ptr) { SSL *ssl = ptr; - rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx)); + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + rb_gc_mark_movable(p->self); + rb_gc_mark_movable(p->io); } static void -ossl_ssl_free(void *ssl) +ossl_ssl_free(void *ptr) { + SSL *ssl = ptr; + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + ruby_xfree(p); SSL_free(ssl); } +static void +ossl_ssl_compact(void *ptr) +{ + SSL *ssl = ptr; + struct ossl_ssl_data *p = ossl_ssl_data(ssl); + p->self = rb_gc_location(p->self); + p->io = rb_gc_location(p->io); +} + const rb_data_type_t ossl_ssl_type = { - "OpenSSL/SSL", - { - ossl_ssl_mark, ossl_ssl_free, + .wrap_struct_name = "OpenSSL/SSL", + .function = { + .dmark = ossl_ssl_mark, + .dfree = ossl_ssl_free, + .dcompact = ossl_ssl_compact, }, - 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, + .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; static VALUE @@ -1593,8 +1669,9 @@ ossl_ssl_s_alloc(VALUE klass) static VALUE peer_ip_address(VALUE io) { - VALUE remote_address = rb_funcall(io, rb_intern("remote_address"), 0); - + VALUE remote_address = rb_check_funcall(io, rb_intern("remote_address"), 0, NULL); + if (remote_address == Qundef) + return rb_str_new_cstr("(io.remote_address not supported)"); return rb_funcall(remote_address, rb_intern("inspect_sockaddr"), 0); } @@ -1611,14 +1688,24 @@ peeraddr_ip_str(VALUE io) rb_eSystemCallError, (VALUE)0); } +static int +is_real_socket(VALUE io) +{ + // FIXME: DO NOT MERGE + return 0; + return RB_TYPE_P(io, T_FILE); +} + /* * call-seq: * SSLSocket.new(io) => aSSLSocket * SSLSocket.new(io, ctx) => aSSLSocket * SSLSocket.new(io, ctx, sync_close:) => aSSLSocket * - * Creates a new SSL socket from _io_ which must be a real IO object (not an - * IO-like object that responds to read/write). + * Creates a new SSL socket from the underlying socket _io_ and _ctx_. + * + * _io_ must be an IO object, typically a TCPSocket or Socket from the socket + * library, or an IO-like object that supports the typical IO methods. * * If _ctx_ is provided the SSL Sockets initial params will be taken from * the context. @@ -1631,6 +1718,22 @@ peeraddr_ip_str(VALUE io) * * This method will freeze the SSLContext if one is provided; * however, session management is still allowed in the frozen SSLContext. + * + * == Support for IO-like objects + * + * Support for IO-like objects was added in version 4.1 and is considered + * experimental. The requirements for the objects may change in future versions. + * + * As of version 4.1, SSLSocket expects the following methods to be compatible + * with core IO objects: + * + * - write_nonblock with the exception: false option + * - read_nonblock with the exception: false option + * - wait_readable + * - wait_writable + * - close + * - closed? + * - sync and flush (optional) */ static VALUE ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) @@ -1663,25 +1766,42 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) rb_ivar_set(self, id_i_context, v_ctx); ossl_sslctx_setup(v_ctx); - if (rb_respond_to(io, rb_intern("nonblock="))) - rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - Check_Type(io, T_FILE); - rb_ivar_set(self, id_i_io, io); + struct ossl_ssl_data *p = RB_ZALLOC(struct ossl_ssl_data); + p->self = self; + p->io = io; ssl = SSL_new(ctx); - if (!ssl) - ossl_raise(eSSLError, NULL); - RTYPEDDATA_DATA(self) = ssl; - - if (!SSL_set_ex_data(ssl, ossl_ssl_ex_ptr_idx, (void *)self)) + if (!ssl) { + ruby_xfree(p); + ossl_raise(eSSLError, "SSL_new"); + } + if (!SSL_set_ex_data(ssl, ossl_ssl_ex_data_idx, p)) { + ruby_xfree(p); + SSL_free(ssl); ossl_raise(eSSLError, "SSL_set_ex_data"); - SSL_set_info_callback(ssl, ssl_info_cb); + } + RTYPEDDATA_DATA(self) = ssl; rb_call_super(0, NULL); return self; } +/* + * call-seq: + * ssl.io -> io + * ssl.to_io -> io + * + * Returns the underlying IO object. + */ +static VALUE +ossl_ssl_get_io(VALUE self) +{ + SSL *ssl; + GetSSL(self, ssl); + return ossl_ssl_data(ssl)->io; +} + #ifndef HAVE_RB_IO_DESCRIPTOR static int io_descriptor_fallback(VALUE io) @@ -1696,27 +1816,42 @@ io_descriptor_fallback(VALUE io) static VALUE ossl_ssl_setup(VALUE self) { - VALUE io; SSL *ssl; - rb_io_t *fptr; + struct ossl_ssl_data *p; GetSSL(self, ssl); + p = ossl_ssl_data(ssl); if (ssl_started(ssl)) return Qtrue; - io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - rb_io_check_readable(fptr); - rb_io_check_writable(fptr); - if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) - ossl_raise(eSSLError, "SSL_set_fd"); + if (is_real_socket(p->io)) { + rb_io_t *fptr; + GetOpenFile(p->io, fptr); + rb_io_check_readable(fptr); + rb_io_check_writable(fptr); + rb_io_set_nonblock(fptr); + if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(p->io)))) + ossl_raise(eSSLError, "SSL_set_fd"); + } + else { + BIO *bio = ossl_ssl_bio_setup(p); + if (!BIO_up_ref(bio)) { + BIO_free(bio); + ossl_raise(eSSLError, "BIO_up_ref"); + } + SSL_set_bio(ssl, bio, bio); + } return Qtrue; } static int -errno_mapped(void) +errno_mapped(struct ossl_ssl_data *p) { + /* ossl_ssl_bio_method -> errno must not be used */ + if (!is_real_socket(p->io)) + return 0; + #ifdef _WIN32 return rb_w32_map_errno(WSAGetLastError()); #else @@ -1762,6 +1897,11 @@ no_exception_p(VALUE opts) static void io_wait_writable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_wait(io, INT2NUM(RUBY_IO_WRITABLE), RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); @@ -1776,6 +1916,11 @@ io_wait_writable(VALUE io) static void io_wait_readable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_wait(io, INT2NUM(RUBY_IO_READABLE), RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); @@ -1791,23 +1936,21 @@ static VALUE ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) { SSL *ssl; - VALUE cb_state; int nonblock = opts != Qfalse; - rb_ivar_set(self, ID_callback_state, Qnil); - GetSSL(self, ssl); + struct ossl_ssl_data *p = ossl_ssl_data(ssl); - VALUE io = rb_attr_get(self, id_i_io); for (;;) { int ret = func(ssl); - int saved_errno = errno_mapped(); + int saved_errno = errno_mapped(p); - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { + if (p->cb_state) { /* must cleanup OpenSSL error stack before re-raising */ ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); + int state = p->cb_state; + p->cb_state = 0; + rb_jump_tag(state); } if (ret > 0) @@ -1818,12 +1961,12 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) case SSL_ERROR_WANT_WRITE: if (no_exception_p(opts)) { return sym_wait_writable; } write_would_block(nonblock); - io_wait_writable(io); + io_wait_writable(p->io); continue; case SSL_ERROR_WANT_READ: if (no_exception_p(opts)) { return sym_wait_readable; } read_would_block(nonblock); - io_wait_readable(io); + io_wait_readable(p->io); continue; case SSL_ERROR_SYSCALL: #ifdef __APPLE__ @@ -1856,7 +1999,7 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) code == SSL_ERROR_SYSCALL ? " SYSCALL" : "", code, saved_errno, - peeraddr_ip_str(io), + peeraddr_ip_str(p->io), SSL_state_string_long(ssl), error_append); } @@ -1965,8 +2108,9 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) { SSL *ssl; int ilen; - VALUE len, str, cb_state; + VALUE len, str; VALUE opts = Qnil; + struct ossl_ssl_data *p; if (nonblock) { rb_scan_args(argc, argv, "11:", &len, &str, &opts); @@ -1974,6 +2118,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_scan_args(argc, argv, "11", &len, &str); } GetSSL(self, ssl); + p = ossl_ssl_data(ssl); if (!ssl_started(ssl)) rb_raise(eSSLError, "SSL session is not started yet"); @@ -1993,19 +2138,17 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) return str; } - VALUE io = rb_attr_get(self, id_i_io); - for (;;) { rb_str_locktmp(str); int nread = SSL_read(ssl, RSTRING_PTR(str), ilen); - int saved_errno = errno_mapped(); + int saved_errno = errno_mapped(p); rb_str_unlocktmp(str); - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - rb_ivar_set(self, ID_callback_state, Qnil); + if (p->cb_state) { ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); + int state = p->cb_state; + p->cb_state = 0; + rb_jump_tag(state); } switch (SSL_get_error(ssl, nread)) { @@ -2020,14 +2163,14 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) if (no_exception_p(opts)) { return sym_wait_writable; } write_would_block(nonblock); } - io_wait_writable(io); + io_wait_writable(p->io); break; case SSL_ERROR_WANT_READ: if (nonblock) { if (no_exception_p(opts)) { return sym_wait_readable; } read_would_block(nonblock); } - io_wait_readable(io); + io_wait_readable(p->io); break; case SSL_ERROR_SYSCALL: if (!ERR_peek_error()) { @@ -2099,17 +2242,14 @@ ossl_ssl_write_internal_safe(VALUE _args) VALUE opts = args[2]; SSL *ssl; - rb_io_t *fptr; + struct ossl_ssl_data *p; int num, nonblock = opts != Qfalse; - VALUE cb_state; GetSSL(self, ssl); + p = ossl_ssl_data(ssl); if (!ssl_started(ssl)) rb_raise(eSSLError, "SSL session is not started yet"); - VALUE io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - /* SSL_write(3ssl) manpage states num == 0 is undefined */ num = RSTRING_LENINT(str); if (num == 0) @@ -2117,13 +2257,13 @@ ossl_ssl_write_internal_safe(VALUE _args) for (;;) { int nwritten = SSL_write(ssl, RSTRING_PTR(str), num); - int saved_errno = errno_mapped(); + int saved_errno = errno_mapped(p); - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - rb_ivar_set(self, ID_callback_state, Qnil); + if (p->cb_state) { ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); + int state = p->cb_state; + p->cb_state = 0; + rb_jump_tag(state); } switch (SSL_get_error(ssl, nwritten)) { @@ -2132,12 +2272,12 @@ ossl_ssl_write_internal_safe(VALUE _args) case SSL_ERROR_WANT_WRITE: if (no_exception_p(opts)) { return sym_wait_writable; } write_would_block(nonblock); - io_wait_writable(io); + io_wait_writable(p->io); continue; case SSL_ERROR_WANT_READ: if (no_exception_p(opts)) { return sym_wait_readable; } read_would_block(nonblock); - io_wait_readable(io); + io_wait_readable(p->io); continue; case SSL_ERROR_SYSCALL: #ifdef __APPLE__ @@ -2224,12 +2364,27 @@ static VALUE ossl_ssl_stop(VALUE self) { SSL *ssl; + struct ossl_ssl_data *p; int ret; GetSSL(self, ssl); + p = ossl_ssl_data(ssl); if (!ssl_started(ssl)) return Qnil; + ret = SSL_shutdown(ssl); + + /* + * XXX: SSLSocket#stop is supposed to ignore errors due to the underlying + * socket being unreadable/unwritable. We want to suppress exceptions raised + * by the underlying IO-like object, but not from any SSL callbacks. Can SSL + * callbacks be invoked during SSL_shutdown()? We assume no, for now. + */ + if (!is_real_socket(p->io) && p->cb_state) { + p->cb_state = 0; + rb_set_errinfo(Qnil); + } + if (ret == 1) /* Have already received close_notify */ return Qnil; if (ret == 0) /* Sent close_notify, but we don't wait for reply */ @@ -2738,11 +2893,8 @@ Init_ossl_ssl(void) #endif #ifndef OPENSSL_NO_SOCK - id_call = rb_intern_const("call"); - ID_callback_state = rb_intern_const("callback_state"); - - ossl_ssl_ex_ptr_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_ptr_idx", 0, 0, 0); - if (ossl_ssl_ex_ptr_idx < 0) + ossl_ssl_ex_data_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_data_idx", 0, 0, 0); + if (ossl_ssl_ex_data_idx < 0) ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index"); ossl_sslctx_ex_ptr_idx = SSL_CTX_get_ex_new_index(0, (void *)"ossl_sslctx_ex_ptr_idx", 0, 0, 0); if (ossl_sslctx_ex_ptr_idx < 0) @@ -2768,6 +2920,7 @@ Init_ossl_ssl(void) rb_include_module(eSSLErrorWaitWritable, rb_mWaitWritable); Init_ossl_ssl_session(); + Init_ossl_ssl_bio(); /* Document-class: OpenSSL::SSL::SSLContext * @@ -3134,6 +3287,8 @@ Init_ossl_ssl(void) rb_define_alloc_func(cSSLSocket, ossl_ssl_s_alloc); rb_define_method(cSSLSocket, "initialize", ossl_ssl_initialize, -1); rb_undef_method(cSSLSocket, "initialize_copy"); + rb_define_method(cSSLSocket, "io", ossl_ssl_get_io, 0); + rb_define_alias(cSSLSocket, "to_io", "io"); rb_define_method(cSSLSocket, "connect", ossl_ssl_connect, 0); rb_define_method(cSSLSocket, "connect_nonblock", ossl_ssl_connect_nonblock, -1); rb_define_method(cSSLSocket, "accept", ossl_ssl_accept, 0); @@ -3294,6 +3449,7 @@ Init_ossl_ssl(void) sym_wait_readable = ID2SYM(rb_intern_const("wait_readable")); sym_wait_writable = ID2SYM(rb_intern_const("wait_writable")); + id_call = rb_intern_const("call"); id_npn_protocols_encoded = rb_intern_const("npn_protocols_encoded"); id_each = rb_intern_const("each"); @@ -3326,7 +3482,6 @@ Init_ossl_ssl(void) DefIVarID(keylog_cb); DefIVarID(tmp_dh_callback); - DefIVarID(io); DefIVarID(context); DefIVarID(hostname); DefIVarID(sync_close); diff --git a/ext/openssl/ossl_ssl.h b/ext/openssl/ossl_ssl.h index a87e62d45..88eadf709 100644 --- a/ext/openssl/ossl_ssl.h +++ b/ext/openssl/ossl_ssl.h @@ -30,7 +30,20 @@ extern VALUE mSSL; extern VALUE cSSLSocket; extern VALUE cSSLSession; +struct ossl_ssl_data { + VALUE self; + VALUE io; + int cb_state; + + /* Used by ossl_ssl_bio_method */ + int bio_eof; +}; +struct ossl_ssl_data *ossl_ssl_data(const SSL *ssl); + +BIO *ossl_ssl_bio_setup(struct ossl_ssl_data *p); + void Init_ossl_ssl(void); void Init_ossl_ssl_session(void); +void Init_ossl_ssl_bio(void); #endif /* _OSSL_SSL_H_ */ diff --git a/ext/openssl/ossl_ssl_bio.c b/ext/openssl/ossl_ssl_bio.c new file mode 100644 index 000000000..25a0ca661 --- /dev/null +++ b/ext/openssl/ossl_ssl_bio.c @@ -0,0 +1,241 @@ +/* + * Ruby/OpenSSL Project + * Copyright (C) 2025 Kazuki Yamaguchi + */ +#include "ossl.h" + +static BIO_METHOD *ossl_bio_meth; +static VALUE nonblock_kwargs, sym_wait_readable, sym_wait_writable; + +BIO * +ossl_ssl_bio_setup(struct ossl_ssl_data *p) +{ + BIO *bio = BIO_new(ossl_bio_meth); + if (!bio) + ossl_raise(eOSSLError, "BIO_new"); + + BIO_set_data(bio, p); + BIO_set_init(bio, 1); + return bio; +} + +struct call0_args { + VALUE (*func)(VALUE); + VALUE args; + VALUE ret; +}; + +static VALUE +do_nothing(VALUE _) +{ + return Qnil; +} + +static VALUE +call_protect1(VALUE args_) +{ + struct call0_args *args = (void *)args_; + rb_set_errinfo(Qnil); + args->ret = args->func(args->args); + return Qnil; +} + +static VALUE +call_protect0(VALUE args_) +{ + rb_ensure(do_nothing, Qnil, call_protect1, args_); + return Qnil; +} + +static VALUE +call_protect(VALUE (*func)(VALUE), VALUE args, int *state, int current) +{ + if (!current) + return rb_protect(func, args, state); + + VALUE errinfo = rb_errinfo(); + struct call0_args call0_args = { func, args, Qnil }; + rb_protect(call_protect0, (VALUE)&call0_args, state); + if (*state) { + if (!rb_obj_is_kind_of(errinfo, rb_eException)) + errinfo = rb_str_new_cstr("(unknown)"); + rb_warn("BIO callback raised an exception, pending jump suppressed: " \ + "state=%d, errinfo=%+"PRIsVALUE, current, errinfo); + } + return call0_args.ret; +} + + +struct bwrite_args { + struct ossl_ssl_data *p; + BIO *bio; + const char *data; + int dlen; + int written; +}; + +static VALUE +bio_bwrite0(VALUE args_) +{ + struct bwrite_args *args = (void *)args_; + struct ossl_ssl_data *p = args->p; + + BIO_clear_retry_flags(args->bio); + + VALUE fargs[] = { rb_str_new_static(args->data, args->dlen), nonblock_kwargs }; + VALUE ret = rb_funcallv_kw(p->io, rb_intern("write_nonblock"), + 2, fargs, RB_PASS_KEYWORDS); + + if (RB_INTEGER_TYPE_P(ret)) { + args->written = NUM2INT(ret); + return Qtrue; + } + else if (ret == sym_wait_readable) { + BIO_set_retry_read(args->bio); + return Qfalse; + } + else if (ret == sym_wait_writable) { + BIO_set_retry_write(args->bio); + return Qfalse; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} + +static int +bio_bwrite(BIO *bio, const char *data, int dlen) +{ + struct ossl_ssl_data *p = BIO_get_data(bio); + + struct bwrite_args args = { p, bio, data, dlen, 0 }; + int state; + VALUE ok = call_protect(bio_bwrite0, (VALUE)&args, &state, p->cb_state); + if (state) { + p->cb_state = state; + return -1; + } + if (RTEST(ok)) + return args.written; + return -1; +} + +struct bread_args { + struct ossl_ssl_data *p; + BIO *bio; + char *data; + int dlen; + int readbytes; +}; + +static VALUE +bio_bread0(VALUE args_) +{ + struct bread_args *args = (void *)args_; + struct ossl_ssl_data *p = args->p; + + BIO_clear_retry_flags(args->bio); + + VALUE fargs[] = { INT2NUM(args->dlen), nonblock_kwargs }; + VALUE ret = rb_funcallv_kw(p->io, rb_intern("read_nonblock"), + 2, fargs, RB_PASS_KEYWORDS); + + if (RB_TYPE_P(ret, T_STRING)) { + int len = RSTRING_LENINT(ret); + if (len > args->dlen) + rb_raise(rb_eTypeError, "read_nonblock returned too much data"); + memcpy(args->data, RSTRING_PTR(ret), len); + args->readbytes = len; + return Qtrue; + } + else if (NIL_P(ret)) { + // In OpenSSL 3.0 or later: BIO_set_flags(args->bio, BIO_FLAGS_IN_EOF); + p->bio_eof = 1; + return Qtrue; + } + else if (ret == sym_wait_readable) { + BIO_set_retry_read(args->bio); + return Qfalse; + } + else if (ret == sym_wait_writable) { + BIO_set_retry_write(args->bio); + return Qfalse; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} + +static int +bio_bread(BIO *bio, char *data, int dlen) +{ + struct ossl_ssl_data *p = BIO_get_data(bio); + + struct bread_args args = { p, bio, data, dlen, 0 }; + int state; + VALUE ok = call_protect(bio_bread0, (VALUE)&args, &state, p->cb_state); + if (state) { + p->cb_state = state; + return -1; + } + if (RTEST(ok)) + return args.readbytes; + return -1; +} + +static VALUE +bio_flush0(VALUE p_) +{ + struct ossl_ssl_data *p = (void *)p_; + /* + * If the underlying IO-like object does not respond to flush, let's just + * assume that it does not need to be flushed. + */ + return rb_check_funcall(p->io, rb_intern("flush"), 0, NULL); +} + +static long +bio_ctrl(BIO *bio, int cmd, long larg, void *parg) +{ + struct ossl_ssl_data *p = BIO_get_data(bio); + int state; + + switch (cmd) { + case BIO_CTRL_EOF: + return p->bio_eof; + case BIO_CTRL_FLUSH: + call_protect(bio_flush0, (VALUE)p, &state, p->cb_state); + if (state) { + p->cb_state = state; + return 0; + } + return 1; + default: + return 0; + } +} + +void +Init_ossl_ssl_bio(void) +{ + ossl_bio_meth = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "Ruby IO-like object"); + if (!ossl_bio_meth) + ossl_raise(eOSSLError, "BIO_meth_new"); + if (!BIO_meth_set_write(ossl_bio_meth, bio_bwrite) || + !BIO_meth_set_read(ossl_bio_meth, bio_bread) || + !BIO_meth_set_ctrl(ossl_bio_meth, bio_ctrl)) { + BIO_meth_free(ossl_bio_meth); + ossl_bio_meth = NULL; + ossl_raise(eOSSLError, "BIO_meth_set_*"); + } + + nonblock_kwargs = rb_hash_new(); + rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern_const("exception")), Qfalse); + rb_global_variable(&nonblock_kwargs); + + sym_wait_readable = ID2SYM(rb_intern_const("wait_readable")); + sym_wait_writable = ID2SYM(rb_intern_const("wait_writable")); +} + diff --git a/ext/openssl/ossl_x509store.c b/ext/openssl/ossl_x509store.c index bed124dc3..ee9f3d0e3 100644 --- a/ext/openssl/ossl_x509store.c +++ b/ext/openssl/ossl_x509store.c @@ -116,10 +116,9 @@ static void ossl_x509store_mark(void *ptr) { X509_STORE *store = ptr; - // Note: this reference is stored as @verify_callback so we don't need to mark it. - // However we do need to ensure GC compaction won't move it, hence why - // we call rb_gc_mark here. - rb_gc_mark((VALUE)X509_STORE_get_ex_data(store, store_ex_verify_cb_idx)); + VALUE verify_cb = + (VALUE)X509_STORE_get_ex_data(store, store_ex_verify_cb_idx); + rb_gc_mark_movable(verify_cb); } static void @@ -128,12 +127,26 @@ ossl_x509store_free(void *ptr) X509_STORE_free(ptr); } +static void +ossl_x509store_compact(void *ptr) +{ + X509_STORE *store = ptr; + VALUE verify_cb = + (VALUE)X509_STORE_get_ex_data(store, store_ex_verify_cb_idx); + if (verify_cb) { + (void)X509_STORE_set_ex_data(store, store_ex_verify_cb_idx, + (void *)rb_gc_location(verify_cb)); + } +} + static const rb_data_type_t ossl_x509store_type = { - "OpenSSL/X509/STORE", - { - ossl_x509store_mark, ossl_x509store_free, + .wrap_struct_name = "OpenSSL/X509/STORE", + .function = { + .dmark = ossl_x509store_mark, + .dfree = ossl_x509store_free, + .dcompact = ossl_x509store_compact, }, - 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, + .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; /* @@ -570,10 +583,9 @@ static void ossl_x509stctx_mark(void *ptr) { X509_STORE_CTX *ctx = ptr; - // Note: this reference is stored as @verify_callback so we don't need to mark it. - // However we do need to ensure GC compaction won't move it, hence why - // we call rb_gc_mark here. - rb_gc_mark((VALUE)X509_STORE_CTX_get_ex_data(ctx, stctx_ex_verify_cb_idx)); + VALUE verify_cb = + (VALUE)X509_STORE_CTX_get_ex_data(ctx, stctx_ex_verify_cb_idx); + rb_gc_mark_movable(verify_cb); } static void @@ -585,12 +597,26 @@ ossl_x509stctx_free(void *ptr) X509_STORE_CTX_free(ctx); } +static void +ossl_x509stctx_compact(void *ptr) +{ + X509_STORE_CTX *ctx = ptr; + VALUE verify_cb = + (VALUE)X509_STORE_CTX_get_ex_data(ctx, stctx_ex_verify_cb_idx); + if (verify_cb) { + (void)X509_STORE_CTX_set_ex_data(ctx, stctx_ex_verify_cb_idx, + (void *)rb_gc_location(verify_cb)); + } +} + static const rb_data_type_t ossl_x509stctx_type = { - "OpenSSL/X509/STORE_CTX", - { - ossl_x509stctx_mark, ossl_x509stctx_free, + .wrap_struct_name = "OpenSSL/X509/STORE_CTX", + .function = { + .dmark = ossl_x509stctx_mark, + .dfree = ossl_x509stctx_free, + .dcompact = ossl_x509stctx_compact, }, - 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, + .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; static VALUE diff --git a/lib/openssl/buffering.rb b/lib/openssl/buffering.rb index 1464a4292..e8e99868c 100644 --- a/lib/openssl/buffering.rb +++ b/lib/openssl/buffering.rb @@ -60,7 +60,7 @@ def initialize(*) super @eof = false @rbuffer = Buffer.new - @sync = @io.sync + @sync = to_io.respond_to?(:sync) ? to_io.sync : true end # diff --git a/lib/openssl/ssl.rb b/lib/openssl/ssl.rb index dccc11a55..23bd2bdab 100644 --- a/lib/openssl/ssl.rb +++ b/lib/openssl/ssl.rb @@ -336,10 +336,6 @@ class SSLSocket attr_reader :hostname - # The underlying IO object. - attr_reader :io - alias :to_io :io - # The SSLContext object used in this connection. attr_reader :context @@ -428,18 +424,6 @@ def using_anon_cipher? ctx.ciphers.include?(cipher) end - def client_cert_cb - @context.client_cert_cb - end - - def session_new_cb - @context.session_new_cb - end - - def session_get_cb - @context.session_get_cb - end - class << self # call-seq: diff --git a/test/openssl/test_buffering.rb b/test/openssl/test_buffering.rb index 466bbcfa2..95361d440 100644 --- a/test/openssl/test_buffering.rb +++ b/test/openssl/test_buffering.rb @@ -7,7 +7,8 @@ class OpenSSL::TestBuffering < OpenSSL::TestCase class IO include OpenSSL::Buffering - attr_accessor :sync + attr_reader :io + alias to_io io def initialize @io = Buffer.new diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb index 10942191d..dde849b3e 100644 --- a/test/openssl/test_pair.rb +++ b/test/openssl/test_pair.rb @@ -67,6 +67,34 @@ def create_tcp_client(host, port) end end +module OpenSSL::SSLPairIOish + include OpenSSL::SSLPairM + + def create_tcp_server(host, port) + Addrinfo.tcp(host, port).listen + end + + class TCPSocketWrapper + attr_reader :io + alias to_io io + + def initialize(io) @io = io end + def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end + def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end + def wait_readable() @io.wait_readable end + def wait_writable() @io.wait_writable end + def close() @io.close end + def closed?() @io.closed? end + + # Only used within test_pair.rb + def write(*args) @io.write(*args) end + end + + def create_tcp_client(host, port) + TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect) + end +end + module OpenSSL::TestEOF1M def open_file(content) ssl_pair { |s1, s2| @@ -518,6 +546,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF1M end +class OpenSSL::TestEOF1IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF1M +end + class OpenSSL::TestEOF2 < OpenSSL::TestCase include OpenSSL::TestEOF include OpenSSL::SSLPair @@ -530,6 +564,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF2M end +class OpenSSL::TestEOF2IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF2M +end + class OpenSSL::TestPair < OpenSSL::TestCase include OpenSSL::SSLPair include OpenSSL::TestPairM @@ -540,4 +580,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase include OpenSSL::TestPairM end +class OpenSSL::TestPairIOish < OpenSSL::TestCase + include OpenSSL::SSLPairIOish + include OpenSSL::TestPairM +end + end diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb index e4fd58107..ff2821d1b 100644 --- a/test/openssl/test_ssl.rb +++ b/test/openssl/test_ssl.rb @@ -4,17 +4,6 @@ if defined?(OpenSSL::SSL) class OpenSSL::TestSSL < OpenSSL::SSLTestCase - def test_bad_socket - bad_socket = Struct.new(:sync).new - assert_raise TypeError do - socket = OpenSSL::SSL::SSLSocket.new bad_socket - # if the socket is not a T_FILE, `connect` will segv because it tries - # to get the underlying file descriptor but the API it calls assumes - # the object type is T_FILE - socket.connect - end - end - def test_ctx_setup ctx = OpenSSL::SSL::SSLContext.new assert_equal true, ctx.setup @@ -170,6 +159,103 @@ def test_socket_close_write end end + def test_synthetic_io + start_server do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |maxlen, exception:| + tcp.read_nonblock(maxlen, exception: exception) } + obj.define_singleton_method(:write_nonblock) { |str, exception:| + tcp.write_nonblock(str, exception: exception) } + obj.define_singleton_method(:wait_readable) { tcp.wait_readable } + obj.define_singleton_method(:wait_writable) { tcp.wait_writable } + obj.define_singleton_method(:flush) { tcp.flush } + obj.define_singleton_method(:closed?) { tcp.closed? } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_same obj, ssl.to_io + + ssl.connect + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + ensure + ssl&.close + tcp&.close + end + end + + def test_synthetic_io_write_nonblock_exception + start_server(ignore_listener_error: true) do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + [:read_nonblock, :wait_readable, :wait_writable, :closed?].each do |name| + obj.define_singleton_method(name) { |*args, **kwargs| + tcp.__send__(name, *args, **kwargs) } + end + + # SSLSocket#connect calls write_nonblock at least twice: to write + # ClientHello and Finished. Let's raise an exception in the 2nd call. + called = 0 + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + raise "foo" if (called += 1) == 2 + tcp.write_nonblock(*args, **kwargs) + } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_raise_with_message(RuntimeError, "foo") { ssl.connect } + ensure + ssl&.close + tcp&.close + end + end + + def test_synthetic_io_error_in_cb_then_error_in_write + # If SSLContext#servername_cb fails, it must send the "unrecognized_name" + # alert. If another error occurs while writing the alert to the underlying + # socket, the original exception from the servername_cb is suppressed and + # the new exception is raised. + sock1, sock2 = socketpair + + t = Thread.new { + s1 = OpenSSL::SSL::SSLSocket.new(sock1) + s1.hostname = "localhost" + begin + s1.connect + rescue + end + } + + called = [] + ctx2 = OpenSSL::SSL::SSLContext.new + ctx2.servername_cb = lambda { |args| + called << :servername_cb + raise "servername_cb" + } + obj = Object.new + obj.define_singleton_method(:method_missing) { |name, *args, **kwargs| + sock2.__send__(name, *args, **kwargs) + } + obj.define_singleton_method(:respond_to_missing?) { |name, *args, **kwargs| + sock2.respond_to?(name, *args, **kwargs) + } + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + called << :write_nonblock + throw :throw_from, :write_nonblock + } + s2 = OpenSSL::SSL::SSLSocket.new(obj, ctx2) + + ret = assert_warning(/servername_cb/) { + catch(:throw_from) { s2.accept } + } + assert_equal(:write_nonblock, ret) + assert_equal([:servername_cb, :write_nonblock], called) + sock2.close + assert t.join + ensure + sock1.close + sock2.close + t.kill.join + end + def test_add_certificate ctx_proc = -> ctx { # Unset values set by start_server @@ -478,18 +564,29 @@ def test_client_auth_success } end - def test_client_cert_cb_ignore_error + def test_client_cert_cb_bad_return vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT start_server(verify_mode: vflag, ignore_listener_error: true) do |port| ctx = OpenSSL::SSL::SSLContext.new ctx.client_cert_cb = -> ssl { - raise "exception in client_cert_cb must be suppressed" + assert_kind_of(OpenSSL::SSL::SSLSocket, ssl) + [@cli_cert, OpenSSL::PKey.read(@cli_key.public_to_der)] } - # 1. Exception in client_cert_cb is suppressed - # 2. No client certificate will be sent to the server - # 3. SSL_VERIFY_FAIL_IF_NO_PEER_CERT causes the handshake to fail - assert_handshake_error { - server_connect(port, ctx) { |ssl| ssl.puts("abc"); ssl.gets } + assert_raise_with_message(ArgumentError, /private key/) { + server_connect(port, ctx) { raise "unreachable" } + } + end + end + + def test_client_cert_cb_error + vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT + start_server(verify_mode: vflag, ignore_listener_error: true) do |port| + ctx = OpenSSL::SSL::SSLContext.new + ctx.client_cert_cb = -> ssl { + raise "exception in client_cert_cb" + } + assert_raise_with_message(RuntimeError, /exception in client_cert_cb/) { + server_connect(port, ctx) { raise "unreachable" } } end end @@ -1583,7 +1680,12 @@ def test_options_disable_versions # Client only supports TLS 1.3 ctx2 = OpenSSL::SSL::SSLContext.new ctx2.min_version = ctx2.max_version = OpenSSL::SSL::TLS1_3_VERSION - assert_nothing_raised { server_connect(port, ctx2) { } } + assert_nothing_raised { + server_connect(port, ctx2) { |ssl| + # Ensure SSL_accept() finishes successfully + ssl.puts("abc"); ssl.gets + } + } } # Server only supports TLS 1.2 @@ -1616,7 +1718,10 @@ def test_ssl_methods_constant def test_renegotiation_cb num_handshakes = 0 - renegotiation_cb = Proc.new { |ssl| num_handshakes += 1 } + renegotiation_cb = lambda { |ssl| + assert_kind_of(OpenSSL::SSL::SSLSocket, ssl) + num_handshakes += 1 + } ctx_proc = Proc.new { |ctx| ctx.renegotiation_cb = renegotiation_cb } start_server(ctx_proc: ctx_proc) { |port| server_connect(port) { |ssl| @@ -1624,6 +1729,27 @@ def test_renegotiation_cb ssl.puts "abc"; assert_equal "abc\n", ssl.gets } } + + sock1, sock2 = socketpair + th = Thread.new { + ssl2 = OpenSSL::SSL::SSLSocket.new(sock2) + begin + ssl2.connect_nonblock(exception: false) + rescue OpenSSL::SSL::SSLError + end + } + ctx1 = OpenSSL::SSL::SSLContext.new + ctx1.renegotiation_cb = lambda { |ssl| raise "in renegotiation_cb" } + ctx1.add_certificate(@svr_cert, @svr_key) + ssl1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) + assert_raise_with_message(RuntimeError, "in renegotiation_cb") { + ssl1.accept + } + th.join + ensure + th&.kill&.join + sock1&.close + sock2&.close end def test_alpn_protocol_selection_ary