2727 */
2828
2929#include "shared-bindings/ssl/SSLSocket.h"
30- #include "shared-bindings/socketpool/Socket.h"
3130#include "shared-bindings/ssl/SSLContext.h"
32- #include "shared-bindings/socketpool/SocketPool.h"
33- #include "shared-bindings/socketpool/Socket.h"
3431
3532#include "shared/runtime/interrupt_char.h"
33+ #include "shared/netutils/netutils.h"
3634#include "py/mperrno.h"
3735#include "py/mphal.h"
3836#include "py/objstr.h"
@@ -108,11 +106,72 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
108106 #endif
109107}
110108
109+ STATIC int call_method_errno (size_t n_args , const mp_obj_t * args ) {
110+ nlr_buf_t nlr ;
111+ mp_int_t result = - MP_EINVAL ;
112+ if (nlr_push (& nlr ) == 0 ) {
113+ mp_obj_t obj_result = mp_call_method_n_kw (n_args , 0 , args );
114+ result = (obj_result == mp_const_none ) ? 0 : mp_obj_get_int (obj_result );
115+ nlr_pop ();
116+ return result ;
117+ } else {
118+ mp_obj_t exc = MP_OBJ_FROM_PTR (nlr .ret_val );
119+ if (nlr_push (& nlr ) == 0 ) {
120+ result = - mp_obj_get_int (mp_load_attr (exc , MP_QSTR_errno ));
121+ nlr_pop ();
122+ }
123+ }
124+ return result ;
125+ }
126+
127+ static int ssl_socket_send (ssl_sslsocket_obj_t * self , const byte * buf , size_t len ) {
128+ mp_obj_array_t mv ;
129+ mp_obj_memoryview_init (& mv , 'B' , 0 , len , (void * )buf );
130+
131+ self -> send_args [2 ] = MP_OBJ_FROM_PTR (& mv );
132+ return call_method_errno (1 , self -> send_args );
133+ }
134+
135+ static int ssl_socket_recv_into (ssl_sslsocket_obj_t * self , byte * buf , size_t len ) {
136+ mp_obj_array_t mv ;
137+ mp_obj_memoryview_init (& mv , 'B' | MP_OBJ_ARRAY_TYPECODE_FLAG_RW , 0 , len , buf );
138+
139+ self -> recv_into_args [2 ] = MP_OBJ_FROM_PTR (& mv );
140+ return call_method_errno (1 , self -> recv_into_args );
141+ }
142+
143+ static int ssl_socket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
144+ self -> connect_args [2 ] = addr_in ;
145+ return call_method_errno (1 , self -> connect_args );
146+ }
147+
148+ static int ssl_socket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
149+ self -> bind_args [2 ] = addr_in ;
150+ return call_method_errno (1 , self -> bind_args );
151+ }
152+
153+ static int ssl_socket_close (ssl_sslsocket_obj_t * self ) {
154+ return call_method_errno (0 , self -> close_args );
155+ }
156+
157+ static int ssl_socket_settimeout (ssl_sslsocket_obj_t * self , mp_int_t timeout_ms ) {
158+ self -> settimeout_args [2 ] = mp_obj_new_float (timeout_ms * MICROPY_FLOAT_CONST (1e-3 ));
159+ return call_method_errno (1 , self -> settimeout_args );
160+ }
161+
162+ static int ssl_socket_listen (ssl_sslsocket_obj_t * self , mp_int_t backlog ) {
163+ self -> listen_args [2 ] = MP_OBJ_NEW_SMALL_INT (backlog );
164+ return call_method_errno (1 , self -> listen_args );
165+ }
166+
167+ static mp_obj_t ssl_socket_accept (ssl_sslsocket_obj_t * self ) {
168+ return mp_call_method_n_kw (0 , 0 , self -> accept_args );
169+ }
170+
111171STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
112- mp_obj_t sock = * ( mp_obj_t * )ctx ;
172+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
113173
114- // mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
115- mp_int_t out_sz = socketpool_socket_send (sock , buf , len );
174+ mp_int_t out_sz = ssl_socket_send (self , buf , len );
116175 DEBUG_PRINT ("socket_send() -> %d" , out_sz );
117176 if (out_sz < 0 ) {
118177 int err = - out_sz ;
@@ -128,9 +187,9 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
128187
129188// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
130189STATIC int _mbedtls_ssl_recv (void * ctx , byte * buf , size_t len ) {
131- mp_obj_t sock = * ( mp_obj_t * )ctx ;
190+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
132191
133- mp_int_t out_sz = socketpool_socket_recv_into ( sock , buf , len );
192+ mp_int_t out_sz = ssl_socket_recv_into ( self , buf , len );
134193 DEBUG_PRINT ("socket_recv() -> %d" , out_sz );
135194 if (out_sz < 0 ) {
136195 int err = - out_sz ;
@@ -155,16 +214,26 @@ static int urandom_adapter(void *unused, unsigned char *buf, size_t n) {
155214#endif
156215
157216ssl_sslsocket_obj_t * common_hal_ssl_sslcontext_wrap_socket (ssl_sslcontext_obj_t * self ,
158- socketpool_socket_obj_t * socket , bool server_side , const char * server_hostname ) {
217+ mp_obj_t socket , bool server_side , const char * server_hostname ) {
159218
160- if (socket -> type != SOCKETPOOL_SOCK_STREAM ) {
219+ mp_int_t socket_type = mp_obj_get_int (mp_load_attr (socket , MP_QSTR_type ));
220+ if (socket_type != SOCKETPOOL_SOCK_STREAM ) {
161221 mp_raise_RuntimeError (MP_ERROR_TEXT ("Invalid socket for TLS" ));
162222 }
163223
164224 ssl_sslsocket_obj_t * o = m_new_obj_with_finaliser (ssl_sslsocket_obj_t );
165225 o -> base .type = & ssl_sslsocket_type ;
166226 o -> ssl_context = self ;
167- o -> sock = socket ;
227+ o -> sock_obj = socket ;
228+
229+ mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
230+ mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
231+ mp_load_method (socket , MP_QSTR_close , o -> close_args );
232+ mp_load_method (socket , MP_QSTR_connect , o -> connect_args );
233+ mp_load_method (socket , MP_QSTR_listen , o -> listen_args );
234+ mp_load_method (socket , MP_QSTR_recv_into , o -> recv_into_args );
235+ mp_load_method (socket , MP_QSTR_send , o -> send_args );
236+ mp_load_method (socket , MP_QSTR_settimeout , o -> settimeout_args );
168237
169238 mbedtls_ssl_init (& o -> ssl );
170239 mbedtls_ssl_config_init (& o -> conf );
@@ -223,7 +292,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
223292 }
224293 }
225294
226- mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
295+ mbedtls_ssl_set_bio (& o -> ssl , o , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
227296
228297 if (self -> cert_buf .buf != NULL ) {
229298 #if MBEDTLS_VERSION_MAJOR >= 3
@@ -292,13 +361,13 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
292361 mbedtls_raise_error (ret );
293362}
294363
295- size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
296- return common_hal_socketpool_socket_bind (self -> sock , host , hostlen , port );
364+ size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
365+ return ssl_socket_bind (self , addr_in );
297366}
298367
299368void common_hal_ssl_sslsocket_close (ssl_sslsocket_obj_t * self ) {
300369 self -> closed = true;
301- common_hal_socketpool_socket_close (self -> sock );
370+ ssl_socket_close (self );
302371 mbedtls_pk_free (& self -> pkey );
303372 mbedtls_x509_crt_free (& self -> cert );
304373 mbedtls_x509_crt_free (& self -> cacert );
@@ -344,8 +413,8 @@ STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
344413 }
345414}
346415
347- void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
348- common_hal_socketpool_socket_connect (self -> sock , host , hostlen , port );
416+ void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
417+ ssl_socket_connect (self , addr_in );
349418 do_handshake (self );
350419}
351420
@@ -358,16 +427,21 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
358427}
359428
360429bool common_hal_ssl_sslsocket_listen (ssl_sslsocket_obj_t * self , int backlog ) {
361- return common_hal_socketpool_socket_listen (self -> sock , backlog );
430+ return ssl_socket_listen (self , backlog );
362431}
363432
364- ssl_sslsocket_obj_t * common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self , uint8_t * ip , uint32_t * port ) {
365- socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept (self -> sock , ip , port );
433+ mp_obj_t common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self ) {
434+ mp_obj_t accepted = ssl_socket_accept (self );
435+ mp_obj_t sock = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
366436 ssl_sslsocket_obj_t * sslsock = common_hal_ssl_sslcontext_wrap_socket (self -> ssl_context , sock , true, NULL );
367437 do_handshake (sslsock );
368- return sslsock ;
438+ mp_obj_t peer = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
439+ mp_obj_t tuple_contents [2 ];
440+ tuple_contents [0 ] = MP_OBJ_FROM_PTR (sslsock );
441+ tuple_contents [1 ] = peer ;
442+ return mp_obj_new_tuple (2 , tuple_contents );
369443}
370444
371445void common_hal_ssl_sslsocket_settimeout (ssl_sslsocket_obj_t * self , uint32_t timeout_ms ) {
372- common_hal_socketpool_socket_settimeout (self -> sock , timeout_ms );
446+ ssl_socket_settimeout (self , timeout_ms );
373447}
0 commit comments