summaryrefslogtreecommitdiff
path: root/www/py-hyper/files/files-server.py
blob: 610b4b9a02dd72132eb5c1c895aa373e684f7247 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-
"""
test/server
~~~~~~~~~~~

This module defines some testing infrastructure that is very useful for
integration-type testing of hyper. It works by spinning up background threads
that run test-defined logic while listening to a background thread.

This very-clever idea and most of its implementation are ripped off from
Andrey Petrov's excellent urllib3 project. I owe him a substantial debt in
ingenuity and about a million beers. The license is available in NOTICES.
"""

import threading
import socket
import sys

from hyper import HTTP20Connection
from hyper.compat import ssl
from hyper.http11.connection import HTTP11Connection
from hyper.packages.hpack.hpack import Encoder
from hyper.packages.hpack.huffman import HuffmanEncoder
from hyper.packages.hpack.huffman_constants import (
    REQUEST_CODES, REQUEST_CODES_LENGTH
)
from hyper.tls import NPN_PROTOCOL

class SocketServerThread(threading.Thread):
    """
    This method stolen wholesale from shazow/urllib3 under license. See NOTICES.

    :param socket_handler: Callable which receives a socket argument for one
        request.
    :param ready_event: Event which gets set when the socket handler is
        ready to receive requests.
    """
    def __init__(self,
                 socket_handler,
                 host='localhost',
                 ready_event=None,
                 h2=True,
                 secure=True):
        threading.Thread.__init__(self)

        self.socket_handler = socket_handler
        self.host = host
        self.secure = secure
        self.ready_event = ready_event
        self.daemon = True

        if self.secure:
            self.cxt = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
            if ssl.HAS_NPN and h2:
                self.cxt.set_npn_protocols([NPN_PROTOCOL])
            self.cxt.load_cert_chain(certfile='test/certs/server.crt',
                                     keyfile='test/certs/server.key')

    def _start_server(self):
        sock = socket.socket(socket.AF_INET6)
        if sys.platform != 'win32':
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        if self.secure:
            sock = self.cxt.wrap_socket(sock, server_side=True)
        sock.bind((self.host, 0))
        self.port = sock.getsockname()[1]

        # Once listen() returns, the server socket is ready
        sock.listen(1)

        if self.ready_event:
            self.ready_event.set()

        self.socket_handler(sock)
        sock.close()

    def _wrap_socket(self, sock):
        raise NotImplementedError()

    def run(self):
        self.server = self._start_server()


class SocketLevelTest(object):
    """
    A test-class that defines a few helper methods for running socket-level
    tests.
    """
    def set_up(self, secure=True, proxy=False):
        self.host = None
        self.port = None
        self.secure = secure if not proxy else False
        self.proxy = proxy
        self.server_thread = None

    def _start_server(self, socket_handler):
        """
        Starts a background thread that runs the given socket handler.
        """
        ready_event = threading.Event()
        self.server_thread = SocketServerThread(
            socket_handler=socket_handler,
            ready_event=ready_event,
            h2=self.h2,
            secure=self.secure
        )
        self.server_thread.start()
        ready_event.wait()

        self.host = self.server_thread.host
        self.port = self.server_thread.port
        self.secure = self.server_thread.secure

    def get_connection(self):
        if self.h2:
            if not self.proxy:
                return HTTP20Connection(self.host, self.port, self.secure)
            else:
                return HTTP20Connection('http2bin.org', secure=self.secure, 
                                        proxy_host=self.host, 
                                        proxy_port=self.port)
        else:
            if not self.proxy:
                return HTTP11Connection(self.host, self.port, self.secure)
            else:
                return HTTP11Connection('httpbin.org', secure=self.secure, 
                                        proxy_host=self.host, 
                                        proxy_port=self.port)


    def get_encoder(self):
        """
        Returns a HPACK encoder set up for responses.
        """
        e = Encoder()
        e.huffman_coder = HuffmanEncoder(REQUEST_CODES, REQUEST_CODES_LENGTH)
        return e

    def tear_down(self):
        """
        Tears down the testing thread.
        """
        self.server_thread.join(0.1)