1 module hio.tls.tls_impl; 2 3 import hio.tls.openssl; 4 import hio.tls.common; 5 6 import std.socket; 7 import std.datetime; 8 import std.string; 9 import std.experimental.logger; 10 11 import hio.loop; 12 import hio.socket; 13 import hio.scheduler; 14 15 import nbuff: NbuffChunk, Nbuff, MutableNbuffChunk; 16 17 version(Posix) 18 { 19 // see https://github.com/ikod/hio/issues/1 20 import core.sys.posix.signal; 21 shared static this() 22 { 23 signal(SIGPIPE, SIG_IGN); 24 } 25 } 26 27 class AsyncSSLSocket : FileEventHandler, AsyncSocketLike 28 { 29 private 30 { 31 enum State 32 { 33 INIT, 34 CONNECTING, 35 ACCEPTING, 36 IO, 37 IDLE, 38 ERROR, 39 CLOSED 40 } 41 //version (openssl11) 42 //{ 43 SSL* _ssl; 44 SSL_CTX* _ctx; 45 //} 46 State _state; 47 hlSocket _so; 48 immutable string _file; 49 immutable int _line; 50 Duration _op_timeout = 15.seconds; 51 hlEvLoop _loop; 52 Timer _timer; 53 AppEvent _polling_for; 54 HandlerDelegate _callback; 55 void delegate(AsyncSocketLike) @safe _accept_callback; 56 bool _ssl_connected; 57 IOResult _ioResult; 58 IOCallback _ioCallback; 59 size_t _to_receive, _received; 60 bool _allowPartialInput = true; 61 MutableNbuffChunk _input; 62 int _io_depth; // see io() 63 string _host; // for SNI 64 string _cert_file; 65 string _key_file; 66 } 67 override string describe() 68 { 69 return "AsyncSSLSocket: " 70 ~"_state: %s; ".format(_state) 71 ~"_file(_line): %s:%s; ".format(_file, _line) 72 ~"_polling_for: %s; ".format(appeventToString(_polling_for)) 73 ~"_ssl_connected: %s; ".format(_ssl_connected) 74 ~"_timer: [%s]; ".format(_timer) 75 ~"underlying so: %s; ".format(_so.describe) 76 ; 77 } 78 this(ubyte af = AF_INET, int sock_type = SOCK_STREAM, string f = __FILE__, int l = __LINE__) @safe 79 { 80 _so = new hlSocket(af, sock_type, f, l); 81 _file = f; 82 _line = l; 83 } 84 85 this(hlSocket so, string f = __FILE__, int l = __LINE__) @safe 86 { 87 _so = so; 88 _file = f; 89 _line = l; 90 } 91 void want_in() @safe 92 { 93 if( !(_polling_for & AppEvent.IN) ) 94 { 95 _loop.startPoll(_so.fileno, AppEvent.IN, this); 96 _polling_for |= AppEvent.IN; 97 } 98 } 99 void stop_in() @safe 100 { 101 if( _polling_for & AppEvent.IN ) 102 { 103 _loop.stopPoll(_so.fileno, AppEvent.IN); 104 _polling_for &= ~AppEvent.IN; 105 } 106 } 107 void want_out() @safe 108 { 109 if( !(_polling_for & AppEvent.OUT) ) 110 { 111 _loop.startPoll(_so.fileno, AppEvent.OUT, this); 112 _polling_for |= AppEvent.OUT; 113 } 114 } 115 void stop_out() @safe 116 { 117 if( _polling_for & AppEvent.OUT ) 118 { 119 _loop.stopPoll(_so.fileno, AppEvent.OUT); 120 _polling_for &= ~AppEvent.OUT; 121 } 122 } 123 override void eventHandler(int fd, AppEvent e) 124 { 125 debug(hiossl) tracef("got event %s for ssl underlying in state %s", e, _state); 126 // here goes all events on underlying socket 127 if ( e & AppEvent.IN ) 128 { 129 if ( _state == State.CONNECTING ) 130 { 131 immutable HandlerDelegate cb = _callback; 132 immutable connect_result = handleConnectEvent(); 133 final switch(connect_result) 134 { 135 case SSL_connect_call_result.ERROR: 136 _callback = null; 137 stop_in(); 138 stop_out(); 139 cb(AppEvent.ERR); 140 return; 141 case SSL_connect_call_result.CONNECTED: 142 _ssl_connected = true; 143 _callback = null; 144 _state = State.IDLE; 145 stop_in(); 146 stop_out(); 147 cb(AppEvent.OUT); 148 return; 149 case SSL_connect_call_result.WANT_READ: 150 debug (hiossl) tracef ("want read"); 151 want_in(); 152 return; 153 case SSL_connect_call_result.WANT_WRITE: 154 debug (hiossl) tracef ("want write"); 155 want_out(); 156 return; 157 } 158 assert(0); 159 } 160 else if ( _state == State.ACCEPTING ) 161 { 162 immutable cb = _accept_callback; 163 immutable connect_result = handleAcceptEvent(); 164 final switch(connect_result) 165 { 166 case SSL_connect_call_result.ERROR: 167 _callback = null; 168 stop_in(); 169 stop_out(); 170 cb(this); 171 return; 172 case SSL_connect_call_result.CONNECTED: 173 _ssl_connected = true; 174 _callback = null; 175 _state = State.IDLE; 176 stop_in(); 177 stop_out(); 178 cb(this); 179 return; 180 case SSL_connect_call_result.WANT_READ: 181 debug (hiossl) tracef ("want read"); 182 want_in(); 183 return; 184 case SSL_connect_call_result.WANT_WRITE: 185 debug (hiossl) tracef ("want write"); 186 want_out(); 187 return; 188 } 189 assert(0); 190 } 191 else if ( _state == State.IO ) 192 { 193 debug(hiossl) tracef("going to read %d bytes more", _to_receive); 194 immutable int result = () @trusted { 195 return SSL_read(_ssl, cast(void*)&_input.data[_received], cast(int)_to_receive); 196 }(); 197 198 if ( result > 0 ) 199 { 200 // success 201 _received += result; 202 _to_receive -= result; 203 debug(hiossl) tracef("successfully received %d, have to receive %d more", _received, _to_receive); 204 debug(hiossl) tracef("<%s>", cast(string)_input.data[0.._received]); 205 if ( _allowPartialInput ) 206 { 207 _state = State.IDLE; 208 _ioResult.input = NbuffChunk(_input, _received); 209 _ioCallback(_ioResult); 210 return; 211 } 212 } 213 if ( result <= 0) 214 { 215 immutable int reason = SSL_get_error(_ssl, result); 216 debug(hiossl) tracef("result %d, reason %d", result, reason); 217 switch(reason) 218 { 219 case SSL_ERROR_WANT_READ: 220 debug (hiossl) tracef ("want read"); 221 want_in(); 222 return; 223 case SSL_ERROR_WANT_WRITE: 224 debug (hiossl) tracef ("want write"); 225 want_out(); 226 return; 227 default: 228 stop_in(); 229 stop_out(); 230 _ioResult.error = true; 231 _state = State.IDLE; 232 _ioCallback(_ioResult); 233 return; 234 } 235 } 236 } 237 } 238 if (e & AppEvent.OUT) 239 { 240 if ( _state == State.CONNECTING ) 241 { 242 immutable HandlerDelegate cb = _callback; 243 immutable connect_result = handleConnectEvent(); 244 final switch(connect_result) 245 { 246 case SSL_connect_call_result.ERROR: 247 _callback = null; 248 _state = State.IDLE; 249 stop_in(); 250 stop_out(); 251 cb(AppEvent.ERR); 252 return; 253 case SSL_connect_call_result.CONNECTED: 254 _ssl_connected = true; 255 _callback = null; 256 _state = State.IDLE; 257 stop_in(); 258 stop_out(); 259 cb(AppEvent.OUT); 260 return; 261 case SSL_connect_call_result.WANT_READ: 262 debug (hiossl) tracef ("want read"); 263 want_in(); 264 return; 265 case SSL_connect_call_result.WANT_WRITE: 266 debug (hiossl) tracef ("want write"); 267 want_out(); 268 return; 269 } 270 assert(0); 271 } 272 } 273 } 274 275 override bool open() @safe 276 { 277 _so.open(); 278 return true; 279 } 280 281 override void close() @safe 282 { 283 _state = State.INIT; 284 _ssl_connected = false; 285 stop_in(); 286 stop_out(); 287 if ( _so ) 288 { 289 _so.close(); 290 _so = null; 291 } 292 if ( _ctx ) 293 { 294 SSL_CTX_free(_ctx); 295 _ctx = null; 296 } 297 if ( _ssl ) 298 { 299 SSL_free(_ssl); 300 _ssl = null; 301 } 302 _input.release; 303 _ioResult = IOResult(); 304 } 305 306 override bool connected() @safe 307 { 308 return _ssl_connected && _so.connected; 309 } 310 311 override void bind(Address addr) @safe 312 { 313 _so.bind(addr); 314 } 315 316 private void io_callback(int fd, AppEvent ev) 317 { 318 debug(hiossl) tracef("read callback on underlying socket"); 319 320 } 321 private void timer_callback(AppEvent ev) @safe 322 { 323 debug(hiossl) tracef("timed out"); 324 325 } 326 327 private SSL_connect_call_result handleAcceptEvent() @safe 328 { 329 auto result = SSL_accept(_ssl); 330 debug (hiossl) tracef("SSL_accept rc=%d", result); 331 332 if (result == 0) 333 { 334 long error = ERR_get_error(); 335 const char* error_str = ERR_error_string(error, null); 336 debug (hiossl) tracef ("could not SSL_accept: %s\n", error_str); 337 return SSL_connect_call_result.ERROR; 338 } 339 if (result > 0) 340 { 341 // connected 342 _ssl_connected = true; 343 return SSL_connect_call_result.CONNECTED; 344 } 345 // result < 0, have to continue 346 // ssl want read or write 347 int ssl_error = SSL_get_error(_ssl, result); 348 debug (hiossl) tracef ("SSL_signal: %s", ssl_error); 349 switch(ssl_error) 350 { 351 case SSL_ERROR_WANT_READ: 352 debug (hiossl) tracef ("want read"); 353 return SSL_connect_call_result.WANT_READ; 354 case SSL_ERROR_WANT_WRITE: 355 debug (hiossl) tracef ("want write"); 356 return SSL_connect_call_result.WANT_WRITE; 357 case SSL_ERROR_SSL: 358 debug(hiossl) tracef("ssl handshake failure"); 359 return SSL_connect_call_result.ERROR; 360 default: 361 warning("while accepting: %s", SSL_error_strings[ssl_error]); 362 return SSL_connect_call_result.ERROR; 363 } 364 } 365 private SSL_connect_call_result handleConnectEvent() @safe 366 { 367 auto result = SSL_connect(_ssl); 368 debug (hiossl) tracef("SSL_connect rc=%d", result); 369 370 if (result == 0) 371 { 372 long error = ERR_get_error(); 373 const char* error_str = ERR_error_string(error, null); 374 debug (hiossl) tracef ("could not SSL_connect: %s\n", error_str); 375 return SSL_connect_call_result.ERROR; 376 } 377 if (result > 0) 378 { 379 // connected 380 _ssl_connected = true; 381 return SSL_connect_call_result.CONNECTED; 382 } 383 // result < 0, have to continue 384 // ssl want read or write 385 int ssl_error = SSL_get_error(_ssl, result); 386 debug (hiossl) tracef ("SSL_signal: %s", ssl_error); 387 switch(ssl_error) 388 { 389 case SSL_ERROR_WANT_READ: 390 debug (hiossl) tracef ("want read"); 391 return SSL_connect_call_result.WANT_READ; 392 case SSL_ERROR_WANT_WRITE: 393 debug (hiossl) tracef ("want write"); 394 return SSL_connect_call_result.WANT_WRITE; 395 case SSL_ERROR_SSL: 396 debug(hiossl) tracef("ssl handshake failure"); 397 return SSL_connect_call_result.ERROR; 398 default: 399 assert(0, SSL_error_strings[ssl_error]); 400 } 401 } 402 void listen(int backlog = 512) 403 { 404 _so.listen(backlog); 405 } 406 override void accept(hlEvLoop loop, Duration timeout, void delegate(AsyncSocketLike) @safe callback) @safe 407 { 408 _loop = loop; 409 void so_accept_callback(AsyncSocketLike s) @safe 410 { 411 debug(hiossl) tracef("ssl callback %s", s); 412 if ( s is null ) 413 { 414 callback(s); 415 return; 416 } 417 // set up ssl on this socket 418 hlSocket new_so = cast(hlSocket)s; 419 assert(new_so.connected); 420 AsyncSSLSocket new_ssl_so = new AsyncSSLSocket(new_so); 421 new_ssl_so._loop = loop; 422 new_ssl_so._so = new_so; 423 new_ssl_so._state = State.ACCEPTING; 424 new_ssl_so._accept_callback = callback; 425 426 new_ssl_so._ctx = SSL_CTX_new(TLS_server_method()); 427 if ( _cert_file ) 428 { 429 new_ssl_so._cert_file = _cert_file; 430 int r = SSL_CTX_use_certificate_file(new_ssl_so._ctx, toStringz(_cert_file), SSL_FILETYPE_PEM); 431 assert(r==1); 432 } 433 if ( _key_file ) 434 { 435 new_ssl_so._key_file = _key_file; 436 int r = SSL_CTX_use_PrivateKey_file(new_ssl_so._ctx, toStringz(_key_file), SSL_FILETYPE_PEM); 437 assert(r==1); 438 } 439 440 //SSL_CTX_set_cipher_list(new_ssl_so._ctx, &"ALL:!MEDIUM:!LOW"[0]); 441 442 new_ssl_so._ssl = SSL_new(new_ssl_so._ctx); 443 SSL_set_fd(new_ssl_so._ssl, cast(int) new_ssl_so._so.fileno); 444 SSL_set_accept_state(new_ssl_so._ssl); 445 // start negotiation 446 auto accept_result = new_ssl_so.handleAcceptEvent(); 447 final switch(accept_result) 448 { 449 case SSL_connect_call_result.ERROR: 450 new_ssl_so._state = State.ERROR; 451 callback(new_ssl_so); 452 return; 453 case SSL_connect_call_result.CONNECTED: 454 new_ssl_so._ssl_connected = true; 455 _state = State.IDLE; 456 callback(new_ssl_so); 457 return; 458 case SSL_connect_call_result.WANT_READ: 459 debug (hiossl) tracef ("want read"); 460 new_ssl_so.want_in(); 461 return; 462 case SSL_connect_call_result.WANT_WRITE: 463 debug (hiossl) tracef ("want write"); 464 new_ssl_so.want_out(); 465 return; 466 } 467 } 468 _so.accept(_loop, timeout, &so_accept_callback); 469 } 470 /// 471 /// turn on and set "host" for server name indeication(SNI) 472 /// call this before call to connect 473 /// 474 public void set_host(string host) @safe 475 { 476 _host = host; 477 } 478 /// 479 public void cert_file(string cert_file) 480 { 481 _cert_file = cert_file; 482 } 483 /// 484 public void key_file(string key_file) 485 { 486 _key_file = key_file; 487 } 488 private void SSL_set_tlsext_host_name() @trusted nothrow { 489 enum int SSL_CTRL_SET_TLSEXT_HOSTNAME = 55; 490 enum long TLSEXT_NAMETYPE_host_name = 0; 491 if ( _host ) 492 { 493 SSL_ctrl(_ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME,TLSEXT_NAMETYPE_host_name, cast(void*)toStringz(_host)); 494 } 495 } 496 497 override bool connect(Address addr, hlEvLoop loop, HandlerDelegate callback, Duration timeout) @safe 498 { 499 assert(_loop is null); 500 assert(_state == State.INIT); 501 assert(_timer is null); 502 assert(_callback is null); 503 504 _loop = loop; 505 _callback = callback; 506 _state = State.CONNECTING; 507 508 void so_connect_callback(AppEvent ev) @safe 509 { 510 debug (hiossl) tracef("underlying socket event %s", ev); 511 if ( ev & AppEvent.TMO ) 512 { 513 debug (hiossl) trace("Connection timeout"); 514 callback(AppEvent.TMO); 515 return; 516 } 517 if ( ev & (AppEvent.ERR|AppEvent.HUP)) 518 { 519 debug (hiossl) tracef("failed to connect: %d", _so.socket_errno); 520 callback(AppEvent.ERR); 521 return; 522 } 523 SSL_set_connect_state(_ssl); 524 SSL_set_tlsext_host_name(); 525 immutable connect_result = handleConnectEvent(); 526 final switch(connect_result) 527 { 528 case SSL_connect_call_result.ERROR: 529 _state = State.ERROR; 530 callback(AppEvent.ERR); 531 return; 532 case SSL_connect_call_result.CONNECTED: 533 _ssl_connected = true; 534 _state = State.IDLE; 535 callback(AppEvent.OUT); 536 return; 537 case SSL_connect_call_result.WANT_READ: 538 debug (hiossl) tracef ("want read"); 539 want_in(); 540 return; 541 case SSL_connect_call_result.WANT_WRITE: 542 debug (hiossl) tracef ("want write"); 543 want_out(); 544 return; 545 } 546 } 547 _ctx = SSL_CTX_new(TLS_client_method()); 548 _ssl = SSL_new(_ctx); 549 SSL_set_fd(_ssl, cast(int) _so.fileno); 550 return _so.connect(addr, loop, &so_connect_callback, timeout); 551 } 552 553 int io(hlEvLoop loop, ref IORequest iorq, Duration timeout) @safe 554 { 555 assert(iorq.callback !is null); 556 assert(connected); 557 assert(_state == State.IDLE || _state == State.IO); 558 _ioResult = IOResult(); 559 _ioResult.output = iorq.output; 560 _to_receive = iorq.to_read; 561 _allowPartialInput = iorq.allowPartialInput; 562 _ioCallback = iorq.callback; 563 _received = 0; 564 _io_depth++; 565 scope(exit) 566 { 567 _io_depth--; 568 } 569 570 assert(_io_depth < 10); 571 572 if ( _to_receive > 0 ) 573 { 574 _input = Nbuff.get(_to_receive); 575 } 576 577 while(_ioResult.output.length > 0 && !_ioResult.error ) 578 { 579 immutable int result = () @trusted { 580 NbuffChunk front = _ioResult.output.frontChunk; 581 return SSL_write(_ssl, cast(void*)&front.data[0], cast(int)front.length); 582 }(); 583 if ( result <= 0) 584 { 585 immutable int reason = SSL_get_error(_ssl, result); 586 debug(hiossl) tracef("result %d, reason %s", result, SSL_error_strings[reason]); 587 _ioResult.error = true; 588 break; 589 // _ioCallback(_ioResult); 590 // return 0; 591 } 592 if ( result > 0 ) 593 { 594 debug(hiossl) tracef("sent %d out of %d", result, _ioResult.output.length); 595 _ioResult.output.pop(result); 596 } 597 } 598 599 // try to read as much as possible 600 r: while(_to_receive > 0 && !_ioResult.error ) 601 { 602 debug(hiossl) trace("receiving"); 603 immutable int result = () @trusted { 604 return SSL_read(_ssl, cast(void*)&_input.data[_received], cast(int)_to_receive); 605 }(); 606 if ( result > 0 ) 607 { 608 // success 609 _received += result; 610 _to_receive -= result; 611 debug(hiossl) tracef("successfully received %d, have to receive %d more", _received, _to_receive); 612 debug(hiossl) tracef("<%s>", cast(string)_input.data[0.._received]); 613 continue; 614 } 615 if ( result <= 0) 616 { 617 immutable int reason = SSL_get_error(_ssl, result); 618 debug(hiossl) tracef("result %d, reason %s", result, SSL_error_strings[reason]); 619 switch(reason) 620 { 621 case SSL_ERROR_WANT_READ: 622 debug (hiossl) tracef ("want read"); 623 want_in(); 624 break r; 625 case SSL_ERROR_WANT_WRITE: 626 debug (hiossl) tracef ("want write"); 627 want_out(); 628 break r; 629 case SSL_ERROR_SYSCALL: 630 auto e = ERR_get_error(); 631 debug (hiossl) tracef("syscall error: %s", e); 632 goto default; 633 default: 634 _ioResult.error = true; 635 break; 636 // _ioCallback(_ioResult); 637 // return 0; 638 } 639 } 640 } 641 if ( _ioResult.error || (_ioResult.output.empty && _to_receive == 0) || (_received > 0 && _allowPartialInput)) 642 { 643 if ( _received >0 ) 644 { 645 _ioResult.input = NbuffChunk(_input, _received); 646 } 647 debug(hiossl) tracef("we can return now"); 648 if ( _io_depth >= 5) 649 { 650 // 651 // as libssl reads from socket internally I have next problem when reads in small 652 // portions: if user call io() inside from ioCallback I receive too deep call stack. 653 // So if stack become too deep I'll pass call to ioCallback to event loop. 654 // 655 auto t = new Timer(0.seconds, (AppEvent e){ 656 _ioCallback(_ioResult); 657 }); 658 _loop.startTimer(t); 659 return 0; 660 } 661 _ioCallback(_ioResult); 662 return 0; 663 } 664 _state = State.IO; 665 return 0; 666 } 667 } 668 669 unittest 670 { 671 globalLogLevel = LogLevel.info; 672 App({ 673 AsyncSSLSocket s = new AsyncSSLSocket(); 674 void connected(AppEvent ev) @safe 675 { 676 debug (hiossl) 677 tracef("connected"); 678 getDefaultLoop.stop(); 679 } 680 s.open(); 681 scope(exit) 682 { 683 s.close(); 684 } 685 s.connect(new InternetAddress("1.1.1.1", 443), getDefaultLoop(), &connected, 1.seconds); 686 getDefaultLoop.run(); 687 }); 688 uninitializeLoops(); 689 }