3131#include "py/mperrno.h"
3232#include "py/runtime.h"
3333#include "shared-bindings/socketpool/SocketPool.h"
34+ #include "shared-bindings/ssl/SSLSocket.h"
35+ #include "common-hal/ssl/SSLSocket.h"
3436#include "supervisor/port.h"
3537#include "supervisor/shared/tick.h"
3638#include "supervisor/workflow.h"
4446StackType_t socket_select_stack [2 * configMINIMAL_STACK_SIZE ];
4547
4648STATIC int open_socket_fds [CONFIG_LWIP_MAX_SOCKETS ];
47- STATIC bool user_socket [CONFIG_LWIP_MAX_SOCKETS ];
49+ STATIC socketpool_socket_obj_t * user_socket [CONFIG_LWIP_MAX_SOCKETS ];
4850StaticTask_t socket_select_task_handle ;
4951STATIC int socket_change_fd = -1 ;
5052
@@ -117,7 +119,7 @@ void socket_user_reset(void) {
117119
118120 for (size_t i = 0 ; i < MP_ARRAY_SIZE (open_socket_fds ); i ++ ) {
119121 open_socket_fds [i ] = -1 ;
120- user_socket [i ] = false ;
122+ user_socket [i ] = NULL ;
121123 }
122124 socket_change_fd = eventfd (0 , 0 );
123125 // Run this at the same priority as CP so that the web workflow background task can be
@@ -134,12 +136,13 @@ void socket_user_reset(void) {
134136
135137 for (size_t i = 0 ; i < MP_ARRAY_SIZE (open_socket_fds ); i ++ ) {
136138 if (open_socket_fds [i ] >= 0 && user_socket [i ]) {
139+ common_hal_socketpool_socket_close (user_socket [i ]);
137140 int num = open_socket_fds [i ];
138141 // Close automatically clears socket handle
139142 lwip_shutdown (num , SHUT_RDWR );
140143 lwip_close (num );
141144 open_socket_fds [i ] = -1 ;
142- user_socket [i ] = false ;
145+ user_socket [i ] = NULL ;
143146 }
144147 }
145148}
@@ -171,10 +174,10 @@ STATIC void unregister_open_socket(int fd) {
171174 }
172175}
173176
174- STATIC void mark_user_socket (int fd ) {
177+ STATIC void mark_user_socket (int fd , socketpool_socket_obj_t * obj ) {
175178 for (size_t i = 0 ; i < MP_ARRAY_SIZE (open_socket_fds ); i ++ ) {
176179 if (open_socket_fds [i ] == fd ) {
177- user_socket [i ] = true ;
180+ user_socket [i ] = obj ;
178181 return ;
179182 }
180183 }
@@ -236,7 +239,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket(socketpool_socketpool_obj_
236239 if (!socketpool_socket (self , family , type , sock )) {
237240 mp_raise_RuntimeError (translate ("Out of sockets" ));
238241 }
239- mark_user_socket (sock -> num );
242+ mark_user_socket (sock -> num , sock );
240243 return sock ;
241244}
242245
@@ -292,12 +295,12 @@ int socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_
292295
293296socketpool_socket_obj_t * common_hal_socketpool_socket_accept (socketpool_socket_obj_t * self ,
294297 uint8_t * ip , uint32_t * port ) {
298+ socketpool_socket_obj_t * sock = m_new_obj_with_finaliser (socketpool_socket_obj_t );
295299 int newsoc = socketpool_socket_accept (self , ip , port , NULL );
296300
297301 if (newsoc > 0 ) {
298- mark_user_socket (newsoc );
299302 // Create the socket
300- socketpool_socket_obj_t * sock = m_new_obj_with_finaliser ( socketpool_socket_obj_t );
303+ mark_user_socket ( newsoc , sock );
301304 sock -> base .type = & socketpool_socket_type ;
302305 sock -> num = newsoc ;
303306 sock -> pool = self -> pool ;
@@ -338,6 +341,12 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
338341}
339342
340343void socketpool_socket_close (socketpool_socket_obj_t * self ) {
344+ if (self -> ssl_socket ) {
345+ ssl_sslsocket_obj_t * ssl_socket = self -> ssl_socket ;
346+ self -> ssl_socket = NULL ;
347+ common_hal_ssl_sslsocket_close (ssl_socket );
348+ return ;
349+ }
341350 self -> connected = false;
342351 if (self -> num >= 0 ) {
343352 lwip_shutdown (self -> num , SHUT_RDWR );
0 commit comments