diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c
index 3d913a396..750b4a76b 100644
--- a/ext/openssl/ossl_ssl.c
+++ b/ext/openssl/ossl_ssl.c
@@ -36,7 +36,7 @@ VALUE cSSLSocket;
static VALUE eSSLErrorWaitReadable;
static VALUE eSSLErrorWaitWritable;
-static ID id_call, ID_callback_state, id_npn_protocols_encoded, id_each;
+static ID id_call, id_npn_protocols_encoded, id_each;
static VALUE sym_exception, sym_wait_readable, sym_wait_writable;
static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode,
@@ -47,16 +47,22 @@ static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode,
id_i_session_remove_cb, id_i_npn_select_cb, id_i_npn_protocols,
id_i_alpn_select_cb, id_i_alpn_protocols, id_i_servername_cb,
id_i_verify_hostname, id_i_keylog_cb, id_i_tmp_dh_callback;
-static ID id_i_io, id_i_context, id_i_hostname, id_i_sync_close;
+static ID id_i_context, id_i_hostname, id_i_sync_close;
-static int ossl_ssl_ex_ptr_idx;
+static int ossl_ssl_ex_data_idx;
static int ossl_sslctx_ex_ptr_idx;
+struct ossl_ssl_data *
+ossl_ssl_data(const SSL *ssl)
+{
+ return SSL_get_ex_data(ssl, ossl_ssl_ex_data_idx);
+}
+
static void
ossl_sslctx_mark(void *ptr)
{
SSL_CTX *ctx = ptr;
- rb_gc_mark((VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx));
+ rb_gc_mark_movable((VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx));
}
static void
@@ -65,12 +71,25 @@ ossl_sslctx_free(void *ptr)
SSL_CTX_free(ptr);
}
+static void
+ossl_sslctx_compact(void *ptr)
+{
+ SSL_CTX *ctx = ptr;
+ VALUE self = (VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx);
+ if (self) {
+ (void)SSL_CTX_set_ex_data(ctx, ossl_sslctx_ex_ptr_idx,
+ (void *)rb_gc_location(self));
+ }
+}
+
static const rb_data_type_t ossl_sslctx_type = {
- "OpenSSL/SSL/CTX",
- {
- ossl_sslctx_mark, ossl_sslctx_free,
+ .wrap_struct_name = "OpenSSL/SSL/CTX",
+ .function = {
+ .dmark = ossl_sslctx_mark,
+ .dfree = ossl_sslctx_free,
+ .dcompact = ossl_sslctx_compact,
},
- 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
+ .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
};
static VALUE
@@ -97,37 +116,55 @@ ossl_sslctx_s_alloc(VALUE klass)
return obj;
}
+struct client_cert_cb_args {
+ VALUE ssl_obj;
+ VALUE cb;
+ X509 **x509;
+ EVP_PKEY **pkey;
+};
+
static VALUE
-ossl_call_client_cert_cb(VALUE obj)
+ossl_call_client_cert_cb(VALUE args_)
{
- VALUE ctx_obj, cb, ary, cert, key;
-
- ctx_obj = rb_attr_get(obj, id_i_context);
- cb = rb_attr_get(ctx_obj, id_i_client_cert_cb);
- if (NIL_P(cb))
- return Qnil;
-
- ary = rb_funcallv(cb, id_call, 1, &obj);
+ struct client_cert_cb_args *args = (struct client_cert_cb_args *)args_;
+ VALUE ary = rb_funcall(args->cb, id_call, 1, args->ssl_obj);
Check_Type(ary, T_ARRAY);
- GetX509CertPtr(cert = rb_ary_entry(ary, 0));
- GetPrivPKeyPtr(key = rb_ary_entry(ary, 1));
+ if (RARRAY_LEN(ary) != 2)
+ rb_raise(rb_eTypeError, "client_cert_cb must return [cert, key]");
+
+ X509 *cert = GetX509CertPtr(rb_ary_entry(ary, 0));
+ EVP_PKEY *pkey = GetPrivPKeyPtr(rb_ary_entry(ary, 1));
+ if (!X509_up_ref(cert))
+ ossl_raise(eSSLError, "X509_up_ref");
+ if (!EVP_PKEY_up_ref(pkey)) {
+ X509_free(cert);
+ ossl_raise(eSSLError, "EVP_PKEY_up_ref");
+ }
- return rb_ary_new3(2, cert, key);
+ *args->x509 = cert;
+ *args->pkey = pkey;
+ return Qnil;
}
static int
ossl_client_cert_cb(SSL *ssl, X509 **x509, EVP_PKEY **pkey)
{
- VALUE obj, ret;
-
- obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- ret = rb_protect(ossl_call_client_cert_cb, obj, NULL);
- if (NIL_P(ret))
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ if (p->cb_state)
return 0;
- *x509 = DupX509CertPtr(RARRAY_AREF(ret, 0));
- *pkey = DupPKeyPtr(RARRAY_AREF(ret, 1));
+ VALUE ctx_obj = rb_attr_get(p->self, id_i_context);
+ VALUE cb = rb_attr_get(ctx_obj, id_i_client_cert_cb);
+ if (NIL_P(cb))
+ return 0;
+ int state;
+ struct client_cert_cb_args args = { p->self, cb, x509, pkey };
+ rb_protect(ossl_call_client_cert_cb, (VALUE)&args, &state);
+ if (state) {
+ p->cb_state = state;
+ return 0;
+ }
return 1;
}
@@ -163,12 +200,15 @@ ossl_call_tmp_dh_callback(VALUE arg)
static DH *
ossl_tmp_dh_callback(SSL *ssl, int is_export, int keylength)
{
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ if (p->cb_state)
+ return NULL;
+
int state;
- VALUE rb_ssl = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- struct tmp_dh_callback_args args = {rb_ssl, is_export, keylength};
+ struct tmp_dh_callback_args args = {p->self, is_export, keylength};
VALUE ret = rb_protect(ossl_call_tmp_dh_callback, (VALUE)&args, &state);
if (state) {
- rb_ivar_set(rb_ssl, ID_callback_state, INT2NUM(state));
+ p->cb_state = state;
return NULL;
}
return (DH *)ret;
@@ -179,12 +219,13 @@ static VALUE
call_verify_certificate_identity(VALUE ctx_v)
{
X509_STORE_CTX *ctx = (X509_STORE_CTX *)ctx_v;
+ struct ossl_ssl_data *p;
SSL *ssl;
- VALUE ssl_obj, hostname, cert_obj;
+ VALUE hostname, cert_obj;
ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
- ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- hostname = rb_attr_get(ssl_obj, id_i_hostname);
+ p = ossl_ssl_data(ssl);
+ hostname = rb_attr_get(p->self, id_i_hostname);
if (!RTEST(hostname)) {
rb_warning("verify_hostname requires hostname to be set");
@@ -199,13 +240,20 @@ call_verify_certificate_identity(VALUE ctx_v)
static int
ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
{
- VALUE cb, ssl_obj, sslctx_obj, verify_hostname, ret;
+ VALUE cb, sslctx_obj, verify_hostname, ret;
SSL *ssl;
+ struct ossl_ssl_data *p;
int status;
ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
- ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- sslctx_obj = rb_attr_get(ssl_obj, id_i_context);
+ p = ossl_ssl_data(ssl);
+ if (p->cb_state) {
+ if (X509_STORE_CTX_get_error(ctx) == X509_V_OK)
+ X509_STORE_CTX_set_error(ctx, X509_V_ERR_UNSPECIFIED);
+ return 0;
+ }
+
+ sslctx_obj = rb_attr_get(p->self, id_i_context);
cb = rb_attr_get(sslctx_obj, id_i_verify_callback);
verify_hostname = rb_attr_get(sslctx_obj, id_i_verify_hostname);
@@ -213,7 +261,7 @@ ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
!X509_STORE_CTX_get_error_depth(ctx)) {
ret = rb_protect(call_verify_certificate_identity, (VALUE)ctx, &status);
if (status) {
- rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status));
+ p->cb_state = status;
return 0;
}
if (ret != Qtrue) {
@@ -225,82 +273,89 @@ ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
return ossl_verify_cb_call(cb, preverify_ok, ctx);
}
+struct sess_get_cb_args {
+ VALUE ssl_obj;
+ const unsigned char *data;
+ int len;
+};
+
static VALUE
-ossl_call_session_get_cb(VALUE ary)
+ossl_call_sess_get_cb(VALUE args_)
{
- VALUE ssl_obj, cb;
-
- Check_Type(ary, T_ARRAY);
- ssl_obj = rb_ary_entry(ary, 0);
+ struct sess_get_cb_args *args = (struct sess_get_cb_args *)args_;
+ VALUE ssl_obj = args->ssl_obj, sslctx_obj, cb;
- cb = rb_funcall(ssl_obj, rb_intern("session_get_cb"), 0);
+ sslctx_obj = rb_attr_get(ssl_obj, id_i_context);
+ cb = rb_attr_get(sslctx_obj, id_i_session_get_cb);
if (NIL_P(cb)) return Qnil;
- return rb_funcallv(cb, id_call, 1, &ary);
+ VALUE session_id = rb_str_new((const char *)args->data, args->len);
+ VALUE ret_obj = rb_funcall(cb, id_call, 1, rb_assoc_new(ssl_obj, session_id));
+ // XXX: Should we raise if ret is neither SSLSession nor nil?
+ if (!rb_obj_is_instance_of(ret_obj, cSSLSession))
+ return (VALUE)NULL;
+
+ SSL_SESSION *sess;
+ GetSSLSession(ret_obj, sess);
+ return (VALUE)sess;
}
static SSL_SESSION *
ossl_sslctx_session_get_cb(SSL *ssl, const unsigned char *buf, int len, int *copy)
{
- VALUE ary, ssl_obj, ret_obj;
- SSL_SESSION *sess;
- int state = 0;
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ if (p->cb_state)
+ return NULL;
OSSL_Debug("SSL SESSION get callback entered");
- ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- ary = rb_ary_new2(2);
- rb_ary_push(ary, ssl_obj);
- rb_ary_push(ary, rb_str_new((const char *)buf, len));
-
- ret_obj = rb_protect(ossl_call_session_get_cb, ary, &state);
+ struct sess_get_cb_args args = { p->self, buf, len };
+ int state;
+ VALUE ret = rb_protect(ossl_call_sess_get_cb, (VALUE)&args, &state);
if (state) {
- rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state));
+ p->cb_state = state;
return NULL;
}
- if (!rb_obj_is_instance_of(ret_obj, cSSLSession))
- return NULL;
-
- GetSSLSession(ret_obj, sess);
*copy = 1;
-
- return sess;
+ return (SSL_SESSION *)ret;
}
+struct sess_new_cb_args {
+ VALUE ssl_obj;
+ SSL_SESSION *sess;
+};
+
static VALUE
-ossl_call_session_new_cb(VALUE ary)
+ossl_call_session_new_cb(VALUE args_)
{
- VALUE ssl_obj, cb;
-
- Check_Type(ary, T_ARRAY);
- ssl_obj = rb_ary_entry(ary, 0);
+ struct sess_new_cb_args *args = (struct sess_new_cb_args *)args_;
+ VALUE ssl_obj = args->ssl_obj, sslctx_obj, cb;
- cb = rb_funcall(ssl_obj, rb_intern("session_new_cb"), 0);
+ sslctx_obj = rb_attr_get(ssl_obj, id_i_context);
+ cb = rb_attr_get(sslctx_obj, id_i_session_new_cb);
if (NIL_P(cb)) return Qnil;
- return rb_funcallv(cb, id_call, 1, &ary);
+ VALUE sess_obj = rb_obj_alloc(cSSLSession);
+ if (!SSL_SESSION_up_ref(args->sess))
+ ossl_raise(eSSLError, "SSL_SESSION_up_ref");
+ RTYPEDDATA_DATA(sess_obj) = args->sess;
+
+ return rb_funcall(cb, id_call, 1, rb_assoc_new(ssl_obj, sess_obj));
}
/* return 1 normal. return 0 removes the session */
static int
ossl_sslctx_session_new_cb(SSL *ssl, SSL_SESSION *sess)
{
- VALUE ary, ssl_obj, sess_obj;
- int state = 0;
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ if (p->cb_state)
+ return 0;
OSSL_Debug("SSL SESSION new callback entered");
-
- ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- sess_obj = rb_obj_alloc(cSSLSession);
- SSL_SESSION_up_ref(sess);
- DATA_PTR(sess_obj) = sess;
-
- ary = rb_ary_new2(2);
- rb_ary_push(ary, ssl_obj);
- rb_ary_push(ary, sess_obj);
-
- rb_protect(ossl_call_session_new_cb, ary, &state);
+ struct sess_new_cb_args args = { p->self, sess };
+ int state;
+ rb_protect(ossl_call_session_new_cb, (VALUE)&args, &state);
if (state) {
- rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state));
+ p->cb_state = state;
}
/*
@@ -344,43 +399,47 @@ ossl_call_keylog_cb(VALUE args_v)
static void
ossl_sslctx_keylog_cb(const SSL *ssl, const char *line)
{
- VALUE ssl_obj;
- struct ossl_call_keylog_cb_args args;
- int state = 0;
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ struct ossl_call_keylog_cb_args args = { p->self, line };
+ int state;
+ if (p->cb_state)
+ return;
OSSL_Debug("SSL keylog callback entered");
- ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- args.ssl_obj = ssl_obj;
- args.line = line;
-
rb_protect(ossl_call_keylog_cb, (VALUE)&args, &state);
if (state) {
- rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state));
+ p->cb_state = state;
}
}
#endif
+struct sess_remove_cb_args {
+ SSL_CTX *ctx;
+ SSL_SESSION *sess;
+};
+
static VALUE
-ossl_call_session_remove_cb(VALUE ary)
+ossl_call_session_remove_cb(VALUE args_)
{
+ struct sess_remove_cb_args *args = (struct sess_remove_cb_args *)args_;
VALUE sslctx_obj, cb;
- Check_Type(ary, T_ARRAY);
- sslctx_obj = rb_ary_entry(ary, 0);
-
+ sslctx_obj = (VALUE)SSL_CTX_get_ex_data(args->ctx, ossl_sslctx_ex_ptr_idx);
cb = rb_attr_get(sslctx_obj, id_i_session_remove_cb);
if (NIL_P(cb)) return Qnil;
- return rb_funcallv(cb, id_call, 1, &ary);
+ VALUE sess_obj = rb_obj_alloc(cSSLSession);
+ if (!SSL_SESSION_up_ref(args->sess))
+ ossl_raise(eSSLError, "SSL_SESSION_up_ref");
+ RTYPEDDATA_DATA(sess_obj) = args->sess;
+
+ return rb_funcall(cb, id_call, 1, rb_assoc_new(sslctx_obj, sess_obj));
}
static void
ossl_sslctx_session_remove_cb(SSL_CTX *ctx, SSL_SESSION *sess)
{
- VALUE ary, sslctx_obj, sess_obj;
- int state = 0;
-
/*
* This callback is also called for all sessions in the internal store
* when SSL_CTX_free() is called.
@@ -389,23 +448,16 @@ ossl_sslctx_session_remove_cb(SSL_CTX *ctx, SSL_SESSION *sess)
return;
OSSL_Debug("SSL SESSION remove callback entered");
-
- sslctx_obj = (VALUE)SSL_CTX_get_ex_data(ctx, ossl_sslctx_ex_ptr_idx);
- sess_obj = rb_obj_alloc(cSSLSession);
- SSL_SESSION_up_ref(sess);
- DATA_PTR(sess_obj) = sess;
-
- ary = rb_ary_new2(2);
- rb_ary_push(ary, sslctx_obj);
- rb_ary_push(ary, sess_obj);
-
- rb_protect(ossl_call_session_remove_cb, ary, &state);
+ struct sess_remove_cb_args args = { ctx, sess };
+ int state;
+ rb_protect(ossl_call_session_remove_cb, (VALUE)&args, &state);
if (state) {
-/*
- the SSL_CTX is frozen, nowhere to save state.
- there is no common accessor method to check it either.
- rb_ivar_set(sslctx_obj, ID_callback_state, INT2NUM(state));
-*/
+ /*
+ * the SSL_CTX is frozen, nowhere to save state.
+ * there is no common accessor method to check it either.
+ */
+ rb_warn("exception in session_remove_cb is ignored");
+ rb_set_errinfo(Qnil);
}
}
@@ -435,10 +487,10 @@ ossl_call_servername_cb(VALUE arg)
if (!servername)
return Qnil;
- VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- VALUE sslctx_obj = rb_attr_get(ssl_obj, id_i_context);
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ VALUE sslctx_obj = rb_attr_get(p->self, id_i_context);
VALUE cb = rb_attr_get(sslctx_obj, id_i_servername_cb);
- VALUE ary = rb_assoc_new(ssl_obj, rb_str_new_cstr(servername));
+ VALUE ary = rb_assoc_new(p->self, rb_str_new_cstr(servername));
VALUE ret_obj = rb_funcallv(cb, id_call, 1, &ary);
if (rb_obj_is_kind_of(ret_obj, cSSLContext)) {
@@ -447,7 +499,7 @@ ossl_call_servername_cb(VALUE arg)
GetSSLCTX(ret_obj, ctx2);
if (!SSL_set_SSL_CTX(ssl, ctx2))
ossl_raise(eSSLError, "SSL_set_SSL_CTX");
- rb_ivar_set(ssl_obj, id_i_context, ret_obj);
+ rb_ivar_set(p->self, id_i_context, ret_obj);
} else if (!NIL_P(ret_obj)) {
ossl_raise(rb_eArgError, "servername_cb must return an "
"OpenSSL::SSL::SSLContext object or nil");
@@ -459,29 +511,43 @@ ossl_call_servername_cb(VALUE arg)
static int
ssl_servername_cb(SSL *ssl, int *ad, void *arg)
{
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
int state;
+ if (p->cb_state)
+ return SSL_TLSEXT_ERR_ALERT_FATAL;
rb_protect(ossl_call_servername_cb, (VALUE)ssl, &state);
if (state) {
- VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(state));
+ p->cb_state = state;
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
return SSL_TLSEXT_ERR_OK;
}
-static void
-ssl_renegotiation_cb(const SSL *ssl)
+static VALUE
+call_renegotiation_cb(VALUE args_)
{
- VALUE ssl_obj, sslctx_obj, cb;
+ VALUE *args = (VALUE *)args_;
+ return rb_funcall(args[0], id_call, 1, args[1]);
+}
- ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
- sslctx_obj = rb_attr_get(ssl_obj, id_i_context);
- cb = rb_attr_get(sslctx_obj, id_i_renegotiation_cb);
- if (NIL_P(cb)) return;
+/* This function may serve as the entry point to support further callbacks. */
+static void
+ssl_info_cb(const SSL *ssl, int where, int val)
+{
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
- rb_funcallv(cb, id_call, 1, &ssl_obj);
+ if (p->cb_state)
+ return;
+ if (where & SSL_CB_HANDSHAKE_START && SSL_is_server(ssl)) {
+ VALUE sslctx_obj = rb_attr_get(p->self, id_i_context);
+ VALUE cb = rb_attr_get(sslctx_obj, id_i_renegotiation_cb);
+ if (!NIL_P(cb)) {
+ VALUE args[] = { cb, p->self };
+ rb_protect(call_renegotiation_cb, (VALUE)&args, &p->cb_state);
+ }
+ }
}
static VALUE
@@ -544,19 +610,21 @@ ssl_npn_select_cb_common(SSL *ssl, VALUE cb, const unsigned char **out,
unsigned char *outlen, const unsigned char *in,
unsigned int inlen)
{
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
VALUE selected;
- int status;
+ int state;
struct npn_select_cb_common_args args;
+ if (p->cb_state)
+ return SSL_TLSEXT_ERR_ALERT_FATAL;
+
args.cb = cb;
args.in = in;
args.inlen = inlen;
- selected = rb_protect(npn_select_cb_common_i, (VALUE)&args, &status);
- if (status) {
- VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);
-
- rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status));
+ selected = rb_protect(npn_select_cb_common_i, (VALUE)&args, &state);
+ if (state) {
+ p->cb_state = state;
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
@@ -605,17 +673,6 @@ ssl_alpn_select_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen,
return ssl_npn_select_cb_common(ssl, cb, out, outlen, in, inlen);
}
-/* This function may serve as the entry point to support further callbacks. */
-static void
-ssl_info_cb(const SSL *ssl, int where, int val)
-{
- int is_server = SSL_is_server((SSL *)ssl);
-
- if (is_server && where & SSL_CB_HANDSHAKE_START) {
- ssl_renegotiation_cb(ssl);
- }
-}
-
/*
* call-seq:
* ctx.options -> integer
@@ -779,6 +836,9 @@ ossl_sslctx_setup(VALUE self)
val = rb_attr_get(self, id_i_verify_depth);
if(!NIL_P(val)) SSL_CTX_set_verify_depth(ctx, NUM2INT(val));
+ if (!NIL_P(rb_attr_get(self, id_i_renegotiation_cb)))
+ SSL_CTX_set_info_callback(ctx, ssl_info_cb);
+
#ifdef OSSL_USE_NEXTPROTONEG
val = rb_attr_get(self, id_i_npn_protocols);
if (!NIL_P(val)) {
@@ -1567,21 +1627,37 @@ static void
ossl_ssl_mark(void *ptr)
{
SSL *ssl = ptr;
- rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx));
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ rb_gc_mark_movable(p->self);
+ rb_gc_mark_movable(p->io);
}
static void
-ossl_ssl_free(void *ssl)
+ossl_ssl_free(void *ptr)
{
+ SSL *ssl = ptr;
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ ruby_xfree(p);
SSL_free(ssl);
}
+static void
+ossl_ssl_compact(void *ptr)
+{
+ SSL *ssl = ptr;
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
+ p->self = rb_gc_location(p->self);
+ p->io = rb_gc_location(p->io);
+}
+
const rb_data_type_t ossl_ssl_type = {
- "OpenSSL/SSL",
- {
- ossl_ssl_mark, ossl_ssl_free,
+ .wrap_struct_name = "OpenSSL/SSL",
+ .function = {
+ .dmark = ossl_ssl_mark,
+ .dfree = ossl_ssl_free,
+ .dcompact = ossl_ssl_compact,
},
- 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
+ .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
};
static VALUE
@@ -1593,8 +1669,9 @@ ossl_ssl_s_alloc(VALUE klass)
static VALUE
peer_ip_address(VALUE io)
{
- VALUE remote_address = rb_funcall(io, rb_intern("remote_address"), 0);
-
+ VALUE remote_address = rb_check_funcall(io, rb_intern("remote_address"), 0, NULL);
+ if (remote_address == Qundef)
+ return rb_str_new_cstr("(io.remote_address not supported)");
return rb_funcall(remote_address, rb_intern("inspect_sockaddr"), 0);
}
@@ -1611,14 +1688,24 @@ peeraddr_ip_str(VALUE io)
rb_eSystemCallError, (VALUE)0);
}
+static int
+is_real_socket(VALUE io)
+{
+ // FIXME: DO NOT MERGE
+ return 0;
+ return RB_TYPE_P(io, T_FILE);
+}
+
/*
* call-seq:
* SSLSocket.new(io) => aSSLSocket
* SSLSocket.new(io, ctx) => aSSLSocket
* SSLSocket.new(io, ctx, sync_close:) => aSSLSocket
*
- * Creates a new SSL socket from _io_ which must be a real IO object (not an
- * IO-like object that responds to read/write).
+ * Creates a new SSL socket from the underlying socket _io_ and _ctx_.
+ *
+ * _io_ must be an IO object, typically a TCPSocket or Socket from the socket
+ * library, or an IO-like object that supports the typical IO methods.
*
* If _ctx_ is provided the SSL Sockets initial params will be taken from
* the context.
@@ -1631,6 +1718,22 @@ peeraddr_ip_str(VALUE io)
*
* This method will freeze the SSLContext if one is provided;
* however, session management is still allowed in the frozen SSLContext.
+ *
+ * == Support for IO-like objects
+ *
+ * Support for IO-like objects was added in version 4.1 and is considered
+ * experimental. The requirements for the objects may change in future versions.
+ *
+ * As of version 4.1, SSLSocket expects the following methods to be compatible
+ * with core IO objects:
+ *
+ * - write_nonblock with the exception: false option
+ * - read_nonblock with the exception: false option
+ * - wait_readable
+ * - wait_writable
+ * - close
+ * - closed?
+ * - sync and flush (optional)
*/
static VALUE
ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
@@ -1663,25 +1766,42 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
rb_ivar_set(self, id_i_context, v_ctx);
ossl_sslctx_setup(v_ctx);
- if (rb_respond_to(io, rb_intern("nonblock=")))
- rb_funcall(io, rb_intern("nonblock="), 1, Qtrue);
- Check_Type(io, T_FILE);
- rb_ivar_set(self, id_i_io, io);
+ struct ossl_ssl_data *p = RB_ZALLOC(struct ossl_ssl_data);
+ p->self = self;
+ p->io = io;
ssl = SSL_new(ctx);
- if (!ssl)
- ossl_raise(eSSLError, NULL);
- RTYPEDDATA_DATA(self) = ssl;
-
- if (!SSL_set_ex_data(ssl, ossl_ssl_ex_ptr_idx, (void *)self))
+ if (!ssl) {
+ ruby_xfree(p);
+ ossl_raise(eSSLError, "SSL_new");
+ }
+ if (!SSL_set_ex_data(ssl, ossl_ssl_ex_data_idx, p)) {
+ ruby_xfree(p);
+ SSL_free(ssl);
ossl_raise(eSSLError, "SSL_set_ex_data");
- SSL_set_info_callback(ssl, ssl_info_cb);
+ }
+ RTYPEDDATA_DATA(self) = ssl;
rb_call_super(0, NULL);
return self;
}
+/*
+ * call-seq:
+ * ssl.io -> io
+ * ssl.to_io -> io
+ *
+ * Returns the underlying IO object.
+ */
+static VALUE
+ossl_ssl_get_io(VALUE self)
+{
+ SSL *ssl;
+ GetSSL(self, ssl);
+ return ossl_ssl_data(ssl)->io;
+}
+
#ifndef HAVE_RB_IO_DESCRIPTOR
static int
io_descriptor_fallback(VALUE io)
@@ -1696,27 +1816,42 @@ io_descriptor_fallback(VALUE io)
static VALUE
ossl_ssl_setup(VALUE self)
{
- VALUE io;
SSL *ssl;
- rb_io_t *fptr;
+ struct ossl_ssl_data *p;
GetSSL(self, ssl);
+ p = ossl_ssl_data(ssl);
if (ssl_started(ssl))
return Qtrue;
- io = rb_attr_get(self, id_i_io);
- GetOpenFile(io, fptr);
- rb_io_check_readable(fptr);
- rb_io_check_writable(fptr);
- if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
- ossl_raise(eSSLError, "SSL_set_fd");
+ if (is_real_socket(p->io)) {
+ rb_io_t *fptr;
+ GetOpenFile(p->io, fptr);
+ rb_io_check_readable(fptr);
+ rb_io_check_writable(fptr);
+ rb_io_set_nonblock(fptr);
+ if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(p->io))))
+ ossl_raise(eSSLError, "SSL_set_fd");
+ }
+ else {
+ BIO *bio = ossl_ssl_bio_setup(p);
+ if (!BIO_up_ref(bio)) {
+ BIO_free(bio);
+ ossl_raise(eSSLError, "BIO_up_ref");
+ }
+ SSL_set_bio(ssl, bio, bio);
+ }
return Qtrue;
}
static int
-errno_mapped(void)
+errno_mapped(struct ossl_ssl_data *p)
{
+ /* ossl_ssl_bio_method -> errno must not be used */
+ if (!is_real_socket(p->io))
+ return 0;
+
#ifdef _WIN32
return rb_w32_map_errno(WSAGetLastError());
#else
@@ -1762,6 +1897,11 @@ no_exception_p(VALUE opts)
static void
io_wait_writable(VALUE io)
{
+ if (!is_real_socket(io)) {
+ if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL)))
+ rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
+ return;
+ }
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_wait(io, INT2NUM(RUBY_IO_WRITABLE), RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
@@ -1776,6 +1916,11 @@ io_wait_writable(VALUE io)
static void
io_wait_readable(VALUE io)
{
+ if (!is_real_socket(io)) {
+ if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL)))
+ rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
+ return;
+ }
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_wait(io, INT2NUM(RUBY_IO_READABLE), RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
@@ -1791,23 +1936,21 @@ static VALUE
ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
{
SSL *ssl;
- VALUE cb_state;
int nonblock = opts != Qfalse;
- rb_ivar_set(self, ID_callback_state, Qnil);
-
GetSSL(self, ssl);
+ struct ossl_ssl_data *p = ossl_ssl_data(ssl);
- VALUE io = rb_attr_get(self, id_i_io);
for (;;) {
int ret = func(ssl);
- int saved_errno = errno_mapped();
+ int saved_errno = errno_mapped(p);
- cb_state = rb_attr_get(self, ID_callback_state);
- if (!NIL_P(cb_state)) {
+ if (p->cb_state) {
/* must cleanup OpenSSL error stack before re-raising */
ossl_clear_error();
- rb_jump_tag(NUM2INT(cb_state));
+ int state = p->cb_state;
+ p->cb_state = 0;
+ rb_jump_tag(state);
}
if (ret > 0)
@@ -1818,12 +1961,12 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
case SSL_ERROR_WANT_WRITE:
if (no_exception_p(opts)) { return sym_wait_writable; }
write_would_block(nonblock);
- io_wait_writable(io);
+ io_wait_writable(p->io);
continue;
case SSL_ERROR_WANT_READ:
if (no_exception_p(opts)) { return sym_wait_readable; }
read_would_block(nonblock);
- io_wait_readable(io);
+ io_wait_readable(p->io);
continue;
case SSL_ERROR_SYSCALL:
#ifdef __APPLE__
@@ -1856,7 +1999,7 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
code == SSL_ERROR_SYSCALL ? " SYSCALL" : "",
code,
saved_errno,
- peeraddr_ip_str(io),
+ peeraddr_ip_str(p->io),
SSL_state_string_long(ssl),
error_append);
}
@@ -1965,8 +2108,9 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
{
SSL *ssl;
int ilen;
- VALUE len, str, cb_state;
+ VALUE len, str;
VALUE opts = Qnil;
+ struct ossl_ssl_data *p;
if (nonblock) {
rb_scan_args(argc, argv, "11:", &len, &str, &opts);
@@ -1974,6 +2118,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
rb_scan_args(argc, argv, "11", &len, &str);
}
GetSSL(self, ssl);
+ p = ossl_ssl_data(ssl);
if (!ssl_started(ssl))
rb_raise(eSSLError, "SSL session is not started yet");
@@ -1993,19 +2138,17 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
return str;
}
- VALUE io = rb_attr_get(self, id_i_io);
-
for (;;) {
rb_str_locktmp(str);
int nread = SSL_read(ssl, RSTRING_PTR(str), ilen);
- int saved_errno = errno_mapped();
+ int saved_errno = errno_mapped(p);
rb_str_unlocktmp(str);
- cb_state = rb_attr_get(self, ID_callback_state);
- if (!NIL_P(cb_state)) {
- rb_ivar_set(self, ID_callback_state, Qnil);
+ if (p->cb_state) {
ossl_clear_error();
- rb_jump_tag(NUM2INT(cb_state));
+ int state = p->cb_state;
+ p->cb_state = 0;
+ rb_jump_tag(state);
}
switch (SSL_get_error(ssl, nread)) {
@@ -2020,14 +2163,14 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
if (no_exception_p(opts)) { return sym_wait_writable; }
write_would_block(nonblock);
}
- io_wait_writable(io);
+ io_wait_writable(p->io);
break;
case SSL_ERROR_WANT_READ:
if (nonblock) {
if (no_exception_p(opts)) { return sym_wait_readable; }
read_would_block(nonblock);
}
- io_wait_readable(io);
+ io_wait_readable(p->io);
break;
case SSL_ERROR_SYSCALL:
if (!ERR_peek_error()) {
@@ -2099,17 +2242,14 @@ ossl_ssl_write_internal_safe(VALUE _args)
VALUE opts = args[2];
SSL *ssl;
- rb_io_t *fptr;
+ struct ossl_ssl_data *p;
int num, nonblock = opts != Qfalse;
- VALUE cb_state;
GetSSL(self, ssl);
+ p = ossl_ssl_data(ssl);
if (!ssl_started(ssl))
rb_raise(eSSLError, "SSL session is not started yet");
- VALUE io = rb_attr_get(self, id_i_io);
- GetOpenFile(io, fptr);
-
/* SSL_write(3ssl) manpage states num == 0 is undefined */
num = RSTRING_LENINT(str);
if (num == 0)
@@ -2117,13 +2257,13 @@ ossl_ssl_write_internal_safe(VALUE _args)
for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(str), num);
- int saved_errno = errno_mapped();
+ int saved_errno = errno_mapped(p);
- cb_state = rb_attr_get(self, ID_callback_state);
- if (!NIL_P(cb_state)) {
- rb_ivar_set(self, ID_callback_state, Qnil);
+ if (p->cb_state) {
ossl_clear_error();
- rb_jump_tag(NUM2INT(cb_state));
+ int state = p->cb_state;
+ p->cb_state = 0;
+ rb_jump_tag(state);
}
switch (SSL_get_error(ssl, nwritten)) {
@@ -2132,12 +2272,12 @@ ossl_ssl_write_internal_safe(VALUE _args)
case SSL_ERROR_WANT_WRITE:
if (no_exception_p(opts)) { return sym_wait_writable; }
write_would_block(nonblock);
- io_wait_writable(io);
+ io_wait_writable(p->io);
continue;
case SSL_ERROR_WANT_READ:
if (no_exception_p(opts)) { return sym_wait_readable; }
read_would_block(nonblock);
- io_wait_readable(io);
+ io_wait_readable(p->io);
continue;
case SSL_ERROR_SYSCALL:
#ifdef __APPLE__
@@ -2224,12 +2364,27 @@ static VALUE
ossl_ssl_stop(VALUE self)
{
SSL *ssl;
+ struct ossl_ssl_data *p;
int ret;
GetSSL(self, ssl);
+ p = ossl_ssl_data(ssl);
if (!ssl_started(ssl))
return Qnil;
+
ret = SSL_shutdown(ssl);
+
+ /*
+ * XXX: SSLSocket#stop is supposed to ignore errors due to the underlying
+ * socket being unreadable/unwritable. We want to suppress exceptions raised
+ * by the underlying IO-like object, but not from any SSL callbacks. Can SSL
+ * callbacks be invoked during SSL_shutdown()? We assume no, for now.
+ */
+ if (!is_real_socket(p->io) && p->cb_state) {
+ p->cb_state = 0;
+ rb_set_errinfo(Qnil);
+ }
+
if (ret == 1) /* Have already received close_notify */
return Qnil;
if (ret == 0) /* Sent close_notify, but we don't wait for reply */
@@ -2738,11 +2893,8 @@ Init_ossl_ssl(void)
#endif
#ifndef OPENSSL_NO_SOCK
- id_call = rb_intern_const("call");
- ID_callback_state = rb_intern_const("callback_state");
-
- ossl_ssl_ex_ptr_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_ptr_idx", 0, 0, 0);
- if (ossl_ssl_ex_ptr_idx < 0)
+ ossl_ssl_ex_data_idx = SSL_get_ex_new_index(0, (void *)"ossl_ssl_ex_data_idx", 0, 0, 0);
+ if (ossl_ssl_ex_data_idx < 0)
ossl_raise(rb_eRuntimeError, "SSL_get_ex_new_index");
ossl_sslctx_ex_ptr_idx = SSL_CTX_get_ex_new_index(0, (void *)"ossl_sslctx_ex_ptr_idx", 0, 0, 0);
if (ossl_sslctx_ex_ptr_idx < 0)
@@ -2768,6 +2920,7 @@ Init_ossl_ssl(void)
rb_include_module(eSSLErrorWaitWritable, rb_mWaitWritable);
Init_ossl_ssl_session();
+ Init_ossl_ssl_bio();
/* Document-class: OpenSSL::SSL::SSLContext
*
@@ -3134,6 +3287,8 @@ Init_ossl_ssl(void)
rb_define_alloc_func(cSSLSocket, ossl_ssl_s_alloc);
rb_define_method(cSSLSocket, "initialize", ossl_ssl_initialize, -1);
rb_undef_method(cSSLSocket, "initialize_copy");
+ rb_define_method(cSSLSocket, "io", ossl_ssl_get_io, 0);
+ rb_define_alias(cSSLSocket, "to_io", "io");
rb_define_method(cSSLSocket, "connect", ossl_ssl_connect, 0);
rb_define_method(cSSLSocket, "connect_nonblock", ossl_ssl_connect_nonblock, -1);
rb_define_method(cSSLSocket, "accept", ossl_ssl_accept, 0);
@@ -3294,6 +3449,7 @@ Init_ossl_ssl(void)
sym_wait_readable = ID2SYM(rb_intern_const("wait_readable"));
sym_wait_writable = ID2SYM(rb_intern_const("wait_writable"));
+ id_call = rb_intern_const("call");
id_npn_protocols_encoded = rb_intern_const("npn_protocols_encoded");
id_each = rb_intern_const("each");
@@ -3326,7 +3482,6 @@ Init_ossl_ssl(void)
DefIVarID(keylog_cb);
DefIVarID(tmp_dh_callback);
- DefIVarID(io);
DefIVarID(context);
DefIVarID(hostname);
DefIVarID(sync_close);
diff --git a/ext/openssl/ossl_ssl.h b/ext/openssl/ossl_ssl.h
index a87e62d45..88eadf709 100644
--- a/ext/openssl/ossl_ssl.h
+++ b/ext/openssl/ossl_ssl.h
@@ -30,7 +30,20 @@ extern VALUE mSSL;
extern VALUE cSSLSocket;
extern VALUE cSSLSession;
+struct ossl_ssl_data {
+ VALUE self;
+ VALUE io;
+ int cb_state;
+
+ /* Used by ossl_ssl_bio_method */
+ int bio_eof;
+};
+struct ossl_ssl_data *ossl_ssl_data(const SSL *ssl);
+
+BIO *ossl_ssl_bio_setup(struct ossl_ssl_data *p);
+
void Init_ossl_ssl(void);
void Init_ossl_ssl_session(void);
+void Init_ossl_ssl_bio(void);
#endif /* _OSSL_SSL_H_ */
diff --git a/ext/openssl/ossl_ssl_bio.c b/ext/openssl/ossl_ssl_bio.c
new file mode 100644
index 000000000..25a0ca661
--- /dev/null
+++ b/ext/openssl/ossl_ssl_bio.c
@@ -0,0 +1,241 @@
+/*
+ * Ruby/OpenSSL Project
+ * Copyright (C) 2025 Kazuki Yamaguchi
+ */
+#include "ossl.h"
+
+static BIO_METHOD *ossl_bio_meth;
+static VALUE nonblock_kwargs, sym_wait_readable, sym_wait_writable;
+
+BIO *
+ossl_ssl_bio_setup(struct ossl_ssl_data *p)
+{
+ BIO *bio = BIO_new(ossl_bio_meth);
+ if (!bio)
+ ossl_raise(eOSSLError, "BIO_new");
+
+ BIO_set_data(bio, p);
+ BIO_set_init(bio, 1);
+ return bio;
+}
+
+struct call0_args {
+ VALUE (*func)(VALUE);
+ VALUE args;
+ VALUE ret;
+};
+
+static VALUE
+do_nothing(VALUE _)
+{
+ return Qnil;
+}
+
+static VALUE
+call_protect1(VALUE args_)
+{
+ struct call0_args *args = (void *)args_;
+ rb_set_errinfo(Qnil);
+ args->ret = args->func(args->args);
+ return Qnil;
+}
+
+static VALUE
+call_protect0(VALUE args_)
+{
+ rb_ensure(do_nothing, Qnil, call_protect1, args_);
+ return Qnil;
+}
+
+static VALUE
+call_protect(VALUE (*func)(VALUE), VALUE args, int *state, int current)
+{
+ if (!current)
+ return rb_protect(func, args, state);
+
+ VALUE errinfo = rb_errinfo();
+ struct call0_args call0_args = { func, args, Qnil };
+ rb_protect(call_protect0, (VALUE)&call0_args, state);
+ if (*state) {
+ if (!rb_obj_is_kind_of(errinfo, rb_eException))
+ errinfo = rb_str_new_cstr("(unknown)");
+ rb_warn("BIO callback raised an exception, pending jump suppressed: " \
+ "state=%d, errinfo=%+"PRIsVALUE, current, errinfo);
+ }
+ return call0_args.ret;
+}
+
+
+struct bwrite_args {
+ struct ossl_ssl_data *p;
+ BIO *bio;
+ const char *data;
+ int dlen;
+ int written;
+};
+
+static VALUE
+bio_bwrite0(VALUE args_)
+{
+ struct bwrite_args *args = (void *)args_;
+ struct ossl_ssl_data *p = args->p;
+
+ BIO_clear_retry_flags(args->bio);
+
+ VALUE fargs[] = { rb_str_new_static(args->data, args->dlen), nonblock_kwargs };
+ VALUE ret = rb_funcallv_kw(p->io, rb_intern("write_nonblock"),
+ 2, fargs, RB_PASS_KEYWORDS);
+
+ if (RB_INTEGER_TYPE_P(ret)) {
+ args->written = NUM2INT(ret);
+ return Qtrue;
+ }
+ else if (ret == sym_wait_readable) {
+ BIO_set_retry_read(args->bio);
+ return Qfalse;
+ }
+ else if (ret == sym_wait_writable) {
+ BIO_set_retry_write(args->bio);
+ return Qfalse;
+ }
+ else {
+ rb_raise(rb_eTypeError, "write_nonblock must return an Integer, "
+ ":wait_readable, or :wait_writable");
+ }
+}
+
+static int
+bio_bwrite(BIO *bio, const char *data, int dlen)
+{
+ struct ossl_ssl_data *p = BIO_get_data(bio);
+
+ struct bwrite_args args = { p, bio, data, dlen, 0 };
+ int state;
+ VALUE ok = call_protect(bio_bwrite0, (VALUE)&args, &state, p->cb_state);
+ if (state) {
+ p->cb_state = state;
+ return -1;
+ }
+ if (RTEST(ok))
+ return args.written;
+ return -1;
+}
+
+struct bread_args {
+ struct ossl_ssl_data *p;
+ BIO *bio;
+ char *data;
+ int dlen;
+ int readbytes;
+};
+
+static VALUE
+bio_bread0(VALUE args_)
+{
+ struct bread_args *args = (void *)args_;
+ struct ossl_ssl_data *p = args->p;
+
+ BIO_clear_retry_flags(args->bio);
+
+ VALUE fargs[] = { INT2NUM(args->dlen), nonblock_kwargs };
+ VALUE ret = rb_funcallv_kw(p->io, rb_intern("read_nonblock"),
+ 2, fargs, RB_PASS_KEYWORDS);
+
+ if (RB_TYPE_P(ret, T_STRING)) {
+ int len = RSTRING_LENINT(ret);
+ if (len > args->dlen)
+ rb_raise(rb_eTypeError, "read_nonblock returned too much data");
+ memcpy(args->data, RSTRING_PTR(ret), len);
+ args->readbytes = len;
+ return Qtrue;
+ }
+ else if (NIL_P(ret)) {
+ // In OpenSSL 3.0 or later: BIO_set_flags(args->bio, BIO_FLAGS_IN_EOF);
+ p->bio_eof = 1;
+ return Qtrue;
+ }
+ else if (ret == sym_wait_readable) {
+ BIO_set_retry_read(args->bio);
+ return Qfalse;
+ }
+ else if (ret == sym_wait_writable) {
+ BIO_set_retry_write(args->bio);
+ return Qfalse;
+ }
+ else {
+ rb_raise(rb_eTypeError, "write_nonblock must return an Integer, "
+ ":wait_readable, or :wait_writable");
+ }
+}
+
+static int
+bio_bread(BIO *bio, char *data, int dlen)
+{
+ struct ossl_ssl_data *p = BIO_get_data(bio);
+
+ struct bread_args args = { p, bio, data, dlen, 0 };
+ int state;
+ VALUE ok = call_protect(bio_bread0, (VALUE)&args, &state, p->cb_state);
+ if (state) {
+ p->cb_state = state;
+ return -1;
+ }
+ if (RTEST(ok))
+ return args.readbytes;
+ return -1;
+}
+
+static VALUE
+bio_flush0(VALUE p_)
+{
+ struct ossl_ssl_data *p = (void *)p_;
+ /*
+ * If the underlying IO-like object does not respond to flush, let's just
+ * assume that it does not need to be flushed.
+ */
+ return rb_check_funcall(p->io, rb_intern("flush"), 0, NULL);
+}
+
+static long
+bio_ctrl(BIO *bio, int cmd, long larg, void *parg)
+{
+ struct ossl_ssl_data *p = BIO_get_data(bio);
+ int state;
+
+ switch (cmd) {
+ case BIO_CTRL_EOF:
+ return p->bio_eof;
+ case BIO_CTRL_FLUSH:
+ call_protect(bio_flush0, (VALUE)p, &state, p->cb_state);
+ if (state) {
+ p->cb_state = state;
+ return 0;
+ }
+ return 1;
+ default:
+ return 0;
+ }
+}
+
+void
+Init_ossl_ssl_bio(void)
+{
+ ossl_bio_meth = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "Ruby IO-like object");
+ if (!ossl_bio_meth)
+ ossl_raise(eOSSLError, "BIO_meth_new");
+ if (!BIO_meth_set_write(ossl_bio_meth, bio_bwrite) ||
+ !BIO_meth_set_read(ossl_bio_meth, bio_bread) ||
+ !BIO_meth_set_ctrl(ossl_bio_meth, bio_ctrl)) {
+ BIO_meth_free(ossl_bio_meth);
+ ossl_bio_meth = NULL;
+ ossl_raise(eOSSLError, "BIO_meth_set_*");
+ }
+
+ nonblock_kwargs = rb_hash_new();
+ rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern_const("exception")), Qfalse);
+ rb_global_variable(&nonblock_kwargs);
+
+ sym_wait_readable = ID2SYM(rb_intern_const("wait_readable"));
+ sym_wait_writable = ID2SYM(rb_intern_const("wait_writable"));
+}
+
diff --git a/ext/openssl/ossl_x509store.c b/ext/openssl/ossl_x509store.c
index bed124dc3..ee9f3d0e3 100644
--- a/ext/openssl/ossl_x509store.c
+++ b/ext/openssl/ossl_x509store.c
@@ -116,10 +116,9 @@ static void
ossl_x509store_mark(void *ptr)
{
X509_STORE *store = ptr;
- // Note: this reference is stored as @verify_callback so we don't need to mark it.
- // However we do need to ensure GC compaction won't move it, hence why
- // we call rb_gc_mark here.
- rb_gc_mark((VALUE)X509_STORE_get_ex_data(store, store_ex_verify_cb_idx));
+ VALUE verify_cb =
+ (VALUE)X509_STORE_get_ex_data(store, store_ex_verify_cb_idx);
+ rb_gc_mark_movable(verify_cb);
}
static void
@@ -128,12 +127,26 @@ ossl_x509store_free(void *ptr)
X509_STORE_free(ptr);
}
+static void
+ossl_x509store_compact(void *ptr)
+{
+ X509_STORE *store = ptr;
+ VALUE verify_cb =
+ (VALUE)X509_STORE_get_ex_data(store, store_ex_verify_cb_idx);
+ if (verify_cb) {
+ (void)X509_STORE_set_ex_data(store, store_ex_verify_cb_idx,
+ (void *)rb_gc_location(verify_cb));
+ }
+}
+
static const rb_data_type_t ossl_x509store_type = {
- "OpenSSL/X509/STORE",
- {
- ossl_x509store_mark, ossl_x509store_free,
+ .wrap_struct_name = "OpenSSL/X509/STORE",
+ .function = {
+ .dmark = ossl_x509store_mark,
+ .dfree = ossl_x509store_free,
+ .dcompact = ossl_x509store_compact,
},
- 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
+ .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
};
/*
@@ -570,10 +583,9 @@ static void
ossl_x509stctx_mark(void *ptr)
{
X509_STORE_CTX *ctx = ptr;
- // Note: this reference is stored as @verify_callback so we don't need to mark it.
- // However we do need to ensure GC compaction won't move it, hence why
- // we call rb_gc_mark here.
- rb_gc_mark((VALUE)X509_STORE_CTX_get_ex_data(ctx, stctx_ex_verify_cb_idx));
+ VALUE verify_cb =
+ (VALUE)X509_STORE_CTX_get_ex_data(ctx, stctx_ex_verify_cb_idx);
+ rb_gc_mark_movable(verify_cb);
}
static void
@@ -585,12 +597,26 @@ ossl_x509stctx_free(void *ptr)
X509_STORE_CTX_free(ctx);
}
+static void
+ossl_x509stctx_compact(void *ptr)
+{
+ X509_STORE_CTX *ctx = ptr;
+ VALUE verify_cb =
+ (VALUE)X509_STORE_CTX_get_ex_data(ctx, stctx_ex_verify_cb_idx);
+ if (verify_cb) {
+ (void)X509_STORE_CTX_set_ex_data(ctx, stctx_ex_verify_cb_idx,
+ (void *)rb_gc_location(verify_cb));
+ }
+}
+
static const rb_data_type_t ossl_x509stctx_type = {
- "OpenSSL/X509/STORE_CTX",
- {
- ossl_x509stctx_mark, ossl_x509stctx_free,
+ .wrap_struct_name = "OpenSSL/X509/STORE_CTX",
+ .function = {
+ .dmark = ossl_x509stctx_mark,
+ .dfree = ossl_x509stctx_free,
+ .dcompact = ossl_x509stctx_compact,
},
- 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
+ .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
};
static VALUE
diff --git a/lib/openssl/buffering.rb b/lib/openssl/buffering.rb
index 1464a4292..e8e99868c 100644
--- a/lib/openssl/buffering.rb
+++ b/lib/openssl/buffering.rb
@@ -60,7 +60,7 @@ def initialize(*)
super
@eof = false
@rbuffer = Buffer.new
- @sync = @io.sync
+ @sync = to_io.respond_to?(:sync) ? to_io.sync : true
end
#
diff --git a/lib/openssl/ssl.rb b/lib/openssl/ssl.rb
index dccc11a55..23bd2bdab 100644
--- a/lib/openssl/ssl.rb
+++ b/lib/openssl/ssl.rb
@@ -336,10 +336,6 @@ class SSLSocket
attr_reader :hostname
- # The underlying IO object.
- attr_reader :io
- alias :to_io :io
-
# The SSLContext object used in this connection.
attr_reader :context
@@ -428,18 +424,6 @@ def using_anon_cipher?
ctx.ciphers.include?(cipher)
end
- def client_cert_cb
- @context.client_cert_cb
- end
-
- def session_new_cb
- @context.session_new_cb
- end
-
- def session_get_cb
- @context.session_get_cb
- end
-
class << self
# call-seq:
diff --git a/test/openssl/test_buffering.rb b/test/openssl/test_buffering.rb
index 466bbcfa2..95361d440 100644
--- a/test/openssl/test_buffering.rb
+++ b/test/openssl/test_buffering.rb
@@ -7,7 +7,8 @@ class OpenSSL::TestBuffering < OpenSSL::TestCase
class IO
include OpenSSL::Buffering
- attr_accessor :sync
+ attr_reader :io
+ alias to_io io
def initialize
@io = Buffer.new
diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb
index 10942191d..dde849b3e 100644
--- a/test/openssl/test_pair.rb
+++ b/test/openssl/test_pair.rb
@@ -67,6 +67,34 @@ def create_tcp_client(host, port)
end
end
+module OpenSSL::SSLPairIOish
+ include OpenSSL::SSLPairM
+
+ def create_tcp_server(host, port)
+ Addrinfo.tcp(host, port).listen
+ end
+
+ class TCPSocketWrapper
+ attr_reader :io
+ alias to_io io
+
+ def initialize(io) @io = io end
+ def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end
+ def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end
+ def wait_readable() @io.wait_readable end
+ def wait_writable() @io.wait_writable end
+ def close() @io.close end
+ def closed?() @io.closed? end
+
+ # Only used within test_pair.rb
+ def write(*args) @io.write(*args) end
+ end
+
+ def create_tcp_client(host, port)
+ TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect)
+ end
+end
+
module OpenSSL::TestEOF1M
def open_file(content)
ssl_pair { |s1, s2|
@@ -518,6 +546,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF1M
end
+class OpenSSL::TestEOF1IOish < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestEOF1M
+end
+
class OpenSSL::TestEOF2 < OpenSSL::TestCase
include OpenSSL::TestEOF
include OpenSSL::SSLPair
@@ -530,6 +564,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF2M
end
+class OpenSSL::TestEOF2IOish < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestEOF2M
+end
+
class OpenSSL::TestPair < OpenSSL::TestCase
include OpenSSL::SSLPair
include OpenSSL::TestPairM
@@ -540,4 +580,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestPairM
end
+class OpenSSL::TestPairIOish < OpenSSL::TestCase
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestPairM
+end
+
end
diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb
index e4fd58107..ff2821d1b 100644
--- a/test/openssl/test_ssl.rb
+++ b/test/openssl/test_ssl.rb
@@ -4,17 +4,6 @@
if defined?(OpenSSL::SSL)
class OpenSSL::TestSSL < OpenSSL::SSLTestCase
- def test_bad_socket
- bad_socket = Struct.new(:sync).new
- assert_raise TypeError do
- socket = OpenSSL::SSL::SSLSocket.new bad_socket
- # if the socket is not a T_FILE, `connect` will segv because it tries
- # to get the underlying file descriptor but the API it calls assumes
- # the object type is T_FILE
- socket.connect
- end
- end
-
def test_ctx_setup
ctx = OpenSSL::SSL::SSLContext.new
assert_equal true, ctx.setup
@@ -170,6 +159,103 @@ def test_socket_close_write
end
end
+ def test_synthetic_io
+ start_server do |port|
+ tcp = TCPSocket.new("127.0.0.1", port)
+ obj = Object.new
+ obj.define_singleton_method(:read_nonblock) { |maxlen, exception:|
+ tcp.read_nonblock(maxlen, exception: exception) }
+ obj.define_singleton_method(:write_nonblock) { |str, exception:|
+ tcp.write_nonblock(str, exception: exception) }
+ obj.define_singleton_method(:wait_readable) { tcp.wait_readable }
+ obj.define_singleton_method(:wait_writable) { tcp.wait_writable }
+ obj.define_singleton_method(:flush) { tcp.flush }
+ obj.define_singleton_method(:closed?) { tcp.closed? }
+
+ ssl = OpenSSL::SSL::SSLSocket.new(obj)
+ assert_same obj, ssl.to_io
+
+ ssl.connect
+ ssl.puts "abc"; assert_equal "abc\n", ssl.gets
+ ensure
+ ssl&.close
+ tcp&.close
+ end
+ end
+
+ def test_synthetic_io_write_nonblock_exception
+ start_server(ignore_listener_error: true) do |port|
+ tcp = TCPSocket.new("127.0.0.1", port)
+ obj = Object.new
+ [:read_nonblock, :wait_readable, :wait_writable, :closed?].each do |name|
+ obj.define_singleton_method(name) { |*args, **kwargs|
+ tcp.__send__(name, *args, **kwargs) }
+ end
+
+ # SSLSocket#connect calls write_nonblock at least twice: to write
+ # ClientHello and Finished. Let's raise an exception in the 2nd call.
+ called = 0
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
+ raise "foo" if (called += 1) == 2
+ tcp.write_nonblock(*args, **kwargs)
+ }
+
+ ssl = OpenSSL::SSL::SSLSocket.new(obj)
+ assert_raise_with_message(RuntimeError, "foo") { ssl.connect }
+ ensure
+ ssl&.close
+ tcp&.close
+ end
+ end
+
+ def test_synthetic_io_error_in_cb_then_error_in_write
+ # If SSLContext#servername_cb fails, it must send the "unrecognized_name"
+ # alert. If another error occurs while writing the alert to the underlying
+ # socket, the original exception from the servername_cb is suppressed and
+ # the new exception is raised.
+ sock1, sock2 = socketpair
+
+ t = Thread.new {
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1)
+ s1.hostname = "localhost"
+ begin
+ s1.connect
+ rescue
+ end
+ }
+
+ called = []
+ ctx2 = OpenSSL::SSL::SSLContext.new
+ ctx2.servername_cb = lambda { |args|
+ called << :servername_cb
+ raise "servername_cb"
+ }
+ obj = Object.new
+ obj.define_singleton_method(:method_missing) { |name, *args, **kwargs|
+ sock2.__send__(name, *args, **kwargs)
+ }
+ obj.define_singleton_method(:respond_to_missing?) { |name, *args, **kwargs|
+ sock2.respond_to?(name, *args, **kwargs)
+ }
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
+ called << :write_nonblock
+ throw :throw_from, :write_nonblock
+ }
+ s2 = OpenSSL::SSL::SSLSocket.new(obj, ctx2)
+
+ ret = assert_warning(/servername_cb/) {
+ catch(:throw_from) { s2.accept }
+ }
+ assert_equal(:write_nonblock, ret)
+ assert_equal([:servername_cb, :write_nonblock], called)
+ sock2.close
+ assert t.join
+ ensure
+ sock1.close
+ sock2.close
+ t.kill.join
+ end
+
def test_add_certificate
ctx_proc = -> ctx {
# Unset values set by start_server
@@ -478,18 +564,29 @@ def test_client_auth_success
}
end
- def test_client_cert_cb_ignore_error
+ def test_client_cert_cb_bad_return
vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT
start_server(verify_mode: vflag, ignore_listener_error: true) do |port|
ctx = OpenSSL::SSL::SSLContext.new
ctx.client_cert_cb = -> ssl {
- raise "exception in client_cert_cb must be suppressed"
+ assert_kind_of(OpenSSL::SSL::SSLSocket, ssl)
+ [@cli_cert, OpenSSL::PKey.read(@cli_key.public_to_der)]
}
- # 1. Exception in client_cert_cb is suppressed
- # 2. No client certificate will be sent to the server
- # 3. SSL_VERIFY_FAIL_IF_NO_PEER_CERT causes the handshake to fail
- assert_handshake_error {
- server_connect(port, ctx) { |ssl| ssl.puts("abc"); ssl.gets }
+ assert_raise_with_message(ArgumentError, /private key/) {
+ server_connect(port, ctx) { raise "unreachable" }
+ }
+ end
+ end
+
+ def test_client_cert_cb_error
+ vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT
+ start_server(verify_mode: vflag, ignore_listener_error: true) do |port|
+ ctx = OpenSSL::SSL::SSLContext.new
+ ctx.client_cert_cb = -> ssl {
+ raise "exception in client_cert_cb"
+ }
+ assert_raise_with_message(RuntimeError, /exception in client_cert_cb/) {
+ server_connect(port, ctx) { raise "unreachable" }
}
end
end
@@ -1583,7 +1680,12 @@ def test_options_disable_versions
# Client only supports TLS 1.3
ctx2 = OpenSSL::SSL::SSLContext.new
ctx2.min_version = ctx2.max_version = OpenSSL::SSL::TLS1_3_VERSION
- assert_nothing_raised { server_connect(port, ctx2) { } }
+ assert_nothing_raised {
+ server_connect(port, ctx2) { |ssl|
+ # Ensure SSL_accept() finishes successfully
+ ssl.puts("abc"); ssl.gets
+ }
+ }
}
# Server only supports TLS 1.2
@@ -1616,7 +1718,10 @@ def test_ssl_methods_constant
def test_renegotiation_cb
num_handshakes = 0
- renegotiation_cb = Proc.new { |ssl| num_handshakes += 1 }
+ renegotiation_cb = lambda { |ssl|
+ assert_kind_of(OpenSSL::SSL::SSLSocket, ssl)
+ num_handshakes += 1
+ }
ctx_proc = Proc.new { |ctx| ctx.renegotiation_cb = renegotiation_cb }
start_server(ctx_proc: ctx_proc) { |port|
server_connect(port) { |ssl|
@@ -1624,6 +1729,27 @@ def test_renegotiation_cb
ssl.puts "abc"; assert_equal "abc\n", ssl.gets
}
}
+
+ sock1, sock2 = socketpair
+ th = Thread.new {
+ ssl2 = OpenSSL::SSL::SSLSocket.new(sock2)
+ begin
+ ssl2.connect_nonblock(exception: false)
+ rescue OpenSSL::SSL::SSLError
+ end
+ }
+ ctx1 = OpenSSL::SSL::SSLContext.new
+ ctx1.renegotiation_cb = lambda { |ssl| raise "in renegotiation_cb" }
+ ctx1.add_certificate(@svr_cert, @svr_key)
+ ssl1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1)
+ assert_raise_with_message(RuntimeError, "in renegotiation_cb") {
+ ssl1.accept
+ }
+ th.join
+ ensure
+ th&.kill&.join
+ sock1&.close
+ sock2&.close
end
def test_alpn_protocol_selection_ary