2424 * THE SOFTWARE.
2525 */
2626
27+ // Include strchrnul()
28+ #define _GNU_SOURCE
29+
2730#include <stdarg.h>
2831#include <string.h>
2932
@@ -85,8 +88,8 @@ typedef struct {
8588 char destination [256 ];
8689 char header_key [64 ];
8790 char header_value [256 ];
88- // We store the origin so we can reply back with it.
89- char origin [64 ];
91+ char origin [ 64 ]; // We store the origin so we can reply back with it.
92+ char host [64 ]; // We store the host to check against origin.
9093 size_t content_length ;
9194 size_t offset ;
9295 uint64_t timestamp_ms ;
@@ -454,49 +457,33 @@ static bool _endswith(const char *str, const char *suffix) {
454457 return strcmp (str + (strlen (str ) - strlen (suffix )), suffix ) == 0 ;
455458}
456459
457- const char * ok_hosts [] = {
458- "127.0.0.1" ,
459- "localhost" ,
460- };
461-
462- static bool _origin_ok (const char * origin ) {
463- const char * http = "http://" ;
460+ const char http_scheme [] = "http://" ;
461+ #define PREFIX_HTTP_LEN (sizeof(http_scheme) - 1)
464462
465- // note: redirected requests send an Origin of "null" and will be caught by this
466- if (strncmp (origin , http , strlen (http )) != 0 ) {
467- return false;
468- }
469- // These are prefix checks up to : so that any port works.
470- // TODO: Support DHCP hostname in addition to MDNS.
471- const char * end ;
472- #if CIRCUITPY_MDNS
473- if (!common_hal_mdns_server_deinited (& mdns )) {
474- const char * local = ".local" ;
475- const char * hostname = common_hal_mdns_server_get_hostname (& mdns );
476- end = origin + strlen (http ) + strlen (hostname ) + strlen (local );
477- if (strncmp (origin + strlen (http ), hostname , strlen (hostname )) == 0 &&
478- strncmp (origin + strlen (http ) + strlen (hostname ), local , strlen (local )) == 0 &&
479- (end [0 ] == '\0' || end [0 ] == ':' )) {
480- return true;
481- }
463+ static bool _origin_ok (_request * request ) {
464+ // Origin may be 'null'
465+ if (request -> origin [0 ] == '\0' ) {
466+ return true;
482467 }
483- #endif
484-
485- _update_encoded_ip ();
486- end = origin + strlen (http ) + strlen (_our_ip_encoded );
487- if (strncmp (origin + strlen (http ), _our_ip_encoded , strlen (_our_ip_encoded )) == 0 &&
488- (end [0 ] == '\0' || end [0 ] == ':' )) {
468+ // Origin has http prefix?
469+ if (strncmp (request -> origin , http_scheme , PREFIX_HTTP_LEN ) != 0 ) {
470+ // Not HTTP scheme request - ok
471+ request -> origin [0 ] = '\0' ;
489472 return true;
490473 }
491-
492- for (size_t i = 0 ; i < MP_ARRAY_SIZE (ok_hosts ); i ++ ) {
493- // Allows any port
494- end = origin + strlen (http ) + strlen (ok_hosts [i ]);
495- if (strncmp (origin + strlen (http ), ok_hosts [i ], strlen (ok_hosts [i ])) == 0
496- && (end [0 ] == '\0' || end [0 ] == ':' )) {
474+ // Host given?
475+ if (request -> host [0 ] != '\0' ) {
476+ // OK if host and origin match (fqdn + port #)
477+ if (strcmp (request -> host , & request -> origin [PREFIX_HTTP_LEN ]) == 0 ) {
478+ return true;
479+ }
480+ // DEBUG: OK if origin is 'localhost' (ignoring port #)
481+ * strchrnul (& request -> origin [PREFIX_HTTP_LEN ], ':' ) = '\0' ;
482+ if (strcmp (& request -> origin [PREFIX_HTTP_LEN ], "localhost" ) == 0 ) {
497483 return true;
498484 }
499485 }
486+ // Otherwise deny request
500487 return false;
501488}
502489
@@ -517,8 +504,8 @@ static void _cors_header(socketpool_socket_obj_t *socket, _request *request) {
517504 _send_strs (socket ,
518505 "Access-Control-Allow-Credentials: true\r\n" ,
519506 "Vary: Origin, Accept, Upgrade\r\n" ,
520- "Access-Control-Allow-Origin: *\r\n " ,
521- NULL );
507+ "Access-Control-Allow-Origin: " ,
508+ ( request -> origin [ 0 ] == '\0' ) ? "*" : request -> origin , "\r\n" , NULL );
522509}
523510
524511static void _reply_continue (socketpool_socket_obj_t * socket , _request * request ) {
@@ -1086,11 +1073,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
10861073 #else
10871074 _reply_missing (socket , request );
10881075 #endif
1089-
1090- // For now until CORS is sorted, allow always the origin requester.
1091- // Note: caller knows who we are better than us. CORS is not security
1092- // unless browser cooperates. Do not rely on mDNS or IP.
1093- } else if (strlen (request -> origin ) > 0 && !_origin_ok (request -> origin )) {
1076+ } else if (!_origin_ok (request )) {
10941077 _reply_forbidden (socket , request );
10951078 } else if (strncmp (request -> path , "/fs/" , 4 ) == 0 ) {
10961079 if (strcasecmp (request -> method , "OPTIONS" ) == 0 ) {
@@ -1314,6 +1297,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
13141297static void _reset_request (_request * request ) {
13151298 request -> state = STATE_METHOD ;
13161299 request -> origin [0 ] = '\0' ;
1300+ request -> host [0 ] = '\0' ;
13171301 request -> content_length = 0 ;
13181302 request -> offset = 0 ;
13191303 request -> timestamp_ms = 0 ;
@@ -1340,6 +1324,7 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
13401324 if (len == 0 || len == - MP_ENOTCONN ) {
13411325 // Disconnect - clear 'in-progress'
13421326 _reset_request (request );
1327+ common_hal_socketpool_socket_close (socket );
13431328 }
13441329 break ;
13451330 }
@@ -1421,14 +1406,17 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
14211406 request -> redirect = strncmp (request -> header_value , cp_local , strlen (cp_local )) == 0 &&
14221407 (strlen (request -> header_value ) == strlen (cp_local ) ||
14231408 request -> header_value [strlen (cp_local )] == ':' );
1409+ strncpy (request -> host , request -> header_value , sizeof (request -> host ) - 1 );
1410+ request -> host [sizeof (request -> host ) - 1 ] = '\0' ;
14241411 } else if (strcasecmp (request -> header_key , "Content-Length" ) == 0 ) {
14251412 request -> content_length = strtoul (request -> header_value , NULL , 10 );
14261413 } else if (strcasecmp (request -> header_key , "Expect" ) == 0 ) {
14271414 request -> expect = strcmp (request -> header_value , "100-continue" ) == 0 ;
14281415 } else if (strcasecmp (request -> header_key , "Accept" ) == 0 ) {
14291416 request -> json = strcasecmp (request -> header_value , "application/json" ) == 0 ;
14301417 } else if (strcasecmp (request -> header_key , "Origin" ) == 0 ) {
1431- strcpy (request -> origin , request -> header_value );
1418+ strncpy (request -> origin , request -> header_value , sizeof (request -> origin ) - 1 );
1419+ request -> origin [sizeof (request -> origin ) - 1 ] = '\0' ;
14321420 } else if (strcasecmp (request -> header_key , "X-Timestamp" ) == 0 ) {
14331421 request -> timestamp_ms = strtoull (request -> header_value , NULL , 10 );
14341422 } else if (strcasecmp (request -> header_key , "Upgrade" ) == 0 ) {
0 commit comments