/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/transport/remote.py

Merge from bzr.dev.  Breaks a few tests because there are new methods not yet implemented.

Show diffs side-by-side

added added

removed removed

Lines of Context:
38
38
del scheme
39
39
 
40
40
 
 
41
# Port 4155 is the default port for bzr://, registered with IANA.
 
42
BZR_DEFAULT_PORT = 4155
 
43
 
 
44
 
 
45
def _recv_tuple(from_file):
 
46
    req_line = from_file.readline()
 
47
    return _decode_tuple(req_line)
 
48
 
 
49
 
 
50
def _decode_tuple(req_line):
 
51
    if req_line == None or req_line == '':
 
52
        return None
 
53
    if req_line[-1] != '\n':
 
54
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
55
    return tuple(req_line[:-1].split('\x01'))
 
56
 
 
57
 
 
58
def _encode_tuple(args):
 
59
    """Encode the tuple args to a bytestream."""
 
60
    return '\x01'.join(args) + '\n'
 
61
 
 
62
 
 
63
class SmartProtocolBase(object):
 
64
    """Methods common to client and server"""
 
65
 
 
66
    # TODO: this only actually accomodates a single block; possibly should
 
67
    # support multiple chunks?
 
68
    def _encode_bulk_data(self, body):
 
69
        """Encode body as a bulk data chunk."""
 
70
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
71
 
 
72
    def _serialise_offsets(self, offsets):
 
73
        """Serialise a readv offset list."""
 
74
        txt = []
 
75
        for start, length in offsets:
 
76
            txt.append('%d,%d' % (start, length))
 
77
        return '\n'.join(txt)
 
78
        
 
79
 
 
80
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
81
    """Server-side encoding and decoding logic for smart version 1."""
 
82
    
 
83
    def __init__(self, backing_transport, write_func):
 
84
        self._backing_transport = backing_transport
 
85
        self.excess_buffer = ''
 
86
        self._finished = False
 
87
        self.in_buffer = ''
 
88
        self.has_dispatched = False
 
89
        self.request = None
 
90
        self._body_decoder = None
 
91
        self._write_func = write_func
 
92
 
 
93
    def accept_bytes(self, bytes):
 
94
        """Take bytes, and advance the internal state machine appropriately.
 
95
        
 
96
        :param bytes: must be a byte string
 
97
        """
 
98
        assert isinstance(bytes, str)
 
99
        self.in_buffer += bytes
 
100
        if not self.has_dispatched:
 
101
            if '\n' not in self.in_buffer:
 
102
                # no command line yet
 
103
                return
 
104
            self.has_dispatched = True
 
105
            try:
 
106
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
107
                first_line += '\n'
 
108
                req_args = _decode_tuple(first_line)
 
109
                self.request = SmartServerRequestHandler(
 
110
                    self._backing_transport)
 
111
                self.request.dispatch_command(req_args[0], req_args[1:])
 
112
                if self.request.finished_reading:
 
113
                    # trivial request
 
114
                    self.excess_buffer = self.in_buffer
 
115
                    self.in_buffer = ''
 
116
                    self._send_response(self.request.response.args,
 
117
                        self.request.response.body)
 
118
            except KeyboardInterrupt:
 
119
                raise
 
120
            except Exception, exception:
 
121
                # everything else: pass to client, flush, and quit
 
122
                self._send_response(('error', str(exception)))
 
123
                return
 
124
 
 
125
        if self.has_dispatched:
 
126
            if self._finished:
 
127
                # nothing to do.XXX: this routine should be a single state 
 
128
                # machine too.
 
129
                self.excess_buffer += self.in_buffer
 
130
                self.in_buffer = ''
 
131
                return
 
132
            if self._body_decoder is None:
 
133
                self._body_decoder = LengthPrefixedBodyDecoder()
 
134
            self._body_decoder.accept_bytes(self.in_buffer)
 
135
            self.in_buffer = self._body_decoder.unused_data
 
136
            body_data = self._body_decoder.read_pending_data()
 
137
            self.request.accept_body(body_data)
 
138
            if self._body_decoder.finished_reading:
 
139
                self.request.end_of_body()
 
140
                assert self.request.finished_reading, \
 
141
                    "no more body, request not finished"
 
142
            if self.request.response is not None:
 
143
                self._send_response(self.request.response.args,
 
144
                    self.request.response.body)
 
145
                self.excess_buffer = self.in_buffer
 
146
                self.in_buffer = ''
 
147
            else:
 
148
                assert not self.request.finished_reading, \
 
149
                    "no response and we have finished reading."
 
150
 
 
151
    def _send_response(self, args, body=None):
 
152
        """Send a smart server response down the output stream."""
 
153
        assert not self._finished, 'response already sent'
 
154
        self._finished = True
 
155
        self._write_func(_encode_tuple(args))
 
156
        if body is not None:
 
157
            assert isinstance(body, str), 'body must be a str'
 
158
            bytes = self._encode_bulk_data(body)
 
159
            self._write_func(bytes)
 
160
 
 
161
    def next_read_size(self):
 
162
        if self._finished:
 
163
            return 0
 
164
        if self._body_decoder is None:
 
165
            return 1
 
166
        else:
 
167
            return self._body_decoder.next_read_size()
 
168
 
 
169
 
 
170
class LengthPrefixedBodyDecoder(object):
 
171
    """Decodes the length-prefixed bulk data."""
 
172
    
 
173
    def __init__(self):
 
174
        self.bytes_left = None
 
175
        self.finished_reading = False
 
176
        self.unused_data = ''
 
177
        self.state_accept = self._state_accept_expecting_length
 
178
        self.state_read = self._state_read_no_data
 
179
        self._in_buffer = ''
 
180
        self._trailer_buffer = ''
 
181
    
 
182
    def accept_bytes(self, bytes):
 
183
        """Decode as much of bytes as possible.
 
184
 
 
185
        If 'bytes' contains too much data it will be appended to
 
186
        self.unused_data.
 
187
 
 
188
        finished_reading will be set when no more data is required.  Further
 
189
        data will be appended to self.unused_data.
 
190
        """
 
191
        # accept_bytes is allowed to change the state
 
192
        current_state = self.state_accept
 
193
        self.state_accept(bytes)
 
194
        while current_state != self.state_accept:
 
195
            current_state = self.state_accept
 
196
            self.state_accept('')
 
197
 
 
198
    def next_read_size(self):
 
199
        if self.bytes_left is not None:
 
200
            # Ideally we want to read all the remainder of the body and the
 
201
            # trailer in one go.
 
202
            return self.bytes_left + 5
 
203
        elif self.state_accept == self._state_accept_reading_trailer:
 
204
            # Just the trailer left
 
205
            return 5 - len(self._trailer_buffer)
 
206
        elif self.state_accept == self._state_accept_expecting_length:
 
207
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
208
            # 'done\n').
 
209
            return 6
 
210
        else:
 
211
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
212
            return 1
 
213
        
 
214
    def read_pending_data(self):
 
215
        """Return any pending data that has been decoded."""
 
216
        return self.state_read()
 
217
 
 
218
    def _state_accept_expecting_length(self, bytes):
 
219
        self._in_buffer += bytes
 
220
        pos = self._in_buffer.find('\n')
 
221
        if pos == -1:
 
222
            return
 
223
        self.bytes_left = int(self._in_buffer[:pos])
 
224
        self._in_buffer = self._in_buffer[pos+1:]
 
225
        self.bytes_left -= len(self._in_buffer)
 
226
        self.state_accept = self._state_accept_reading_body
 
227
        self.state_read = self._state_read_in_buffer
 
228
 
 
229
    def _state_accept_reading_body(self, bytes):
 
230
        self._in_buffer += bytes
 
231
        self.bytes_left -= len(bytes)
 
232
        if self.bytes_left <= 0:
 
233
            # Finished with body
 
234
            if self.bytes_left != 0:
 
235
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
236
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
237
            self.bytes_left = None
 
238
            self.state_accept = self._state_accept_reading_trailer
 
239
        
 
240
    def _state_accept_reading_trailer(self, bytes):
 
241
        self._trailer_buffer += bytes
 
242
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
243
        # a ProtocolViolation exception?
 
244
        if self._trailer_buffer.startswith('done\n'):
 
245
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
246
            self.state_accept = self._state_accept_reading_unused
 
247
            self.finished_reading = True
 
248
    
 
249
    def _state_accept_reading_unused(self, bytes):
 
250
        self.unused_data += bytes
 
251
 
 
252
    def _state_read_no_data(self):
 
253
        return ''
 
254
 
 
255
    def _state_read_in_buffer(self):
 
256
        result = self._in_buffer
 
257
        self._in_buffer = ''
 
258
        return result
 
259
 
 
260
 
 
261
class SmartServerStreamMedium(object):
 
262
    """Handles smart commands coming over a stream.
 
263
 
 
264
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
265
    in-process fifo for testing.
 
266
 
 
267
    One instance is created for each connected client; it can serve multiple
 
268
    requests in the lifetime of the connection.
 
269
 
 
270
    The server passes requests through to an underlying backing transport, 
 
271
    which will typically be a LocalTransport looking at the server's filesystem.
 
272
    """
 
273
 
 
274
    def __init__(self, backing_transport):
 
275
        """Construct new server.
 
276
 
 
277
        :param backing_transport: Transport for the directory served.
 
278
        """
 
279
        # backing_transport could be passed to serve instead of __init__
 
280
        self.backing_transport = backing_transport
 
281
        self.finished = False
 
282
 
 
283
    def serve(self):
 
284
        """Serve requests until the client disconnects."""
 
285
        # Keep a reference to stderr because the sys module's globals get set to
 
286
        # None during interpreter shutdown.
 
287
        from sys import stderr
 
288
        try:
 
289
            while not self.finished:
 
290
                protocol = SmartServerRequestProtocolOne(self.backing_transport,
 
291
                                                         self._write_out)
 
292
                self._serve_one_request(protocol)
 
293
        except Exception, e:
 
294
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
295
            raise
 
296
 
 
297
    def _serve_one_request(self, protocol):
 
298
        """Read one request from input, process, send back a response.
 
299
        
 
300
        :param protocol: a SmartServerRequestProtocol.
 
301
        """
 
302
        try:
 
303
            self._serve_one_request_unguarded(protocol)
 
304
        except KeyboardInterrupt:
 
305
            raise
 
306
        except Exception, e:
 
307
            self.terminate_due_to_error()
 
308
 
 
309
    def terminate_due_to_error(self):
 
310
        """Called when an unhandled exception from the protocol occurs."""
 
311
        raise NotImplementedError(self.terminate_due_to_error)
 
312
 
 
313
 
 
314
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
315
 
 
316
    def __init__(self, sock, backing_transport):
 
317
        """Constructor.
 
318
 
 
319
        :param sock: the socket the server will read from.  It will be put
 
320
            into blocking mode.
 
321
        """
 
322
        SmartServerStreamMedium.__init__(self, backing_transport)
 
323
        self.push_back = ''
 
324
        sock.setblocking(True)
 
325
        self.socket = sock
 
326
 
 
327
    def _serve_one_request_unguarded(self, protocol):
 
328
        while protocol.next_read_size():
 
329
            if self.push_back:
 
330
                protocol.accept_bytes(self.push_back)
 
331
                self.push_back = ''
 
332
            else:
 
333
                bytes = self.socket.recv(4096)
 
334
                if bytes == '':
 
335
                    self.finished = True
 
336
                    return
 
337
                protocol.accept_bytes(bytes)
 
338
        
 
339
        self.push_back = protocol.excess_buffer
 
340
    
 
341
    def terminate_due_to_error(self):
 
342
        """Called when an unhandled exception from the protocol occurs."""
 
343
        # TODO: This should log to a server log file, but no such thing
 
344
        # exists yet.  Andrew Bennetts 2006-09-29.
 
345
        self.socket.close()
 
346
        self.finished = True
 
347
 
 
348
    def _write_out(self, bytes):
 
349
        self.socket.sendall(bytes)
 
350
 
 
351
 
 
352
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
353
 
 
354
    def __init__(self, in_file, out_file, backing_transport):
 
355
        """Construct new server.
 
356
 
 
357
        :param in_file: Python file from which requests can be read.
 
358
        :param out_file: Python file to write responses.
 
359
        :param backing_transport: Transport for the directory served.
 
360
        """
 
361
        SmartServerStreamMedium.__init__(self, backing_transport)
 
362
        self._in = in_file
 
363
        self._out = out_file
 
364
 
 
365
    def _serve_one_request_unguarded(self, protocol):
 
366
        while True:
 
367
            bytes_to_read = protocol.next_read_size()
 
368
            if bytes_to_read == 0:
 
369
                # Finished serving this request.
 
370
                self._out.flush()
 
371
                return
 
372
            bytes = self._in.read(bytes_to_read)
 
373
            if bytes == '':
 
374
                # Connection has been closed.
 
375
                self.finished = True
 
376
                self._out.flush()
 
377
                return
 
378
            protocol.accept_bytes(bytes)
 
379
 
 
380
    def terminate_due_to_error(self):
 
381
        # TODO: This should log to a server log file, but no such thing
 
382
        # exists yet.  Andrew Bennetts 2006-09-29.
 
383
        self._out.close()
 
384
        self.finished = True
 
385
 
 
386
    def _write_out(self, bytes):
 
387
        self._out.write(bytes)
 
388
 
 
389
 
 
390
class SmartServerResponse(object):
 
391
    """Response generated by SmartServerRequestHandler."""
 
392
 
 
393
    def __init__(self, args, body=None):
 
394
        self.args = args
 
395
        self.body = body
 
396
 
 
397
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
 
398
# for delivering the data for a request. This could be done with as the
 
399
# StreamServer, though that would create conflation between request and response
 
400
# which may be undesirable.
 
401
 
 
402
 
 
403
class SmartServerRequestHandler(object):
 
404
    """Protocol logic for smart server.
 
405
    
 
406
    This doesn't handle serialization at all, it just processes requests and
 
407
    creates responses.
 
408
    """
 
409
 
 
410
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServerRequestHandler
 
411
    # not contain encoding or decoding logic to allow the wire protocol to vary
 
412
    # from the object protocol: we will want to tweak the wire protocol separate
 
413
    # from the object model, and ideally we will be able to do that without
 
414
    # having a SmartServerRequestHandler subclass for each wire protocol, rather
 
415
    # just a Protocol subclass.
 
416
 
 
417
    # TODO: Better way of representing the body for commands that take it,
 
418
    # and allow it to be streamed into the server.
 
419
    
 
420
    def __init__(self, backing_transport):
 
421
        self._backing_transport = backing_transport
 
422
        self._converted_command = False
 
423
        self.finished_reading = False
 
424
        self._body_bytes = ''
 
425
        self.response = None
 
426
 
 
427
    def accept_body(self, bytes):
 
428
        """Accept body data.
 
429
 
 
430
        This should be overriden for each command that desired body data to
 
431
        handle the right format of that data. I.e. plain bytes, a bundle etc.
 
432
 
 
433
        The deserialisation into that format should be done in the Protocol
 
434
        object. Set self.desired_body_format to the format your method will
 
435
        handle.
 
436
        """
 
437
        # default fallback is to accumulate bytes.
 
438
        self._body_bytes += bytes
 
439
        
 
440
    def _end_of_body_handler(self):
 
441
        """An unimplemented end of body handler."""
 
442
        raise NotImplementedError(self._end_of_body_handler)
 
443
        
 
444
    def do_hello(self):
 
445
        """Answer a version request with my version."""
 
446
        return SmartServerResponse(('ok', '1'))
 
447
 
 
448
    def do_has(self, relpath):
 
449
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
 
450
        return SmartServerResponse((r,))
 
451
 
 
452
    def do_get(self, relpath):
 
453
        backing_bytes = self._backing_transport.get_bytes(relpath)
 
454
        return SmartServerResponse(('ok',), backing_bytes)
 
455
 
 
456
    def _deserialise_optional_mode(self, mode):
 
457
        # XXX: FIXME this should be on the protocol object.
 
458
        if mode == '':
 
459
            return None
 
460
        else:
 
461
            return int(mode)
 
462
 
 
463
    def do_append(self, relpath, mode):
 
464
        self._converted_command = True
 
465
        self._relpath = relpath
 
466
        self._mode = self._deserialise_optional_mode(mode)
 
467
        self._end_of_body_handler = self._handle_do_append_end
 
468
    
 
469
    def _handle_do_append_end(self):
 
470
        old_length = self._backing_transport.append_bytes(
 
471
            self._relpath, self._body_bytes, self._mode)
 
472
        self.response = SmartServerResponse(('appended', '%d' % old_length))
 
473
 
 
474
    def do_delete(self, relpath):
 
475
        self._backing_transport.delete(relpath)
 
476
 
 
477
    def do_iter_files_recursive(self, relpath):
 
478
        transport = self._backing_transport.clone(relpath)
 
479
        filenames = transport.iter_files_recursive()
 
480
        return SmartServerResponse(('names',) + tuple(filenames))
 
481
 
 
482
    def do_list_dir(self, relpath):
 
483
        filenames = self._backing_transport.list_dir(relpath)
 
484
        return SmartServerResponse(('names',) + tuple(filenames))
 
485
 
 
486
    def do_mkdir(self, relpath, mode):
 
487
        self._backing_transport.mkdir(relpath,
 
488
                                      self._deserialise_optional_mode(mode))
 
489
 
 
490
    def do_move(self, rel_from, rel_to):
 
491
        self._backing_transport.move(rel_from, rel_to)
 
492
 
 
493
    def do_put(self, relpath, mode):
 
494
        self._converted_command = True
 
495
        self._relpath = relpath
 
496
        self._mode = self._deserialise_optional_mode(mode)
 
497
        self._end_of_body_handler = self._handle_do_put
 
498
 
 
499
    def _handle_do_put(self):
 
500
        self._backing_transport.put_bytes(self._relpath,
 
501
                self._body_bytes, self._mode)
 
502
        self.response = SmartServerResponse(('ok',))
 
503
 
 
504
    def _deserialise_offsets(self, text):
 
505
        # XXX: FIXME this should be on the protocol object.
 
506
        offsets = []
 
507
        for line in text.split('\n'):
 
508
            if not line:
 
509
                continue
 
510
            start, length = line.split(',')
 
511
            offsets.append((int(start), int(length)))
 
512
        return offsets
 
513
 
 
514
    def do_put_non_atomic(self, relpath, mode, create_parent, dir_mode):
 
515
        self._converted_command = True
 
516
        self._end_of_body_handler = self._handle_put_non_atomic
 
517
        self._relpath = relpath
 
518
        self._dir_mode = self._deserialise_optional_mode(dir_mode)
 
519
        self._mode = self._deserialise_optional_mode(mode)
 
520
        # a boolean would be nicer XXX
 
521
        self._create_parent = (create_parent == 'T')
 
522
 
 
523
    def _handle_put_non_atomic(self):
 
524
        self._backing_transport.put_bytes_non_atomic(self._relpath,
 
525
                self._body_bytes,
 
526
                mode=self._mode,
 
527
                create_parent_dir=self._create_parent,
 
528
                dir_mode=self._dir_mode)
 
529
        self.response = SmartServerResponse(('ok',))
 
530
 
 
531
    def do_readv(self, relpath):
 
532
        self._converted_command = True
 
533
        self._end_of_body_handler = self._handle_readv_offsets
 
534
        self._relpath = relpath
 
535
 
 
536
    def end_of_body(self):
 
537
        """No more body data will be received."""
 
538
        self._run_handler_code(self._end_of_body_handler, (), {})
 
539
        # cannot read after this.
 
540
        self.finished_reading = True
 
541
 
 
542
    def _handle_readv_offsets(self):
 
543
        """accept offsets for a readv request."""
 
544
        offsets = self._deserialise_offsets(self._body_bytes)
 
545
        backing_bytes = ''.join(bytes for offset, bytes in
 
546
            self._backing_transport.readv(self._relpath, offsets))
 
547
        self.response = SmartServerResponse(('readv',), backing_bytes)
 
548
        
 
549
    def do_rename(self, rel_from, rel_to):
 
550
        self._backing_transport.rename(rel_from, rel_to)
 
551
 
 
552
    def do_rmdir(self, relpath):
 
553
        self._backing_transport.rmdir(relpath)
 
554
 
 
555
    def do_stat(self, relpath):
 
556
        stat = self._backing_transport.stat(relpath)
 
557
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
 
558
        
 
559
    def do_get_bundle(self, path, revision_id):
 
560
        # open transport relative to our base
 
561
        t = self._backing_transport.clone(path)
 
562
        control, extra_path = bzrdir.BzrDir.open_containing_from_transport(t)
 
563
        repo = control.open_repository()
 
564
        tmpf = tempfile.TemporaryFile()
 
565
        base_revision = revision.NULL_REVISION
 
566
        write_bundle(repo, revision_id, base_revision, tmpf)
 
567
        tmpf.seek(0)
 
568
        return SmartServerResponse((), tmpf.read())
 
569
 
 
570
    def dispatch_command(self, cmd, args):
 
571
        """Deprecated compatibility method.""" # XXX XXX
 
572
        func = getattr(self, 'do_' + cmd, None)
 
573
        if func is None:
 
574
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
 
575
        self._run_handler_code(func, args, {})
 
576
 
 
577
    def _run_handler_code(self, callable, args, kwargs):
 
578
        """Run some handler specific code 'callable'.
 
579
 
 
580
        If a result is returned, it is considered to be the commands response,
 
581
        and finished_reading is set true, and its assigned to self.response.
 
582
 
 
583
        Any exceptions caught are translated and a response object created
 
584
        from them.
 
585
        """
 
586
        result = self._call_converting_errors(callable, args, kwargs)
 
587
        if result is not None:
 
588
            self.response = result
 
589
            self.finished_reading = True
 
590
        # handle unconverted commands
 
591
        if not self._converted_command:
 
592
            self.finished_reading = True
 
593
            if result is None:
 
594
                self.response = SmartServerResponse(('ok',))
 
595
 
 
596
    def _call_converting_errors(self, callable, args, kwargs):
 
597
        """Call callable converting errors to Response objects."""
 
598
        try:
 
599
            return callable(*args, **kwargs)
 
600
        except errors.NoSuchFile, e:
 
601
            return SmartServerResponse(('NoSuchFile', e.path))
 
602
        except errors.FileExists, e:
 
603
            return SmartServerResponse(('FileExists', e.path))
 
604
        except errors.DirectoryNotEmpty, e:
 
605
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
 
606
        except errors.ShortReadvError, e:
 
607
            return SmartServerResponse(('ShortReadvError',
 
608
                e.path, str(e.offset), str(e.length), str(e.actual)))
 
609
        except UnicodeError, e:
 
610
            # If it is a DecodeError, than most likely we are starting
 
611
            # with a plain string
 
612
            str_or_unicode = e.object
 
613
            if isinstance(str_or_unicode, unicode):
 
614
                # XXX: UTF-8 might have \x01 (our seperator byte) in it.  We
 
615
                # should escape it somehow.
 
616
                val = 'u:' + str_or_unicode.encode('utf-8')
 
617
            else:
 
618
                val = 's:' + str_or_unicode.encode('base64')
 
619
            # This handles UnicodeEncodeError or UnicodeDecodeError
 
620
            return SmartServerResponse((e.__class__.__name__,
 
621
                    e.encoding, val, str(e.start), str(e.end), e.reason))
 
622
        except errors.TransportNotPossible, e:
 
623
            if e.msg == "readonly transport":
 
624
                return SmartServerResponse(('ReadOnlyError', ))
 
625
            else:
 
626
                raise
 
627
 
 
628
 
 
629
class SmartTCPServer(object):
 
630
    """Listens on a TCP socket and accepts connections from smart clients"""
 
631
 
 
632
    def __init__(self, backing_transport, host='127.0.0.1', port=0):
 
633
        """Construct a new server.
 
634
 
 
635
        To actually start it running, call either start_background_thread or
 
636
        serve.
 
637
 
 
638
        :param host: Name of the interface to listen on.
 
639
        :param port: TCP port to listen on, or 0 to allocate a transient port.
 
640
        """
 
641
        self._server_socket = socket.socket()
 
642
        self._server_socket.bind((host, port))
 
643
        self.port = self._server_socket.getsockname()[1]
 
644
        self._server_socket.listen(1)
 
645
        self._server_socket.settimeout(1)
 
646
        self.backing_transport = backing_transport
 
647
 
 
648
    def serve(self):
 
649
        # let connections timeout so that we get a chance to terminate
 
650
        # Keep a reference to the exceptions we want to catch because the socket
 
651
        # module's globals get set to None during interpreter shutdown.
 
652
        from socket import timeout as socket_timeout
 
653
        from socket import error as socket_error
 
654
        self._should_terminate = False
 
655
        while not self._should_terminate:
 
656
            try:
 
657
                self.accept_and_serve()
 
658
            except socket_timeout:
 
659
                # just check if we're asked to stop
 
660
                pass
 
661
            except socket_error, e:
 
662
                trace.warning("client disconnected: %s", e)
 
663
                pass
 
664
 
 
665
    def get_url(self):
 
666
        """Return the url of the server"""
 
667
        return "bzr://%s:%d/" % self._server_socket.getsockname()
 
668
 
 
669
    def accept_and_serve(self):
 
670
        conn, client_addr = self._server_socket.accept()
 
671
        # For WIN32, where the timeout value from the listening socket
 
672
        # propogates to the newly accepted socket.
 
673
        conn.setblocking(True)
 
674
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
675
        handler = SmartServerSocketStreamMedium(conn, self.backing_transport)
 
676
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
 
677
        connection_thread.setDaemon(True)
 
678
        connection_thread.start()
 
679
 
 
680
    def start_background_thread(self):
 
681
        self._server_thread = threading.Thread(None,
 
682
                self.serve,
 
683
                name='server-' + self.get_url())
 
684
        self._server_thread.setDaemon(True)
 
685
        self._server_thread.start()
 
686
 
 
687
    def stop_background_thread(self):
 
688
        self._should_terminate = True
 
689
        # At one point we would wait to join the threads here, but it looks
 
690
        # like they don't actually exit.  So now we just leave them running
 
691
        # and expect to terminate the process. -- mbp 20070215
 
692
        # self._server_socket.close()
 
693
        ## sys.stderr.write("waiting for server thread to finish...")
 
694
        ## self._server_thread.join()
 
695
 
 
696
 
 
697
class SmartTCPServer_for_testing(SmartTCPServer):
 
698
    """Server suitable for use by transport tests.
 
699
    
 
700
    This server is backed by the process's cwd.
 
701
    """
 
702
 
 
703
    def __init__(self):
 
704
        self._homedir = urlutils.local_path_to_url(os.getcwd())[7:]
 
705
        # The server is set up by default like for ssh access: the client
 
706
        # passes filesystem-absolute paths; therefore the server must look
 
707
        # them up relative to the root directory.  it might be better to act
 
708
        # a public server and have the server rewrite paths into the test
 
709
        # directory.
 
710
        SmartTCPServer.__init__(self,
 
711
            transport.get_transport(urlutils.local_path_to_url('/')))
 
712
        
 
713
    def setUp(self):
 
714
        """Set up server for testing"""
 
715
        self.start_background_thread()
 
716
 
 
717
    def tearDown(self):
 
718
        self.stop_background_thread()
 
719
 
 
720
    def get_url(self):
 
721
        """Return the url of the server"""
 
722
        host, port = self._server_socket.getsockname()
 
723
        return "bzr://%s:%d%s" % (host, port, urlutils.escape(self._homedir))
 
724
 
 
725
    def get_bogus_url(self):
 
726
        """Return a URL which will fail to connect"""
 
727
        return 'bzr://127.0.0.1:1/'
 
728
 
 
729
 
41
730
class _SmartStat(object):
42
731
 
43
732
    def __init__(self, size, mode):
410
1099
    def __init__(self, url):
411
1100
        _scheme, _username, _password, _host, _port, _path = \
412
1101
            transport.split_url(url)
413
 
        try:
414
 
            _port = int(_port)
415
 
        except (ValueError, TypeError), e:
416
 
            raise errors.InvalidURL(path=url, extra="invalid port %s" % _port)
 
1102
        if _port is None:
 
1103
            _port = BZR_DEFAULT_PORT
 
1104
        else:
 
1105
            try:
 
1106
                _port = int(_port)
 
1107
            except (ValueError, TypeError), e:
 
1108
                raise errors.InvalidURL(
 
1109
                    path=url, extra="invalid port %s" % _port)
417
1110
        client_medium = medium.SmartTCPClientMedium(_host, _port)
418
1111
        super(SmartTCPTransport, self).__init__(url, medium=client_medium)
419
1112