Revision e6daa90e share/websockify/websocketproxy.py

View differences:

share/websockify/websocketproxy.py
11 11

  
12 12
'''
13 13

  
14
import signal, socket, optparse, time, os, sys, subprocess
14
import signal, socket, optparse, time, os, sys, subprocess, logging
15
try:    from socketserver import ForkingMixIn
16
except: from SocketServer import ForkingMixIn
17
try:    from http.server import HTTPServer
18
except: from BaseHTTPServer import HTTPServer
15 19
from select import select
16 20
import websocket
17
try:    from urllib.parse import parse_qs, urlparse
18
except: from urlparse import parse_qs, urlparse
21
try:
22
    from urllib.parse import parse_qs, urlparse
23
except:
24
    from cgi import parse_qs
25
    from urlparse import urlparse
19 26

  
20
class WebSocketProxy(websocket.WebSocketServer):
21
    """
22
    Proxy traffic to and from a WebSockets client to a normal TCP
23
    socket server target. All traffic to/from the client is base64
24
    encoded/decoded to allow binary data to be sent/received to/from
25
    the target.
26
    """
27

  
28
    buffer_size = 65536
27
class ProxyRequestHandler(websocket.WebSocketRequestHandler):
29 28

  
30 29
    traffic_legend = """
31 30
Traffic Legend:
......
39 38
    <. - Client send partial
40 39
"""
41 40

  
42
    def __init__(self, *args, **kwargs):
43
        # Save off proxy specific options
44
        self.target_host    = kwargs.pop('target_host', None)
45
        self.target_port    = kwargs.pop('target_port', None)
46
        self.wrap_cmd       = kwargs.pop('wrap_cmd', None)
47
        self.wrap_mode      = kwargs.pop('wrap_mode', None)
48
        self.unix_target    = kwargs.pop('unix_target', None)
49
        self.ssl_target     = kwargs.pop('ssl_target', None)
50
        self.target_cfg     = kwargs.pop('target_cfg', None)
51
        # Last 3 timestamps command was run
52
        self.wrap_times    = [0, 0, 0]
53

  
54
        if self.wrap_cmd:
55
            rebinder_path = ['./', os.path.dirname(sys.argv[0])]
56
            self.rebinder = None
57

  
58
            for rdir in rebinder_path:
59
                rpath = os.path.join(rdir, "rebind.so")
60
                if os.path.exists(rpath):
61
                    self.rebinder = rpath
62
                    break
63

  
64
            if not self.rebinder:
65
                raise Exception("rebind.so not found, perhaps you need to run make")
66
            self.rebinder = os.path.abspath(self.rebinder)
67

  
68
            self.target_host = "127.0.0.1"  # Loopback
69
            # Find a free high port
70
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
71
            sock.bind(('', 0))
72
            self.target_port = sock.getsockname()[1]
73
            sock.close()
74

  
75
            os.environ.update({
76
                "LD_PRELOAD": self.rebinder,
77
                "REBIND_OLD_PORT": str(kwargs['listen_port']),
78
                "REBIND_NEW_PORT": str(self.target_port)})
79

  
80
        if self.target_cfg:
81
            self.target_cfg = os.path.abspath(self.target_cfg)
82

  
83
        websocket.WebSocketServer.__init__(self, *args, **kwargs)
84

  
85
    def run_wrap_cmd(self):
86
        print("Starting '%s'" % " ".join(self.wrap_cmd))
87
        self.wrap_times.append(time.time())
88
        self.wrap_times.pop(0)
89
        self.cmd = subprocess.Popen(
90
                self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
91
        self.spawn_message = True
92

  
93
    def started(self):
94
        """
95
        Called after Websockets server startup (i.e. after daemonize)
96
        """
97
        # Need to call wrapped command after daemonization so we can
98
        # know when the wrapped command exits
99
        if self.wrap_cmd:
100
            dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
101
        elif self.unix_target:
102
            dst_string = self.unix_target
103
        else:
104
            dst_string = "%s:%s" % (self.target_host, self.target_port)
105

  
106
        if self.target_cfg:
107
            msg = "  - proxying from %s:%s to targets in %s" % (
108
                self.listen_host, self.listen_port, self.target_cfg)
109
        else:
110
            msg = "  - proxying from %s:%s to %s" % (
111
                self.listen_host, self.listen_port, dst_string)
112

  
113
        if self.ssl_target:
114
            msg += " (using SSL)"
115

  
116
        print(msg + "\n")
117

  
118
        if self.wrap_cmd:
119
            self.run_wrap_cmd()
120

  
121
    def poll(self):
122
        # If we are wrapping a command, check it's status
123

  
124
        if self.wrap_cmd and self.cmd:
125
            ret = self.cmd.poll()
126
            if ret != None:
127
                self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret)
128
                self.cmd = None
129

  
130
        if self.wrap_cmd and self.cmd == None:
131
            # Response to wrapped command being gone
132
            if self.wrap_mode == "ignore":
133
                pass
134
            elif self.wrap_mode == "exit":
135
                sys.exit(ret)
136
            elif self.wrap_mode == "respawn":
137
                now = time.time()
138
                avg = sum(self.wrap_times)/len(self.wrap_times)
139
                if (now - avg) < 10:
140
                    # 3 times in the last 10 seconds
141
                    if self.spawn_message:
142
                        print("Command respawning too fast")
143
                        self.spawn_message = False
144
                else:
145
                    self.run_wrap_cmd()
146

  
147
    #
148
    # Routines above this point are run in the master listener
149
    # process.
150
    #
151

  
152
    #
153
    # Routines below this point are connection handler routines and
154
    # will be run in a separate forked process for each connection.
155
    #
156

  
157
    def new_client(self):
41
    def new_websocket_client(self):
158 42
        """
159 43
        Called after a new WebSocket connection has been established.
160 44
        """
161 45
        # Checks if we receive a token, and look
162 46
        # for a valid target for it then
163
        if self.target_cfg:
164
            (self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path)
47
        if self.server.target_cfg:
48
            (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path)
165 49

  
166 50
        # Connect to the target
167
        if self.wrap_cmd:
168
            msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
169
        elif self.unix_target:
170
            msg = "connecting to unix socket: %s" % self.unix_target
51
        if self.server.wrap_cmd:
52
            msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
53
        elif self.server.unix_target:
54
            msg = "connecting to unix socket: %s" % self.server.unix_target
171 55
        else:
172 56
            msg = "connecting to: %s:%s" % (
173
                                    self.target_host, self.target_port)
57
                                    self.server.target_host, self.server.target_port)
174 58

  
175
        if self.ssl_target:
59
        if self.server.ssl_target:
176 60
            msg += " (using SSL)"
177
        self.msg(msg)
61
        self.log_message(msg)
178 62

  
179
        tsock = self.socket(self.target_host, self.target_port,
180
                connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target)
63
        tsock = websocket.WebSocketServer.socket(self.server.target_host,
64
                                                 self.server.target_port,
65
                connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target)
181 66

  
182
        if self.verbose and not self.daemon:
183
            print(self.traffic_legend)
67
        self.print_traffic(self.traffic_legend)
184 68

  
185 69
        # Start proxying
186 70
        try:
......
189 73
            if tsock:
190 74
                tsock.shutdown(socket.SHUT_RDWR)
191 75
                tsock.close()
192
                self.vmsg("%s:%s: Closed target" %(
193
                    self.target_host, self.target_port))
76
                if self.verbose: 
77
                    self.log_message("%s:%s: Closed target",
78
                            self.server.target_host, self.server.target_port)
194 79
            raise
195 80

  
196 81
    def get_target(self, target_cfg, path):
......
205 90
        # Extract the token parameter from url
206 91
        args = parse_qs(urlparse(path)[4]) # 4 is the query from url
207 92

  
208
        if not len(args['token']):
93
        if not args.has_key('token') or not len(args['token']):
209 94
            raise self.EClose("Token not present")
210 95

  
211 96
        token = args['token'][0].rstrip('\n')
......
239 124
        cqueue = []
240 125
        c_pend = 0
241 126
        tqueue = []
242
        rlist = [self.client, target]
127
        rlist = [self.request, target]
243 128

  
244 129
        while True:
245 130
            wlist = []
246 131

  
247 132
            if tqueue: wlist.append(target)
248
            if cqueue or c_pend: wlist.append(self.client)
133
            if cqueue or c_pend: wlist.append(self.request)
249 134
            ins, outs, excepts = select(rlist, wlist, [], 1)
250 135
            if excepts: raise Exception("Socket exception")
251 136

  
137
            if self.request in outs:
138
                # Send queued target data to the client
139
                c_pend = self.send_frames(cqueue)
140

  
141
                cqueue = []
142

  
143
            if self.request in ins:
144
                # Receive client data, decode it, and queue for target
145
                bufs, closed = self.recv_frames()
146
                tqueue.extend(bufs)
147

  
148
                if closed:
149
                    # TODO: What about blocking on client socket?
150
                    if self.verbose: 
151
                        self.log_message("%s:%s: Client closed connection",
152
                                self.server.target_host, self.server.target_port)
153
                    raise self.CClose(closed['code'], closed['reason'])
154

  
155

  
252 156
            if target in outs:
253 157
                # Send queued client data to the target
254 158
                dat = tqueue.pop(0)
255 159
                sent = target.send(dat)
256 160
                if sent == len(dat):
257
                    self.traffic(">")
161
                    self.print_traffic(">")
258 162
                else:
259 163
                    # requeue the remaining data
260 164
                    tqueue.insert(0, dat[sent:])
261
                    self.traffic(".>")
165
                    self.print_traffic(".>")
262 166

  
263 167

  
264 168
            if target in ins:
265 169
                # Receive target data, encode it and queue for client
266 170
                buf = target.recv(self.buffer_size)
267 171
                if len(buf) == 0:
268
                    self.vmsg("%s:%s: Target closed connection" %(
269
                        self.target_host, self.target_port))
172
                    if self.verbose:
173
                        self.log_message("%s:%s: Target closed connection",
174
                                self.server.target_host, self.server.target_port)
270 175
                    raise self.CClose(1000, "Target closed")
271 176

  
272 177
                cqueue.append(buf)
273
                self.traffic("{")
178
                self.print_traffic("{")
179

  
180
class WebSocketProxy(websocket.WebSocketServer):
181
    """
182
    Proxy traffic to and from a WebSockets client to a normal TCP
183
    socket server target. All traffic to/from the client is base64
184
    encoded/decoded to allow binary data to be sent/received to/from
185
    the target.
186
    """
274 187

  
188
    buffer_size = 65536
275 189

  
276
            if self.client in outs:
277
                # Send queued target data to the client
278
                c_pend = self.send_frames(cqueue)
190
    def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
191
        # Save off proxy specific options
192
        self.target_host    = kwargs.pop('target_host', None)
193
        self.target_port    = kwargs.pop('target_port', None)
194
        self.wrap_cmd       = kwargs.pop('wrap_cmd', None)
195
        self.wrap_mode      = kwargs.pop('wrap_mode', None)
196
        self.unix_target    = kwargs.pop('unix_target', None)
197
        self.ssl_target     = kwargs.pop('ssl_target', None)
198
        self.target_cfg     = kwargs.pop('target_cfg', None)
199
        # Last 3 timestamps command was run
200
        self.wrap_times    = [0, 0, 0]
279 201

  
280
                cqueue = []
202
        if self.wrap_cmd:
203
            wsdir = os.path.dirname(sys.argv[0])
204
            rebinder_path = [os.path.join(wsdir, "..", "lib"),
205
                             os.path.join(wsdir, "..", "lib", "websockify"),
206
                             wsdir]
207
            self.rebinder = None
281 208

  
209
            for rdir in rebinder_path:
210
                rpath = os.path.join(rdir, "rebind.so")
211
                if os.path.exists(rpath):
212
                    self.rebinder = rpath
213
                    break
282 214

  
283
            if self.client in ins:
284
                # Receive client data, decode it, and queue for target
285
                bufs, closed = self.recv_frames()
286
                tqueue.extend(bufs)
215
            if not self.rebinder:
216
                raise Exception("rebind.so not found, perhaps you need to run make")
217
            self.rebinder = os.path.abspath(self.rebinder)
287 218

  
288
                if closed:
289
                    # TODO: What about blocking on client socket?
290
                    self.vmsg("%s:%s: Client closed connection" %(
291
                        self.target_host, self.target_port))
292
                    raise self.CClose(closed['code'], closed['reason'])
219
            self.target_host = "127.0.0.1"  # Loopback
220
            # Find a free high port
221
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
222
            sock.bind(('', 0))
223
            self.target_port = sock.getsockname()[1]
224
            sock.close()
225

  
226
            os.environ.update({
227
                "LD_PRELOAD": self.rebinder,
228
                "REBIND_OLD_PORT": str(kwargs['listen_port']),
229
                "REBIND_NEW_PORT": str(self.target_port)})
230

  
231
        websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs)
232

  
233
    def run_wrap_cmd(self):
234
        self.msg("Starting '%s'", " ".join(self.wrap_cmd))
235
        self.wrap_times.append(time.time())
236
        self.wrap_times.pop(0)
237
        self.cmd = subprocess.Popen(
238
                self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
239
        self.spawn_message = True
240

  
241
    def started(self):
242
        """
243
        Called after Websockets server startup (i.e. after daemonize)
244
        """
245
        # Need to call wrapped command after daemonization so we can
246
        # know when the wrapped command exits
247
        if self.wrap_cmd:
248
            dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
249
        elif self.unix_target:
250
            dst_string = self.unix_target
251
        else:
252
            dst_string = "%s:%s" % (self.target_host, self.target_port)
253

  
254
        if self.target_cfg:
255
            msg = "  - proxying from %s:%s to targets in %s" % (
256
                self.listen_host, self.listen_port, self.target_cfg)
257
        else:
258
            msg = "  - proxying from %s:%s to %s" % (
259
                self.listen_host, self.listen_port, dst_string)
260

  
261
        if self.ssl_target:
262
            msg += " (using SSL)"
263

  
264
        self.msg("%s", msg)
265

  
266
        if self.wrap_cmd:
267
            self.run_wrap_cmd()
268

  
269
    def poll(self):
270
        # If we are wrapping a command, check it's status
271

  
272
        if self.wrap_cmd and self.cmd:
273
            ret = self.cmd.poll()
274
            if ret != None:
275
                self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret)
276
                self.cmd = None
277

  
278
        if self.wrap_cmd and self.cmd == None:
279
            # Response to wrapped command being gone
280
            if self.wrap_mode == "ignore":
281
                pass
282
            elif self.wrap_mode == "exit":
283
                sys.exit(ret)
284
            elif self.wrap_mode == "respawn":
285
                now = time.time()
286
                avg = sum(self.wrap_times)/len(self.wrap_times)
287
                if (now - avg) < 10:
288
                    # 3 times in the last 10 seconds
289
                    if self.spawn_message:
290
                        self.warn("Command respawning too fast")
291
                        self.spawn_message = False
292
                else:
293
                    self.run_wrap_cmd()
293 294

  
294 295

  
295 296
def _subprocess_setup():
......
298 299
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)
299 300

  
300 301

  
302
def logger_init():
303
    logger = logging.getLogger(WebSocketProxy.log_prefix)
304
    logger.propagate = False
305
    logger.setLevel(logging.INFO)
306
    h = logging.StreamHandler()
307
    h.setLevel(logging.DEBUG)
308
    h.setFormatter(logging.Formatter("%(message)s"))
309
    logger.addHandler(h)
310

  
311

  
301 312
def websockify_init():
313
    logger_init()
314

  
302 315
    usage = "\n    %prog [options]"
303 316
    usage += " [source_addr:]source_port [target_addr:target_port]"
304 317
    usage += "\n    %prog [options]"
305 318
    usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
306 319
    parser = optparse.OptionParser(usage=usage)
307 320
    parser.add_option("--verbose", "-v", action="store_true",
308
            help="verbose messages and per frame traffic")
321
            help="verbose messages")
322
    parser.add_option("--traffic", action="store_true",
323
            help="per frame traffic")
309 324
    parser.add_option("--record",
310 325
            help="record sessions to FILE.[session_number]", metavar="FILE")
311 326
    parser.add_option("--daemon", "-D",
......
342 357
            help="Configuration file containing valid targets "
343 358
            "in the form 'token: host:port' or, alternatively, a "
344 359
            "directory containing configuration files of this form")
360
    parser.add_option("--libserver", action="store_true",
361
            help="use Python library SocketServer engine")
345 362
    (opts, args) = parser.parse_args()
346 363

  
364
    if opts.verbose:
365
        logging.getLogger(WebSocketProxy.log_prefix).setLevel(logging.DEBUG)
366

  
347 367
    # Sanity checks
348 368
    if len(args) < 2 and not (opts.target_cfg or opts.unix_target):
349 369
        parser.error("Too few arguments")
......
382 402
        try:    opts.target_port = int(opts.target_port)
383 403
        except: parser.error("Error parsing target port")
384 404

  
405
    # Transform to absolute path as daemon may chdir
406
    if opts.target_cfg:
407
        opts.target_cfg = os.path.abspath(opts.target_cfg)
408

  
385 409
    # Create and start the WebSockets proxy
386
    server = WebSocketProxy(**opts.__dict__)
387
    server.start_server()
410
    libserver = opts.libserver
411
    del opts.libserver
412
    if libserver:
413
        # Use standard Python SocketServer framework
414
        server = LibProxyServer(**opts.__dict__)
415
        server.serve_forever()
416
    else:
417
        # Use internal service framework
418
        server = WebSocketProxy(**opts.__dict__)
419
        server.start_server()
420

  
421

  
422
class LibProxyServer(ForkingMixIn, HTTPServer):
423
    """
424
    Just like WebSocketProxy, but uses standard Python SocketServer
425
    framework.
426
    """
427

  
428
    def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
429
        # Save off proxy specific options
430
        self.target_host    = kwargs.pop('target_host', None)
431
        self.target_port    = kwargs.pop('target_port', None)
432
        self.wrap_cmd       = kwargs.pop('wrap_cmd', None)
433
        self.wrap_mode      = kwargs.pop('wrap_mode', None)
434
        self.unix_target    = kwargs.pop('unix_target', None)
435
        self.ssl_target     = kwargs.pop('ssl_target', None)
436
        self.target_cfg     = kwargs.pop('target_cfg', None)
437
        self.daemon = False
438
        self.target_cfg = None
439

  
440
        # Server configuration
441
        listen_host    = kwargs.pop('listen_host', '')
442
        listen_port    = kwargs.pop('listen_port', None)
443
        web            = kwargs.pop('web', '')
444

  
445
        # Configuration affecting base request handler
446
        self.only_upgrade   = not web
447
        self.verbose   = kwargs.pop('verbose', False)
448
        record = kwargs.pop('record', '')
449
        if record:
450
            self.record = os.path.abspath(record)
451
        self.run_once  = kwargs.pop('run_once', False)
452
        self.handler_id = 0
453

  
454
        for arg in kwargs.keys():
455
            print("warning: option %s ignored when using --libserver" % arg)
456

  
457
        if web:
458
            os.chdir(web)
459
            
460
        HTTPServer.__init__(self, (listen_host, listen_port), 
461
                            RequestHandlerClass)
462

  
463

  
464
    def process_request(self, request, client_address):
465
        """Override process_request to implement a counter"""
466
        self.handler_id += 1
467
        ForkingMixIn.process_request(self, request, client_address)
468

  
388 469

  
389 470
if __name__ == '__main__':
390 471
    websockify_init()

Also available in: Unified diff