Skip to content

Commit 6108a9d

Browse files
committed
undo some of the call-sites in SSLSocket - account for sub-classes
1 parent 0fbba2f commit 6108a9d

2 files changed

Lines changed: 93 additions & 53 deletions

File tree

src/main/java/org/jruby/ext/openssl/SSLSocket.java

Lines changed: 79 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public IRubyObject allocate(Ruby runtime, RubyClass klass) {
7878
private enum CallSiteIndex {
7979

8080
// self
81-
hostname("hostname"),
81+
//hostname("hostname"),
8282
//sync_close("sync_close"),
8383
//sync_close_w("sync_close="),
8484
// io
@@ -87,8 +87,8 @@ private enum CallSiteIndex {
8787
sync("sync"),
8888
sync_w("sync="),
8989
flush("flush"),
90-
close("close"),
91-
closed_p("closed?"),
90+
//close("close"),
91+
//closed_p("closed?"),
9292
// ssl_context
9393
verify_mode("verify_mode");
9494

@@ -145,8 +145,8 @@ private static RaiseException newSSLErrorFromHandshake(Ruby runtime, SSLHandshak
145145
return SSL.newSSLError(runtime, cause);
146146
}
147147

148-
private CallSite callSite(final CallSiteIndex index) {
149-
return getMetaClass().getExtraCallSites()[ index.ordinal() ];
148+
private static CallSite callSite(final CallSite[] sites, final CallSiteIndex index) {
149+
return sites[ index.ordinal() ];
150150
}
151151

152152
private SSLContext sslContext;
@@ -192,9 +192,18 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a
192192

193193
private IRubyObject set_io_nonblock_checked(final ThreadContext context, RubyBoolean value) {
194194
// @io.nonblock = true if @io.respond_to?(:nonblock=)
195-
IRubyObject respond = callSite(CallSiteIndex._respond_to_nonblock_w).call(context, io, io, context.runtime.newSymbol("nonblock="));
195+
final CallSite[] sites = getMetaClass().getExtraCallSites();
196+
if (sites == null) return fallback_set_io_nonblock_checked(context, value);
197+
IRubyObject respond = callSite(sites, CallSiteIndex._respond_to_nonblock_w).call(context, io, io, context.runtime.newSymbol("nonblock="));
196198
if (respond.isTrue()) {
197-
return callSite(CallSiteIndex.nonblock_w).call(context, io, io, value);
199+
return callSite(sites, CallSiteIndex.nonblock_w).call(context, io, io, value);
200+
}
201+
return context.nil;
202+
}
203+
204+
private IRubyObject fallback_set_io_nonblock_checked(ThreadContext context, RubyBoolean value) {
205+
if (io.respondsTo("nonblock=")) {
206+
return io.callMethod(context, "nonblock=", value);
198207
}
199208
return context.nil;
200209
}
@@ -206,8 +215,8 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context)
206215

207216
// Server Name Indication (SNI) RFC 3546
208217
// SNI support will not be attempted unless hostname is explicitly set by the caller
209-
IRubyObject hostname = callSite(CallSiteIndex.hostname).call(context, this, this); // self.hostname
210-
String peerHost = hostname.toString();
218+
IRubyObject hostname = getInstanceVariable("@hostname");
219+
String peerHost = hostname == null ? null : hostname.toString();
211220
final int peerPort = socketChannelImpl().getRemotePort();
212221
engine = sslContext.createSSLEngine(peerHost, peerPort);
213222

@@ -220,7 +229,7 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context)
220229
netData.limit(0);
221230
dummy = ByteBuffer.allocate(0);
222231
this.engine = engine;
223-
copySessionSetupIfSet();
232+
copySessionSetupIfSet(context);
224233
return engine;
225234
}
226235

@@ -232,12 +241,24 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context)
232241

233242
@JRubyMethod(name = "sync")
234243
public IRubyObject sync(final ThreadContext context) {
235-
return callSite(CallSiteIndex.sync).call(context, io, io); // io.sync
244+
final CallSite[] sites = getMetaClass().getExtraCallSites();
245+
if (sites == null) return fallback_sync(context);
246+
return callSite(sites, CallSiteIndex.sync).call(context, io, io); // io.sync
247+
}
248+
249+
private IRubyObject fallback_sync(final ThreadContext context) {
250+
return io.callMethod(context, "sync");
236251
}
237252

238253
@JRubyMethod(name = "sync=")
239254
public IRubyObject set_sync(final ThreadContext context, final IRubyObject sync) {
240-
return callSite(CallSiteIndex.sync_w).call(context, io, io, sync); // io.sync = sync
255+
final CallSite[] sites = getMetaClass().getExtraCallSites();
256+
if (sites == null) return fallback_set_sync(context, sync);
257+
return callSite(sites, CallSiteIndex.sync_w).call(context, io, io, sync); // io.sync = sync
258+
}
259+
260+
private IRubyObject fallback_set_sync(final ThreadContext context, final IRubyObject sync) {
261+
return io.callMethod(context, "sync=", sync);
241262
}
242263

243264
@JRubyMethod
@@ -335,7 +356,7 @@ private IRubyObject acceptImpl(final ThreadContext context, final boolean blocki
335356
if ( ! initialHandshake ) {
336357
final SSLEngine engine = ossl_ssl_setup(context);
337358
engine.setUseClientMode(false);
338-
final IRubyObject verify_mode = callSite(CallSiteIndex.verify_mode).call(context, sslContext, sslContext);
359+
final IRubyObject verify_mode = verify_mode(context);
339360
if ( verify_mode != context.nil ) {
340361
final int verify = RubyNumeric.fix2int(verify_mode);
341362
if ( verify == 0 ) { // VERIFY_NONE
@@ -395,6 +416,16 @@ private IRubyObject acceptImpl(final ThreadContext context, final boolean blocki
395416
return this;
396417
}
397418

419+
final IRubyObject verify_mode(final ThreadContext context) {
420+
final CallSite[] sites = getMetaClass().getExtraCallSites();
421+
if (sites == null) return fallback_verify_mode(context);
422+
return callSite(sites, CallSiteIndex.verify_mode).call(context, sslContext, sslContext);
423+
}
424+
425+
private IRubyObject fallback_verify_mode(final ThreadContext context) {
426+
return sslContext.callMethod(context, "verify_mode");
427+
}
428+
398429
@JRubyMethod
399430
public IRubyObject verify_result(final ThreadContext context) {
400431
final Ruby runtime = context.runtime;
@@ -896,7 +927,7 @@ private IRubyObject syswriteImpl(final ThreadContext context,
896927
written = write(buff, blocking);
897928
}
898929

899-
callSite(CallSiteIndex.flush).call(context, io, io); // io.flush
930+
io_flush(context); // io.flush
900931

901932
return runtime.newFixnum(written);
902933
}
@@ -905,6 +936,16 @@ private IRubyObject syswriteImpl(final ThreadContext context,
905936
}
906937
}
907938

939+
private IRubyObject io_flush(final ThreadContext context) {
940+
final CallSite[] sites = getMetaClass().getExtraCallSites();
941+
if (sites == null) return fallback_io_flush(context);
942+
return callSite(sites, CallSiteIndex.flush).call(context, io, io);
943+
}
944+
945+
private IRubyObject fallback_io_flush(final ThreadContext context) {
946+
return io.callMethod(context, "flush");
947+
}
948+
908949
@JRubyMethod
909950
public IRubyObject syswrite(ThreadContext context, IRubyObject arg) {
910951
return syswriteImpl(context, arg, true, true);
@@ -963,29 +1004,15 @@ private void close(boolean force) {
9631004
}
9641005

9651006
@JRubyMethod
966-
public IRubyObject sysclose(final ThreadContext context) {
967-
if ( io_closed_p(context).isTrue() ) return context.nil;
968-
1007+
public IRubyObject stop(final ThreadContext context) {
9691008
// no need to try shutdown when it's a server
9701009
close( sslContext.isProtocolForClient() );
971-
972-
if ( getInstanceVariable("@sync_close").isTrue() ) return io_close(context);
973-
9741010
return context.nil;
9751011
}
9761012

977-
private IRubyObject io_closed_p(final ThreadContext context) { // io.closed?
978-
return callSite(CallSiteIndex.closed_p).call(context, io, io);
979-
}
980-
981-
private IRubyObject io_close(final ThreadContext context) { // io.close
982-
return callSite(CallSiteIndex.close).call(context, io, io);
983-
}
984-
9851013
@JRubyMethod
9861014
public IRubyObject cert(final ThreadContext context) {
987-
final Ruby runtime = context.runtime;
988-
if ( engine == null ) return runtime.getNil();
1015+
if ( engine == null ) return context.nil;
9891016

9901017
try {
9911018
Certificate[] cert = engine.getSession().getLocalCertificates();
@@ -994,20 +1021,19 @@ public IRubyObject cert(final ThreadContext context) {
9941021
}
9951022
}
9961023
catch (CertificateEncodingException e) {
997-
throw X509Cert.newCertificateError(runtime, e);
1024+
throw X509Cert.newCertificateError(context.runtime, e);
9981025
}
999-
return runtime.getNil();
1026+
return context.nil;
10001027
}
10011028

1002-
// @Deprecated
1029+
@Deprecated
10031030
public final IRubyObject cert() {
10041031
return cert(getRuntime().getCurrentContext());
10051032
}
10061033

10071034
@JRubyMethod
10081035
public IRubyObject peer_cert(final ThreadContext context) {
1009-
final Ruby runtime = context.runtime;
1010-
if ( engine == null ) return runtime.getNil();
1036+
if ( engine == null ) return context.nil;
10111037

10121038
try {
10131039
Certificate[] cert = engine.getSession().getPeerCertificates();
@@ -1016,17 +1042,17 @@ public IRubyObject peer_cert(final ThreadContext context) {
10161042
}
10171043
}
10181044
catch (CertificateEncodingException e) {
1019-
throw X509Cert.newCertificateError(runtime, e);
1045+
throw X509Cert.newCertificateError(context.runtime, e);
10201046
}
10211047
catch (SSLPeerUnverifiedException e) {
1022-
if (OpenSSL.isDebug(runtime)) {
1023-
runtime.getWarnings().warning(String.format("%s: %s", e.getClass().getName(), e.getMessage()));
1048+
if (OpenSSL.isDebug(context.runtime)) {
1049+
context.runtime.getWarnings().warning(String.format("%s: %s", e.getClass().getName(), e.getMessage()));
10241050
}
10251051
}
1026-
return runtime.getNil();
1052+
return context.nil;
10271053
}
10281054

1029-
// @Deprecated
1055+
@Deprecated
10301056
public final IRubyObject peer_cert() {
10311057
return peer_cert(getRuntime().getCurrentContext());
10321058
}
@@ -1055,7 +1081,7 @@ public IRubyObject peer_cert_chain(final ThreadContext context) {
10551081
return runtime.getNil();
10561082
}
10571083

1058-
// @Deprecated
1084+
@Deprecated
10591085
public final IRubyObject peer_cert_chain() {
10601086
return peer_cert_chain(getRuntime().getCurrentContext());
10611087
}
@@ -1134,22 +1160,26 @@ private SSLSession getSession(final Ruby runtime) {
11341160
private transient SSLSession setSession = null;
11351161

11361162
@JRubyMethod(name = "session=")
1137-
public IRubyObject set_session(IRubyObject session) {
1163+
public IRubyObject set_session(final ThreadContext context, IRubyObject session) {
11381164
// NOTE: we can not fully support this without the SSL provider internals
11391165
// but we can assume setting a session= is meant as a forced session re-use
11401166
if ( session instanceof SSLSession ) {
11411167
setSession = (SSLSession) session;
1142-
if ( engine != null ) copySessionSetupIfSet();
1168+
if ( engine != null ) copySessionSetupIfSet(context);
11431169
}
11441170
//warn(context, "WARNING: SSLSocket#session= has not effect");
1145-
return getRuntime().getNil();
1171+
return context.nil;
11461172
}
11471173

1148-
private void copySessionSetupIfSet() {
1174+
@Deprecated
1175+
public IRubyObject set_session(IRubyObject session) {
1176+
return set_session(getRuntime().getCurrentContext(), session);
1177+
}
1178+
1179+
private void copySessionSetupIfSet(final ThreadContext context) {
11491180
if ( setSession != null ) {
11501181
if ( reusableSSLEngine() ) {
11511182
engine.setEnableSessionCreation(false);
1152-
final ThreadContext context = getRuntime().getCurrentContext();
11531183
if ( ! setSession.equals( getSession(context.runtime) ) ) {
11541184
getSession(context.runtime).set_timeout(context, setSession.timeout(context));
11551185
}
@@ -1158,9 +1188,9 @@ private void copySessionSetupIfSet() {
11581188
}
11591189

11601190
@JRubyMethod
1161-
public IRubyObject ssl_version() {
1162-
if ( engine == null ) return getRuntime().getNil();
1163-
return getRuntime().newString( engine.getSession().getProtocol() );
1191+
public IRubyObject ssl_version(ThreadContext context) {
1192+
if ( engine == null ) return context.nil;
1193+
return context.runtime.newString( engine.getSession().getProtocol() );
11641194
}
11651195

11661196
private transient SocketChannelImpl socketChannel;

src/test/ruby/ssl/test_socket.rb

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,29 @@ def test_connect_nonblock
136136
end
137137
end if RUBY_VERSION > '2.2'
138138

139+
def test_inherited_socket; require 'socket'
140+
inheritedSSLSocket = Class.new(OpenSSL::SSL::SSLSocket)
141+
142+
io_stub = STDERR.dup
143+
ctx = OpenSSL::SSL::SSLContext.new
144+
145+
assert socket = inheritedSSLSocket.new(io_stub, ctx) # does not raise
146+
assert socket.io.nonblock? if STDERR.respond_to?(:nonblock=) # >= 2.3
147+
socket.sync = true
148+
assert_equal true, socket.sync
149+
end
150+
139151
private
140152

141-
def server
142-
require 'socket'
153+
def server; require 'socket'
143154
host = "127.0.0.1"; port = 0
144155
ctx = OpenSSL::SSL::SSLContext.new()
145156
ctx.ciphers = "ADH"
146157
server = TCPServer.new(host, port)
147158
OpenSSL::SSL::SSLServer.new(server, ctx)
148159
end
149160

150-
def client(port)
151-
require 'socket'
161+
def client(port); require 'socket'
152162
host = "127.0.0.1"
153163
ctx = OpenSSL::SSL::SSLContext.new()
154164
ctx.ciphers = "ADH"

0 commit comments

Comments
 (0)