~ chicken-core (chicken-5) 18e75a43ad70ff3088135fbe45930bbb340944b9


commit 18e75a43ad70ff3088135fbe45930bbb340944b9
Author:     Peter Bex <peter.bex@xs4all.nl>
AuthorDate: Wed Nov 20 23:05:40 2013 +0100
Commit:     Christian Kellermann <ckeen@pestilenz.org>
CommitDate: Thu Dec 19 15:48:02 2013 +0100

    Several Windows-related fixes and one race condition-related fix for TCP.
    
    - Fix nonblocking socket behaviour on Windows by actually marking it nonblocking.
    - Fix socket error handling in Windows by using WSAGetLastError() instead of checking errno.
    - Declare tcp should run with interrupts disabled, to prevent race conditions between multiple threads causing TCP errors (or on UNIX, causing any error which may overwrite errno).
    
    Signed-off-by: Christian Kellermann <ckeen@pestilenz.org>

diff --git a/NEWS b/NEWS
index 8f1f6fb6..7592ce4b 100644
--- a/NEWS
+++ b/NEWS
@@ -19,6 +19,8 @@
      prematurely woken up by a signal.
   - unsetenv has been fixed on Windows.
   - The process procedure has been fixed on Windows.
+  - Nonblocking behaviour on sockets has been fixed on Windows.
+  - Possible race condition while handling TCP errors has been fixed.
   - The posix unit will no longer hang upon any error in Windows.
 
 - Platform support
diff --git a/tcp.scm b/tcp.scm
index bba60c42..40dcd8ec 100644
--- a/tcp.scm
+++ b/tcp.scm
@@ -28,11 +28,11 @@
 (declare
   (unit tcp)
   (uses extras scheduler)
+  (disable-interrupts) ; Avoid race conditions around errno/WSAGetLastError
   (export tcp-close tcp-listen tcp-connect tcp-accept tcp-accept-ready? ##sys#tcp-port->fileno tcp-listener? tcp-addresses
 	  tcp-abandon-port tcp-listener-port tcp-listener-fileno tcp-port-numbers tcp-buffer-size
 	  tcp-read-timeout tcp-write-timeout tcp-accept-timeout tcp-connect-timeout)
   (foreign-declare #<<EOF
-#include <errno.h>
 #ifdef _WIN32
 # if (defined(HAVE_WINSOCK2_H) && defined(HAVE_WS2TCPIP_H))
 #  include <winsock2.h>
@@ -41,21 +41,50 @@
 #  include <winsock.h>
 # endif
 /* Beware: winsock2.h must come BEFORE windows.h */
-# define socklen_t       int
+# define socklen_t	 int
 static WSADATA wsa;
-# define fcntl(a, b, c)  0
-# ifndef EWOULDBLOCK
-#  define EWOULDBLOCK     0
+# ifndef SHUT_RD
+#  define SHUT_RD	  SD_RECEIVE
 # endif
-# ifndef EINPROGRESS
-#  define EINPROGRESS     0
-# endif
-# ifndef EAGAIN
-#  define EAGAIN          0
+# ifndef SHUT_WR
+#  define SHUT_WR	  SD_SEND
 # endif
+
 # define typecorrect_getsockopt(socket, level, optname, optval, optlen)	\
     getsockopt(socket, level, optname, (char *)optval, optlen)
+
+static C_word make_socket_nonblocking (C_word sock) {
+  int fd = C_unfix(sock);
+  C_return(C_mk_bool(ioctlsocket(fd, FIONBIO, (void *)&fd) != SOCKET_ERROR)) ;
+}
+
+/* This is a bit of a hack, but it keeps things simple */
+static C_TLS char *last_wsa_errorstring = NULL;
+
+static char *errormsg_from_code(int code) {
+  int bufsize;
+  if (last_wsa_errorstring != NULL) {
+    LocalFree(last_wsa_errorstring);
+    last_wsa_errorstring = NULL;
+  }
+  bufsize = FormatMessage(
+	FORMAT_MESSAGE_ALLOCATE_BUFFER |
+	FORMAT_MESSAGE_FROM_SYSTEM |
+	FORMAT_MESSAGE_IGNORE_INSERTS,
+	NULL, code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
+	(LPTSTR) &last_wsa_errorstring, 0, NULL);
+  if (bufsize == 0) return "ERROR WHILE FETCHING ERROR";
+  return last_wsa_errorstring;
+}
+
+# define get_last_socket_error()  WSAGetLastError()
+# define should_retry_call()      (WSAGetLastError() == WSAEWOULDBLOCK)
+/* Not EINPROGRESS in winsock.  Nonblocking connect returns EWOULDBLOCK... */
+# define call_in_progress()       (WSAGetLastError() == WSAEWOULDBLOCK)
+# define call_was_interrupted()   (WSAGetLastError() == WSAEINTR) /* ? */
+
 #else
+# include <errno.h>
 # include <fcntl.h>
 # include <sys/socket.h>
 # include <sys/time.h>
@@ -64,12 +93,22 @@ static WSADATA wsa;
 # include <signal.h>
 # define closesocket     close
 # define INVALID_SOCKET  -1
+# define SOCKET_ERROR    -1
 # define typecorrect_getsockopt getsockopt
-#endif
 
-#ifndef SD_RECEIVE
-# define SD_RECEIVE      0
-# define SD_SEND         1
+static C_word make_socket_nonblocking (C_word sock) {
+  int fd = C_unfix(sock);
+  int val = fcntl(fd, F_GETFL, 0);
+  if(val == -1) C_return(C_SCHEME_FALSE);
+  C_return(C_mk_bool(fcntl(fd, F_SETFL, val | O_NONBLOCK) != -1));
+}
+
+# define get_last_socket_error()  errno
+# define errormsg_from_code(e)    strerror(e)
+
+# define should_retry_call()      (errno == EAGAIN || errno == EWOULDBLOCK)
+# define call_was_interrupted()   (errno == EINTR)
+# define call_in_progress()       (errno == EINPROGRESS)
 #endif
 
 #ifdef ECOS
@@ -88,9 +127,6 @@ EOF
 
 (register-feature! 'tcp)
 
-(define-foreign-variable errno int "errno")
-(define-foreign-variable strerror c-string "strerror(errno)")
-
 (define-foreign-type sockaddr* (pointer "struct sockaddr"))
 (define-foreign-type sockaddr_in* (pointer "struct sockaddr_in"))
 
@@ -99,15 +135,18 @@ EOF
 (define-foreign-variable _sock_dgram int "SOCK_DGRAM")
 (define-foreign-variable _sockaddr_size int "sizeof(struct sockaddr)")
 (define-foreign-variable _sockaddr_in_size int "sizeof(struct sockaddr_in)")
-(define-foreign-variable _sd_receive int "SD_RECEIVE")
-(define-foreign-variable _sd_send int "SD_SEND")
+(define-foreign-variable _shut_rd int "SHUT_RD")
+(define-foreign-variable _shut_wr int "SHUT_WR")
 (define-foreign-variable _ipproto_tcp int "IPPROTO_TCP")
 (define-foreign-variable _invalid_socket int "INVALID_SOCKET")
-(define-foreign-variable _ewouldblock int "EWOULDBLOCK")
-(define-foreign-variable _eagain int "EAGAIN")
-(define-foreign-variable _eintr int "EINTR")
-(define-foreign-variable _einprogress int "EINPROGRESS")
-
+(define-foreign-variable _socket_error int "SOCKET_ERROR")
+
+(define ##net#last-error-code (foreign-lambda int "get_last_socket_error"))
+(define ##net#error-code->message
+  (foreign-lambda c-string "errormsg_from_code" int))
+(define ##net#retry? (foreign-lambda bool "should_retry_call"))
+(define ##net#in-progress? (foreign-lambda bool "call_in_progress"))
+(define ##net#interrupted? (foreign-lambda bool "call_was_interrupted"))
 (define ##net#socket (foreign-lambda int "socket" int int int))
 (define ##net#bind (foreign-lambda int "bind" int scheme-pointer int))
 (define ##net#listen (foreign-lambda int "listen" int int))
@@ -123,12 +162,6 @@ EOF
       int ((int s) (scheme-pointer msg) (int offset) (int len) (int flags))
     "C_return(send(s, (char *)msg+offset, len, flags));"))
 
-(define ##net#make-nonblocking
-  (foreign-lambda* bool ((int fd))
-    "int val = fcntl(fd, F_GETFL, 0);"
-    "if(val == -1) C_return(0);"
-    "C_return(fcntl(fd, F_SETFL, val | O_NONBLOCK) != -1);") )
-
 (define ##net#getsockname 
   (foreign-lambda* c-string ((int s))
     "struct sockaddr_in sa;"
@@ -197,21 +230,21 @@ EOF
 (define-syntax network-error
   (syntax-rules ()
     ((_ loc msg . args)
-     (network-error/errno loc (##sys#update-errno) msg . args))))
+     (network-error/code loc (##net#last-error-code) msg . args))))
 
 (define-syntax network-error/close
   (syntax-rules ()
     ((_ loc msg socket . args)
-     (let ((errno (##sys#update-errno)))
+     (let ((error-code (##net#last-error-code)))
        (##net#close socket)
-       (network-error/errno loc errno msg socket . args)))))
+       (network-error/code loc error-code msg socket . args)))))
 
-(define-syntax network-error/errno
+(define-syntax network-error/code
   (syntax-rules ()
-    ((_ loc errno msg . args)
+    ((_ loc error-code msg . args)
      (##sys#signal-hook #:network-error loc
 			(string-append (string-append msg " - ")
-				       (general-strerror errno))
+				       (##net#error-code->message error-code))
 			. args))))
 
 (define ##net#parse-host
@@ -250,15 +283,15 @@ EOF
 	(##net#fresh-addr addr port) )
     (let ((s (##net#socket _af_inet style 0)))
       (when (eq? _invalid_socket s)
-	(##sys#update-errno)
 	(##sys#error "cannot create socket") )
       ;; PLT makes this an optional arg to tcp-listen. Should we as well?
-      (when (eq? -1 ((foreign-lambda* int ((int socket)) 
-		       "int yes = 1; 
+      (when (eq? _socket_error
+		 ((foreign-lambda* int ((int socket))
+		    "int yes = 1;
 		      C_return(setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&yes, sizeof(int)));") 
-		     s) )
+		  s) )
 	(network-error/close 'tcp-listen "error while setting up socket" s) )
-      (when (eq? -1 (##net#bind s addr _sockaddr_in_size))
+      (when (eq? _socket_error (##net#bind s addr _sockaddr_in_size))
 	(network-error/close 'tcp-listen "cannot bind to socket" s host port) )
       s)) )
 
@@ -270,7 +303,7 @@ EOF
     (##sys#signal-hook #:domain-error 'tcp-listen "invalid port number" port) )
   (##sys#check-exact backlog)
   (let ((s (##net#bind-socket _sock_stream host port)))
-    (when (eq? -1 (##net#listen s backlog))
+    (when (eq? _socket_error (##net#listen s backlog))
       (network-error/close 'tcp-listen "cannot listen on socket" s port) )
     (##sys#make-structure 'tcp-listener s) ) )
 
@@ -281,7 +314,7 @@ EOF
 (define (tcp-close tcpl)
   (##sys#check-structure tcpl 'tcp-listener)
   (let ((s (##sys#slot tcpl 1)))
-    (when (fx= -1 (##net#close s))
+    (when (eq? _socket_error (##net#close s))
       (network-error 'tcp-close "cannot close TCP socket" tcpl) ) ) )
 
 (define-constant +input-buffer-size+ 1024)
@@ -306,7 +339,7 @@ EOF
 (define ##net#io-ports
   (let ((tbs tcp-buffer-size))
     (lambda (loc fd)
-      (unless (##net#make-nonblocking fd)
+      (unless (##core#inline "make_socket_nonblocking" fd)
 	(network-error/close loc "cannot create TCP ports" fd) )
       (let* ((buf (make-string +input-buffer-size+))
 	     (data (vector fd #f #f buf 0))
@@ -318,13 +351,12 @@ EOF
 	     (outbuf (and outbufsize (fx> outbufsize 0) ""))
 	     (read-input
 	      (lambda ()
-                (let* ((tmr (tcp-read-timeout))
-                       (dlr (and tmr (+ (current-milliseconds) tmr))))
+		(let* ((tmr (tcp-read-timeout))
+		       (dlr (and tmr (+ (current-milliseconds) tmr))))
 		  (let loop ()
 		    (let ((n (##net#recv fd buf +input-buffer-size+ 0)))
-		      (cond ((eq? -1 n)
-			     (cond ((or (eq? errno _ewouldblock)
-					(eq? errno _eagain))
+		      (cond ((eq? _socket_error n)
+			     (cond ((##net#retry?)
 				    (when dlr
 				      (##sys#thread-block-for-timeout!
 				       ##sys#current-thread dlr) )
@@ -335,7 +367,7 @@ EOF
 				       #:network-timeout-error
 				       "read operation timed out" tmr fd) )
 				    (loop) )
-				   ((eq? errno _eintr)
+				   ((##net#interrupted?)
 				    (##sys#dispatch-interrupt loop))
 				   (else
 				    (network-error #f "cannot read from socket" fd) ) ) )
@@ -355,15 +387,17 @@ EOF
 		       c) ) )
 	       (lambda ()
 		 (or (fx< bufindex buflen)
+		     ;; XXX: This "knows" that check_fd_ready is
+		     ;; implemented using a winsock2 call on Windows
 		     (let ((f (##net#check-fd-ready fd)))
-		       (when (eq? f -1)
+		       (when (eq? _socket_error f)
 			 (network-error #f "cannot check socket for input" fd) )
 		       (eq? f 1) ) ) )
 	       (lambda ()
 		 (unless iclosed
 		   (set! iclosed #t)
-		   (unless (##sys#slot data 1) (##net#shutdown fd _sd_receive))
-		   (when (and oclosed (eq? -1 (##net#close fd)))
+		   (unless (##sys#slot data 1) (##net#shutdown fd _shut_rd))
+		   (when (and oclosed (eq? _socket_error (##net#close fd)))
 		     (network-error #f "cannot close socket input port" fd) ) ) )
 	       (lambda ()
 		 (when (fx>= bufindex buflen)
@@ -431,9 +465,8 @@ EOF
 			     (dlw (and tmw (+ (current-milliseconds) tmw))))
 		    (let* ((count (fxmin +output-chunk-size+ len))
 			   (n (##net#send fd s offset count 0)) )
-		      (cond ((eq? -1 n)
-			     (cond ((or (eq? errno _ewouldblock)
-					(eq? errno _eagain))
+		      (cond ((eq? _socket_error n)
+			     (cond ((##net#retry?)
 				    (when dlw
 				      (##sys#thread-block-for-timeout!
 				       ##sys#current-thread dlw) )
@@ -444,7 +477,7 @@ EOF
 				       #:network-timeout-error
 				       "write operation timed out" tmw fd) )
 				    (loop len offset dlw) )
-				   ((eq? errno _eintr)
+				   ((##net#interrupted?)
 				    (##sys#dispatch-interrupt
 				     (cut loop len offset dlw)))
 				   (else
@@ -472,8 +505,8 @@ EOF
 		   (when (and outbuf (fx> (##sys#size outbuf) 0))
 		     (output outbuf)
 		     (set! outbuf "") )
-		   (unless (##sys#slot data 2) (##net#shutdown fd _sd_send))
-		   (when (and iclosed (eq? -1 (##net#close fd)))
+		   (unless (##sys#slot data 2) (##net#shutdown fd _shut_wr))
+		   (when (and iclosed (eq? _socket_error (##net#close fd)))
 		     (network-error #f "cannot close socket output port" fd) ) ) )
 	       (and outbuf
 		    (lambda ()
@@ -491,11 +524,11 @@ EOF
 (define (tcp-accept tcpl)
   (##sys#check-structure tcpl 'tcp-listener)
   (let* ((fd (##sys#slot tcpl 1))
-         (tma (tcp-accept-timeout))
-         (dla (and tma (+ tma (current-milliseconds)))))
+	 (tma (tcp-accept-timeout))
+	 (dla (and tma (+ tma (current-milliseconds)))))
     (let loop ()
       (when dla
-        (##sys#thread-block-for-timeout! ##sys#current-thread dla) )
+	(##sys#thread-block-for-timeout! ##sys#current-thread dla) )
       (##sys#thread-block-for-i/o! ##sys#current-thread fd #:input)
       (##sys#thread-yield!)
       (if (##sys#slot ##sys#current-thread 13)
@@ -504,16 +537,18 @@ EOF
 	   'tcp-accept
 	   "accept operation timed out" tma fd) )
       (let ((fd (##net#accept fd #f #f)))
-	(cond ((not (eq? -1 fd)) (##net#io-ports 'tcp-accept fd))
-	      ((eq? errno _eintr)
+	(cond ((not (eq? _invalid_socket fd))
+	       (##net#io-ports 'tcp-accept fd))
+	      ((##net#interrupted?)
 	       (##sys#dispatch-interrupt loop))
 	      (else
 	       (network-error 'tcp-accept "could not accept from listener" tcpl)))) ) ) )
 
 (define (tcp-accept-ready? tcpl)
   (##sys#check-structure tcpl 'tcp-listener 'tcp-accept-ready?)
+  ;; XXX: This "knows" that check_fd_ready is implemented using a winsock2 call
   (let ((f (##net#check-fd-ready (##sys#slot tcpl 1))))
-    (when (eq? -1 f)
+    (when (eq? _socket_error f)
       (network-error 'tcp-accept-ready? "cannot check socket for input" tcpl) )
     (eq? 1 f) ) )
 
@@ -521,17 +556,15 @@ EOF
   (foreign-lambda* int ((int socket))
     "int err, optlen;"
     "optlen = sizeof(err);"
-    "if (typecorrect_getsockopt(socket, SOL_SOCKET, SO_ERROR, &err, (socklen_t *)&optlen) == -1)"
-    "  C_return(-1);"
+    "if (typecorrect_getsockopt(socket, SOL_SOCKET, SO_ERROR, &err, (socklen_t *)&optlen) == SOCKET_ERROR)"
+    "  C_return(SOCKET_ERROR);"
     "C_return(err);"))
 
-(define general-strerror (foreign-lambda c-string "strerror" int))
-
 (define (tcp-connect host . more)
   (let* ((port (optional more #f))
-         (tmc (tcp-connect-timeout))
-         (dlc (and tmc (+ (current-milliseconds) tmc)))
-         (addr (make-string _sockaddr_in_size)))
+	 (tmc (tcp-connect-timeout))
+	 (dlc (and tmc (+ (current-milliseconds) tmc)))
+	 (addr (make-string _sockaddr_in_size)))
     (##sys#check-string host)
     (unless port
       (set!-values (host port) (##net#parse-host host "tcp"))
@@ -540,28 +573,28 @@ EOF
     (unless (##net#gethostaddr addr host port)
       (##sys#signal-hook #:network-error 'tcp-connect "cannot find host address" host) )
     (let ((s (##net#socket _af_inet _sock_stream 0)) )
-      (when (eq? -1 s)
+      (when (eq? _invalid_socket s)
 	(network-error 'tcp-connect "cannot create socket" host port) )
-      (unless (##net#make-nonblocking s)
+      (unless (##core#inline "make_socket_nonblocking" s)
 	(network-error/close 'tcp-connect "fcntl() failed" s) )
       (let loop ()
-	(when (eq? -1 (##net#connect s addr _sockaddr_in_size))
-	  (cond ((eq? errno _einprogress)
+	(when (eq? _socket_error (##net#connect s addr _sockaddr_in_size))
+	  (cond ((##net#in-progress?) ; Wait till it's available via select/poll
 		 (when dlc
 		   (##sys#thread-block-for-timeout! ##sys#current-thread dlc))
-		 (##sys#thread-block-for-i/o! ##sys#current-thread s #:all)
-                 (##sys#thread-yield!))
-		((eq? errno _eintr)
+		 (##sys#thread-block-for-i/o! ##sys#current-thread s #:output)
+		 (##sys#thread-yield!)) ; Don't loop: it's connected now
+		((##net#interrupted?)
 		 (##sys#dispatch-interrupt loop))
 		(else
 		 (network-error/close
-                  'tcp-connect "cannot connect to socket" s host port)))))
+		  'tcp-connect "cannot connect to socket" s host port)))))
       (let ((err (get-socket-error s)))
-	(cond ((fx= err -1)
-               (network-error/close 'tcp-connect "getsockopt() failed" s))
+	(cond ((eq? _socket_error err)
+	       (network-error/close 'tcp-connect "getsockopt() failed" s))
 	      ((fx> err 0)
 	       (##net#close s)
-	       (network-error/errno 'tcp-connect err "cannot create socket"))))
+	       (network-error/code 'tcp-connect err "cannot create socket"))))
       (##net#io-ports 'tcp-connect s) ) ) )
 
 (define (##sys#tcp-port->fileno p)
Trap