1 module hio.tls.tls_impl;
2 
3 import hio.tls.openssl;
4 import hio.tls.common;
5 
6 import std.socket;
7 import std.datetime;
8 import std.string;
9 import std.experimental.logger;
10 
11 import hio.loop;
12 import hio.socket;
13 import hio.scheduler;
14 
15 import nbuff: NbuffChunk, Nbuff, MutableNbuffChunk;
16 
17 version(Posix)
18 {
19     // see https://github.com/ikod/hio/issues/1
20     import core.sys.posix.signal;
21     shared static this()
22     {
23         signal(SIGPIPE, SIG_IGN);
24     }
25 }
26 
27 class AsyncSSLSocket : FileEventHandler, AsyncSocketLike
28 {
29     private
30     {
31         enum State
32         {
33             INIT,
34             CONNECTING,
35             ACCEPTING,
36             IO,
37             IDLE,
38             ERROR,
39             CLOSED
40         }
41         //version (openssl11)
42         //{
43             SSL* _ssl;
44             SSL_CTX* _ctx;
45         //}
46         State               _state;
47         hlSocket            _so;
48         immutable string    _file;
49         immutable int       _line;
50         Duration            _op_timeout = 15.seconds;
51         hlEvLoop            _loop;
52         Timer               _timer;
53         AppEvent            _polling_for;
54         HandlerDelegate     _callback;
55         void delegate(AsyncSocketLike) @safe _accept_callback;
56         bool                _ssl_connected;
57         IOResult            _ioResult;
58         IOCallback          _ioCallback;
59         size_t              _to_receive, _received;
60         bool                _allowPartialInput = true;
61         MutableNbuffChunk   _input;
62         int                 _io_depth; // see io()
63         string              _host;     // for SNI
64         string              _cert_file;
65         string              _key_file;
66     }
67     override string describe()
68     {
69         return "AsyncSSLSocket: "
70            ~"_state: %s; ".format(_state)
71            ~"_file(_line): %s:%s; ".format(_file, _line)
72            ~"_polling_for: %s; ".format(appeventToString(_polling_for))
73            ~"_ssl_connected: %s; ".format(_ssl_connected)
74            ~"_timer: [%s]; ".format(_timer)
75            ~"underlying so: %s; ".format(_so.describe)
76            ;
77     }
78     this(ubyte af = AF_INET, int sock_type = SOCK_STREAM, string f = __FILE__, int l = __LINE__) @safe
79     {
80         _so = new hlSocket(af, sock_type, f, l);
81         _file = f;
82         _line = l;
83     }
84 
85     this(hlSocket so, string f = __FILE__, int l = __LINE__) @safe
86     {
87         _so = so;
88         _file = f;
89         _line = l;
90     }
91     void want_in() @safe
92     {
93         if( !(_polling_for & AppEvent.IN) )
94         {
95             _loop.startPoll(_so.fileno, AppEvent.IN, this);
96             _polling_for |= AppEvent.IN;
97         }
98     }
99     void stop_in() @safe
100     {
101         if( _polling_for & AppEvent.IN )
102         {
103             _loop.stopPoll(_so.fileno, AppEvent.IN);
104             _polling_for &= ~AppEvent.IN;
105         }
106     }
107     void want_out() @safe
108     {
109         if( !(_polling_for & AppEvent.OUT) )
110         {
111             _loop.startPoll(_so.fileno, AppEvent.OUT, this);
112             _polling_for |= AppEvent.OUT;
113         }
114     }
115     void stop_out() @safe
116     {
117         if( _polling_for & AppEvent.OUT )
118         {
119             _loop.stopPoll(_so.fileno, AppEvent.OUT);
120             _polling_for &= ~AppEvent.OUT;
121         }
122     }
123     override void eventHandler(int fd, AppEvent e)
124     {
125         debug(hiossl) tracef("got event %s for ssl underlying in state %s", e, _state);
126         // here goes all events on underlying socket
127         if ( e & AppEvent.IN )
128         {
129             if ( _state == State.CONNECTING )
130             {
131                 immutable HandlerDelegate cb = _callback;
132                 immutable connect_result = handleConnectEvent();
133                 final switch(connect_result)
134                 {
135                     case SSL_connect_call_result.ERROR:
136                         _callback = null;
137                         stop_in();
138                         stop_out();
139                         cb(AppEvent.ERR);
140                         return;
141                     case SSL_connect_call_result.CONNECTED:
142                         _ssl_connected = true;
143                         _callback = null;
144                         _state = State.IDLE;
145                         stop_in();
146                         stop_out();
147                         cb(AppEvent.OUT);
148                         return;
149                     case SSL_connect_call_result.WANT_READ:
150                         debug (hiossl) tracef ("want read");
151                         want_in();
152                         return;
153                     case SSL_connect_call_result.WANT_WRITE:
154                         debug (hiossl) tracef ("want write");
155                         want_out();
156                         return;
157                 }
158                 assert(0);
159             }
160             else if ( _state == State.ACCEPTING )
161             {
162                 immutable cb = _accept_callback;
163                 immutable connect_result = handleAcceptEvent();
164                 final switch(connect_result)
165                 {
166                     case SSL_connect_call_result.ERROR:
167                         _callback = null;
168                         stop_in();
169                         stop_out();
170                         cb(this);
171                         return;
172                     case SSL_connect_call_result.CONNECTED:
173                         _ssl_connected = true;
174                         _callback = null;
175                         _state = State.IDLE;
176                         stop_in();
177                         stop_out();
178                         cb(this);
179                         return;
180                     case SSL_connect_call_result.WANT_READ:
181                         debug (hiossl) tracef ("want read");
182                         want_in();
183                         return;
184                     case SSL_connect_call_result.WANT_WRITE:
185                         debug (hiossl) tracef ("want write");
186                         want_out();
187                         return;
188                 }
189                 assert(0);
190             }
191             else if ( _state == State.IO )
192             {
193                 debug(hiossl) tracef("going to read %d bytes more", _to_receive);
194                 immutable int result = () @trusted {
195                     return SSL_read(_ssl, cast(void*)&_input.data[_received], cast(int)_to_receive);
196                 }();
197 
198                 if ( result > 0 )
199                 {
200                     // success
201                     _received += result;
202                     _to_receive -= result;
203                     debug(hiossl) tracef("successfully received %d, have to receive %d more", _received, _to_receive);
204                     debug(hiossl) tracef("<%s>", cast(string)_input.data[0.._received]);
205                     if ( _allowPartialInput )
206                     {
207                         _state = State.IDLE;
208                         _ioResult.input = NbuffChunk(_input, _received);
209                         _ioCallback(_ioResult);
210                         return;
211                     }
212                 }
213                 if ( result <= 0)
214                 {
215                     immutable int reason = SSL_get_error(_ssl, result);
216                     debug(hiossl) tracef("result %d, reason %d", result, reason);
217                     switch(reason)
218                     {
219                         case SSL_ERROR_WANT_READ:
220                             debug (hiossl) tracef ("want read");
221                             want_in();
222                             return;
223                         case SSL_ERROR_WANT_WRITE:
224                             debug (hiossl) tracef ("want write");
225                             want_out();
226                             return;
227                         default:
228                             stop_in();
229                             stop_out();
230                             _ioResult.error = true;
231                             _state = State.IDLE;
232                             _ioCallback(_ioResult);
233                             return;
234                     }
235                 }
236             }
237         }
238         if (e & AppEvent.OUT)
239         {
240             if ( _state == State.CONNECTING )
241             {
242                 immutable HandlerDelegate cb = _callback;
243                 immutable connect_result = handleConnectEvent();
244                 final switch(connect_result)
245                 {
246                     case SSL_connect_call_result.ERROR:
247                         _callback = null;
248                         _state = State.IDLE;
249                         stop_in();
250                         stop_out();
251                         cb(AppEvent.ERR);
252                         return;
253                     case SSL_connect_call_result.CONNECTED:
254                         _ssl_connected = true;
255                         _callback = null;
256                         _state = State.IDLE;
257                         stop_in();
258                         stop_out();
259                         cb(AppEvent.OUT);
260                         return;
261                     case SSL_connect_call_result.WANT_READ:
262                         debug (hiossl) tracef ("want read");
263                         want_in();
264                         return;
265                     case SSL_connect_call_result.WANT_WRITE:
266                         debug (hiossl) tracef ("want write");
267                         want_out();
268                         return;
269                 }
270                 assert(0);
271             }
272         }
273     }
274 
275     override bool open() @safe
276     {
277         _so.open();
278         return true;
279     }
280 
281     override void close() @safe
282     {
283         _state = State.INIT;
284         _ssl_connected = false;
285         stop_in();
286         stop_out();
287         if ( _so )
288         {
289             _so.close();
290             _so = null;
291         }
292         if ( _ctx )
293         {
294             SSL_CTX_free(_ctx);
295             _ctx = null;
296         }
297         if ( _ssl )
298         {
299             SSL_free(_ssl);
300             _ssl = null;
301         }
302         _input.release;
303         _ioResult = IOResult();
304     }
305 
306     override bool connected() @safe
307     {
308         return _ssl_connected && _so.connected;
309     }
310 
311     override void bind(Address addr) @safe
312     {
313         _so.bind(addr);
314     }
315 
316     private void io_callback(int fd, AppEvent ev)
317     {
318         debug(hiossl) tracef("read callback on underlying socket");
319 
320     }
321     private void timer_callback(AppEvent ev) @safe
322     {
323         debug(hiossl) tracef("timed out");
324 
325     }
326 
327     private SSL_connect_call_result handleAcceptEvent() @safe
328     {
329         auto result = SSL_accept(_ssl);
330         debug (hiossl) tracef("SSL_accept rc=%d", result);
331 
332         if (result == 0)
333         {
334             long error = ERR_get_error();
335             const char* error_str = ERR_error_string(error, null);
336             debug (hiossl) tracef ("could not SSL_accept: %s\n", error_str);
337             return SSL_connect_call_result.ERROR;
338         }
339         if (result > 0)
340         {
341             // connected
342             _ssl_connected = true;
343             return SSL_connect_call_result.CONNECTED;
344         }
345         // result < 0, have to continue
346         // ssl want read or write
347         int ssl_error = SSL_get_error(_ssl, result);
348         debug (hiossl) tracef ("SSL_signal: %s", ssl_error);
349         switch(ssl_error)
350         {
351             case SSL_ERROR_WANT_READ:
352                 debug (hiossl) tracef ("want read");
353                 return SSL_connect_call_result.WANT_READ;
354             case SSL_ERROR_WANT_WRITE:
355                 debug (hiossl) tracef ("want write");
356                 return SSL_connect_call_result.WANT_WRITE;
357             case SSL_ERROR_SSL:
358                 debug(hiossl) tracef("ssl handshake failure");
359                 return SSL_connect_call_result.ERROR;
360             default:
361                 warning("while accepting: %s", SSL_error_strings[ssl_error]);
362                 return SSL_connect_call_result.ERROR;
363         }
364     }
365     private SSL_connect_call_result handleConnectEvent() @safe
366     {
367         auto result = SSL_connect(_ssl);
368         debug (hiossl) tracef("SSL_connect rc=%d", result);
369 
370         if (result == 0)
371         {
372             long error = ERR_get_error();
373             const char* error_str = ERR_error_string(error, null);
374             debug (hiossl) tracef ("could not SSL_connect: %s\n", error_str);
375             return SSL_connect_call_result.ERROR;
376         }
377         if (result > 0)
378         {
379             // connected
380             _ssl_connected = true;
381             return SSL_connect_call_result.CONNECTED;
382         }
383         // result < 0, have to continue
384         // ssl want read or write
385         int ssl_error = SSL_get_error(_ssl, result);
386         debug (hiossl) tracef ("SSL_signal: %s", ssl_error);
387         switch(ssl_error)
388         {
389             case SSL_ERROR_WANT_READ:
390                 debug (hiossl) tracef ("want read");
391                 return SSL_connect_call_result.WANT_READ;
392             case SSL_ERROR_WANT_WRITE:
393                 debug (hiossl) tracef ("want write");
394                 return SSL_connect_call_result.WANT_WRITE;
395             case SSL_ERROR_SSL:
396                 debug(hiossl) tracef("ssl handshake failure");
397                 return SSL_connect_call_result.ERROR;
398             default:
399                 assert(0, SSL_error_strings[ssl_error]);
400         }
401     }
402     void listen(int backlog = 512)
403     {
404         _so.listen(backlog);
405     }
406     override void accept(hlEvLoop loop, Duration timeout, void delegate(AsyncSocketLike) @safe callback) @safe
407     {
408         _loop = loop;
409         void so_accept_callback(AsyncSocketLike s) @safe
410         {
411             debug(hiossl) tracef("ssl callback %s", s);
412             if ( s is null )
413             {
414                 callback(s);
415                 return;
416             }
417             // set up ssl on this socket
418             hlSocket new_so = cast(hlSocket)s;
419             assert(new_so.connected);
420             AsyncSSLSocket new_ssl_so = new AsyncSSLSocket(new_so);
421             new_ssl_so._loop = loop;
422             new_ssl_so._so = new_so;
423             new_ssl_so._state = State.ACCEPTING;
424             new_ssl_so._accept_callback = callback;
425 
426             new_ssl_so._ctx = SSL_CTX_new(TLS_server_method());
427             if ( _cert_file )
428             {
429                 new_ssl_so._cert_file = _cert_file;
430                 int r = SSL_CTX_use_certificate_file(new_ssl_so._ctx, toStringz(_cert_file), SSL_FILETYPE_PEM);
431                 assert(r==1);
432             }
433             if ( _key_file )
434             {
435                 new_ssl_so._key_file = _key_file;
436                 int r = SSL_CTX_use_PrivateKey_file(new_ssl_so._ctx, toStringz(_key_file), SSL_FILETYPE_PEM);
437                 assert(r==1);
438             }
439 
440             //SSL_CTX_set_cipher_list(new_ssl_so._ctx, &"ALL:!MEDIUM:!LOW"[0]);
441 
442             new_ssl_so._ssl = SSL_new(new_ssl_so._ctx);
443             SSL_set_fd(new_ssl_so._ssl, cast(int) new_ssl_so._so.fileno);
444             SSL_set_accept_state(new_ssl_so._ssl);
445             // start negotiation
446             auto accept_result = new_ssl_so.handleAcceptEvent();
447             final switch(accept_result)
448             {
449                 case SSL_connect_call_result.ERROR:
450                     new_ssl_so._state = State.ERROR;
451                     callback(new_ssl_so);
452                     return;
453                 case SSL_connect_call_result.CONNECTED:
454                     new_ssl_so._ssl_connected = true;
455                     _state = State.IDLE;
456                     callback(new_ssl_so);
457                     return;
458                 case SSL_connect_call_result.WANT_READ:
459                     debug (hiossl) tracef ("want read");
460                     new_ssl_so.want_in();
461                     return;
462                 case SSL_connect_call_result.WANT_WRITE:
463                     debug (hiossl) tracef ("want write");
464                     new_ssl_so.want_out();
465                     return;
466             }
467         }
468         _so.accept(_loop, timeout, &so_accept_callback);
469     }
470     ///
471     /// turn on and set "host" for server name indeication(SNI)
472     /// call this before call to connect
473     ///
474     public void set_host(string host) @safe
475     {
476         _host = host;
477     }
478     ///
479     public void cert_file(string cert_file)
480     {
481         _cert_file = cert_file;
482     }
483     ///
484     public void key_file(string key_file)
485     {
486         _key_file = key_file;
487     }
488     private void SSL_set_tlsext_host_name() @trusted nothrow {
489         enum int SSL_CTRL_SET_TLSEXT_HOSTNAME = 55;
490         enum long TLSEXT_NAMETYPE_host_name = 0;
491         if ( _host )
492         {
493             SSL_ctrl(_ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME,TLSEXT_NAMETYPE_host_name, cast(void*)toStringz(_host));
494         }
495     }
496 
497     override bool connect(Address addr, hlEvLoop loop, HandlerDelegate callback, Duration timeout) @safe
498     {
499         assert(_loop is null);
500         assert(_state == State.INIT);
501         assert(_timer is null);
502         assert(_callback is null);
503 
504         _loop = loop;
505         _callback = callback;
506         _state = State.CONNECTING;
507 
508         void so_connect_callback(AppEvent ev) @safe
509         {
510             debug (hiossl) tracef("underlying socket event %s", ev);
511             if ( ev & AppEvent.TMO )
512             {
513                 debug (hiossl) trace("Connection timeout");
514                 callback(AppEvent.TMO);
515                 return;
516             }
517             if ( ev & (AppEvent.ERR|AppEvent.HUP))
518             {
519                 debug (hiossl) tracef("failed to connect: %d", _so.socket_errno);
520                 callback(AppEvent.ERR);
521                 return;
522             }
523             SSL_set_connect_state(_ssl);
524             SSL_set_tlsext_host_name();
525             immutable connect_result = handleConnectEvent();
526             final switch(connect_result)
527             {
528                 case SSL_connect_call_result.ERROR:
529                     _state = State.ERROR;
530                     callback(AppEvent.ERR);
531                     return;
532                 case SSL_connect_call_result.CONNECTED:
533                     _ssl_connected = true;
534                     _state = State.IDLE;
535                     callback(AppEvent.OUT);
536                     return;
537                 case SSL_connect_call_result.WANT_READ:
538                     debug (hiossl) tracef ("want read");
539                     want_in();
540                     return;
541                 case SSL_connect_call_result.WANT_WRITE:
542                     debug (hiossl) tracef ("want write");
543                     want_out();
544                     return;
545             }
546         }
547         _ctx = SSL_CTX_new(TLS_client_method());
548         _ssl = SSL_new(_ctx);
549         SSL_set_fd(_ssl, cast(int) _so.fileno);
550         return _so.connect(addr, loop, &so_connect_callback, timeout);
551     }
552 
553     int io(hlEvLoop loop, ref IORequest iorq, Duration timeout) @safe
554     {
555         assert(iorq.callback !is null);
556         assert(connected);
557         assert(_state == State.IDLE || _state == State.IO);
558         _ioResult = IOResult();
559         _ioResult.output = iorq.output;
560         _to_receive = iorq.to_read;
561         _allowPartialInput = iorq.allowPartialInput;
562         _ioCallback = iorq.callback;
563         _received = 0;
564         _io_depth++;
565         scope(exit)
566         {
567             _io_depth--;
568         }
569 
570         assert(_io_depth < 10);
571 
572         if ( _to_receive > 0 )
573         {
574             _input = Nbuff.get(_to_receive);
575         }
576 
577         while(_ioResult.output.length > 0 && !_ioResult.error )
578         {
579             immutable int result = () @trusted {
580                 NbuffChunk front = _ioResult.output.frontChunk;
581                 return SSL_write(_ssl, cast(void*)&front.data[0], cast(int)front.length);
582             }();
583             if ( result <= 0)
584             {
585                 immutable int reason = SSL_get_error(_ssl, result);
586                 debug(hiossl) tracef("result %d, reason %s", result, SSL_error_strings[reason]);
587                 _ioResult.error = true;
588                 break;
589                 // _ioCallback(_ioResult);
590                 // return 0;
591             }
592             if ( result > 0 )
593             {
594                 debug(hiossl) tracef("sent %d out of %d", result, _ioResult.output.length);
595                 _ioResult.output.pop(result);
596             }
597         }
598 
599         // try to read as much as possible
600         r: while(_to_receive > 0 && !_ioResult.error )
601         {
602             debug(hiossl) trace("receiving");
603             immutable int result = () @trusted {
604                 return SSL_read(_ssl, cast(void*)&_input.data[_received], cast(int)_to_receive);
605             }();
606             if ( result > 0 )
607             {
608                 // success
609                 _received += result;
610                 _to_receive -= result;
611                 debug(hiossl) tracef("successfully received %d, have to receive %d more", _received, _to_receive);
612                 debug(hiossl) tracef("<%s>", cast(string)_input.data[0.._received]);
613                 continue;
614             }
615             if ( result <= 0)
616             {
617                 immutable int reason = SSL_get_error(_ssl, result);
618                 debug(hiossl) tracef("result %d, reason %s", result, SSL_error_strings[reason]);
619                 switch(reason)
620                 {
621                     case SSL_ERROR_WANT_READ:
622                         debug (hiossl) tracef ("want read");
623                         want_in();
624                         break r;
625                     case SSL_ERROR_WANT_WRITE:
626                         debug (hiossl) tracef ("want write");
627                         want_out();
628                         break r;
629                     case SSL_ERROR_SYSCALL:
630                         auto e = ERR_get_error();
631                         debug (hiossl) tracef("syscall error: %s", e);
632                         goto default;
633                     default:
634                         _ioResult.error = true;
635                         break;
636                         // _ioCallback(_ioResult);
637                         // return 0;
638                 }
639             }
640         }
641         if ( _ioResult.error || (_ioResult.output.empty && _to_receive == 0) || (_received > 0 && _allowPartialInput))
642         {
643             if ( _received >0 )
644             {
645                 _ioResult.input = NbuffChunk(_input, _received);
646             }
647             debug(hiossl) tracef("we can return now");
648             if ( _io_depth >= 5)
649             {
650                 //
651                 // as libssl  reads from  socket internally I have next  problem when reads in small
652                 // portions: if user call io() inside from ioCallback I receive too deep call stack.
653                 // So if stack become too deep I'll pass call to ioCallback to event loop.
654                 //
655                 auto t = new Timer(0.seconds, (AppEvent e){
656                     _ioCallback(_ioResult);
657                 });
658                 _loop.startTimer(t);
659                 return 0;
660             }
661             _ioCallback(_ioResult);
662             return 0;
663         }
664         _state = State.IO;
665         return 0;
666     }
667 }
668 
669 unittest
670 {
671     globalLogLevel = LogLevel.info;
672     App({
673         AsyncSSLSocket s = new AsyncSSLSocket();
674         void connected(AppEvent ev) @safe
675         {
676             debug (hiossl)
677                 tracef("connected");
678             getDefaultLoop.stop();
679         }
680         s.open();
681         scope(exit)
682         {
683             s.close();
684         }
685         s.connect(new InternetAddress("1.1.1.1", 443), getDefaultLoop(), &connected, 1.seconds);
686         getDefaultLoop.run();
687     });
688     uninitializeLoops();
689 }