/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart/_smart.py

Split up more smart server code, this time into bzrlib/transport/smart/protocol.py

Show diffs side-by-side

added added

removed removed

Lines of Context:
212
212
    )
213
213
from bzrlib.bundle.serializer import write_bundle
214
214
from bzrlib.transport.smart import medium
 
215
from bzrlib.transport.smart import protocol
215
216
try:
216
217
    from bzrlib.transport import ssh
217
218
except errors.ParamikoNotPresent:
224
225
del scheme
225
226
 
226
227
 
227
 
def _recv_tuple(from_file):
228
 
    req_line = from_file.readline()
229
 
    return _decode_tuple(req_line)
230
 
 
231
 
 
232
 
def _decode_tuple(req_line):
233
 
    if req_line == None or req_line == '':
234
 
        return None
235
 
    if req_line[-1] != '\n':
236
 
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
237
 
    return tuple(req_line[:-1].split('\x01'))
238
 
 
239
 
 
240
 
def _encode_tuple(args):
241
 
    """Encode the tuple args to a bytestream."""
242
 
    return '\x01'.join(args) + '\n'
243
 
 
244
 
 
245
 
class SmartProtocolBase(object):
246
 
    """Methods common to client and server"""
247
 
 
248
 
    # TODO: this only actually accomodates a single block; possibly should
249
 
    # support multiple chunks?
250
 
    def _encode_bulk_data(self, body):
251
 
        """Encode body as a bulk data chunk."""
252
 
        return ''.join(('%d\n' % len(body), body, 'done\n'))
253
 
 
254
 
    def _serialise_offsets(self, offsets):
255
 
        """Serialise a readv offset list."""
256
 
        txt = []
257
 
        for start, length in offsets:
258
 
            txt.append('%d,%d' % (start, length))
259
 
        return '\n'.join(txt)
260
 
        
261
 
 
262
 
class SmartServerRequestProtocolOne(SmartProtocolBase):
263
 
    """Server-side encoding and decoding logic for smart version 1."""
264
 
    
265
 
    def __init__(self, backing_transport, write_func):
266
 
        self._backing_transport = backing_transport
267
 
        self.excess_buffer = ''
268
 
        self._finished = False
269
 
        self.in_buffer = ''
270
 
        self.has_dispatched = False
271
 
        self.request = None
272
 
        self._body_decoder = None
273
 
        self._write_func = write_func
274
 
 
275
 
    def accept_bytes(self, bytes):
276
 
        """Take bytes, and advance the internal state machine appropriately.
277
 
        
278
 
        :param bytes: must be a byte string
279
 
        """
280
 
        assert isinstance(bytes, str)
281
 
        self.in_buffer += bytes
282
 
        if not self.has_dispatched:
283
 
            if '\n' not in self.in_buffer:
284
 
                # no command line yet
285
 
                return
286
 
            self.has_dispatched = True
287
 
            try:
288
 
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
289
 
                first_line += '\n'
290
 
                req_args = _decode_tuple(first_line)
291
 
                self.request = SmartServerRequestHandler(
292
 
                    self._backing_transport)
293
 
                self.request.dispatch_command(req_args[0], req_args[1:])
294
 
                if self.request.finished_reading:
295
 
                    # trivial request
296
 
                    self.excess_buffer = self.in_buffer
297
 
                    self.in_buffer = ''
298
 
                    self._send_response(self.request.response.args,
299
 
                        self.request.response.body)
300
 
            except KeyboardInterrupt:
301
 
                raise
302
 
            except Exception, exception:
303
 
                # everything else: pass to client, flush, and quit
304
 
                self._send_response(('error', str(exception)))
305
 
                return
306
 
 
307
 
        if self.has_dispatched:
308
 
            if self._finished:
309
 
                # nothing to do.XXX: this routine should be a single state 
310
 
                # machine too.
311
 
                self.excess_buffer += self.in_buffer
312
 
                self.in_buffer = ''
313
 
                return
314
 
            if self._body_decoder is None:
315
 
                self._body_decoder = LengthPrefixedBodyDecoder()
316
 
            self._body_decoder.accept_bytes(self.in_buffer)
317
 
            self.in_buffer = self._body_decoder.unused_data
318
 
            body_data = self._body_decoder.read_pending_data()
319
 
            self.request.accept_body(body_data)
320
 
            if self._body_decoder.finished_reading:
321
 
                self.request.end_of_body()
322
 
                assert self.request.finished_reading, \
323
 
                    "no more body, request not finished"
324
 
            if self.request.response is not None:
325
 
                self._send_response(self.request.response.args,
326
 
                    self.request.response.body)
327
 
                self.excess_buffer = self.in_buffer
328
 
                self.in_buffer = ''
329
 
            else:
330
 
                assert not self.request.finished_reading, \
331
 
                    "no response and we have finished reading."
332
 
 
333
 
    def _send_response(self, args, body=None):
334
 
        """Send a smart server response down the output stream."""
335
 
        assert not self._finished, 'response already sent'
336
 
        self._finished = True
337
 
        self._write_func(_encode_tuple(args))
338
 
        if body is not None:
339
 
            assert isinstance(body, str), 'body must be a str'
340
 
            bytes = self._encode_bulk_data(body)
341
 
            self._write_func(bytes)
342
 
 
343
 
    def next_read_size(self):
344
 
        if self._finished:
345
 
            return 0
346
 
        if self._body_decoder is None:
347
 
            return 1
348
 
        else:
349
 
            return self._body_decoder.next_read_size()
350
 
 
351
 
 
352
 
class LengthPrefixedBodyDecoder(object):
353
 
    """Decodes the length-prefixed bulk data."""
354
 
    
355
 
    def __init__(self):
356
 
        self.bytes_left = None
357
 
        self.finished_reading = False
358
 
        self.unused_data = ''
359
 
        self.state_accept = self._state_accept_expecting_length
360
 
        self.state_read = self._state_read_no_data
361
 
        self._in_buffer = ''
362
 
        self._trailer_buffer = ''
363
 
    
364
 
    def accept_bytes(self, bytes):
365
 
        """Decode as much of bytes as possible.
366
 
 
367
 
        If 'bytes' contains too much data it will be appended to
368
 
        self.unused_data.
369
 
 
370
 
        finished_reading will be set when no more data is required.  Further
371
 
        data will be appended to self.unused_data.
372
 
        """
373
 
        # accept_bytes is allowed to change the state
374
 
        current_state = self.state_accept
375
 
        self.state_accept(bytes)
376
 
        while current_state != self.state_accept:
377
 
            current_state = self.state_accept
378
 
            self.state_accept('')
379
 
 
380
 
    def next_read_size(self):
381
 
        if self.bytes_left is not None:
382
 
            # Ideally we want to read all the remainder of the body and the
383
 
            # trailer in one go.
384
 
            return self.bytes_left + 5
385
 
        elif self.state_accept == self._state_accept_reading_trailer:
386
 
            # Just the trailer left
387
 
            return 5 - len(self._trailer_buffer)
388
 
        elif self.state_accept == self._state_accept_expecting_length:
389
 
            # There's still at least 6 bytes left ('\n' to end the length, plus
390
 
            # 'done\n').
391
 
            return 6
392
 
        else:
393
 
            # Reading excess data.  Either way, 1 byte at a time is fine.
394
 
            return 1
395
 
        
396
 
    def read_pending_data(self):
397
 
        """Return any pending data that has been decoded."""
398
 
        return self.state_read()
399
 
 
400
 
    def _state_accept_expecting_length(self, bytes):
401
 
        self._in_buffer += bytes
402
 
        pos = self._in_buffer.find('\n')
403
 
        if pos == -1:
404
 
            return
405
 
        self.bytes_left = int(self._in_buffer[:pos])
406
 
        self._in_buffer = self._in_buffer[pos+1:]
407
 
        self.bytes_left -= len(self._in_buffer)
408
 
        self.state_accept = self._state_accept_reading_body
409
 
        self.state_read = self._state_read_in_buffer
410
 
 
411
 
    def _state_accept_reading_body(self, bytes):
412
 
        self._in_buffer += bytes
413
 
        self.bytes_left -= len(bytes)
414
 
        if self.bytes_left <= 0:
415
 
            # Finished with body
416
 
            if self.bytes_left != 0:
417
 
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
418
 
                self._in_buffer = self._in_buffer[:self.bytes_left]
419
 
            self.bytes_left = None
420
 
            self.state_accept = self._state_accept_reading_trailer
421
 
        
422
 
    def _state_accept_reading_trailer(self, bytes):
423
 
        self._trailer_buffer += bytes
424
 
        # TODO: what if the trailer does not match "done\n"?  Should this raise
425
 
        # a ProtocolViolation exception?
426
 
        if self._trailer_buffer.startswith('done\n'):
427
 
            self.unused_data = self._trailer_buffer[len('done\n'):]
428
 
            self.state_accept = self._state_accept_reading_unused
429
 
            self.finished_reading = True
430
 
    
431
 
    def _state_accept_reading_unused(self, bytes):
432
 
        self.unused_data += bytes
433
 
 
434
 
    def _state_read_no_data(self):
435
 
        return ''
436
 
 
437
 
    def _state_read_in_buffer(self):
438
 
        result = self._in_buffer
439
 
        self._in_buffer = ''
440
 
        return result
441
 
 
442
 
 
443
 
class SmartServerResponse(object):
444
 
    """Response generated by SmartServerRequestHandler."""
445
 
 
446
 
    def __init__(self, args, body=None):
447
 
        self.args = args
448
 
        self.body = body
449
 
 
450
 
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
451
 
# for delivering the data for a request. This could be done with as the
452
 
# StreamServer, though that would create conflation between request and response
453
 
# which may be undesirable.
454
 
 
455
228
 
456
229
class SmartServerRequestHandler(object):
457
230
    """Protocol logic for smart server.
496
269
        
497
270
    def do_hello(self):
498
271
        """Answer a version request with my version."""
499
 
        return SmartServerResponse(('ok', '1'))
 
272
        return protocol.SmartServerResponse(('ok', '1'))
500
273
 
501
274
    def do_has(self, relpath):
502
275
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
503
 
        return SmartServerResponse((r,))
 
276
        return protocol.SmartServerResponse((r,))
504
277
 
505
278
    def do_get(self, relpath):
506
279
        backing_bytes = self._backing_transport.get_bytes(relpath)
507
 
        return SmartServerResponse(('ok',), backing_bytes)
 
280
        return protocol.SmartServerResponse(('ok',), backing_bytes)
508
281
 
509
282
    def _deserialise_optional_mode(self, mode):
510
283
        # XXX: FIXME this should be on the protocol object.
522
295
    def _handle_do_append_end(self):
523
296
        old_length = self._backing_transport.append_bytes(
524
297
            self._relpath, self._body_bytes, self._mode)
525
 
        self.response = SmartServerResponse(('appended', '%d' % old_length))
 
298
        self.response = protocol.SmartServerResponse(('appended', '%d' % old_length))
526
299
 
527
300
    def do_delete(self, relpath):
528
301
        self._backing_transport.delete(relpath)
530
303
    def do_iter_files_recursive(self, relpath):
531
304
        transport = self._backing_transport.clone(relpath)
532
305
        filenames = transport.iter_files_recursive()
533
 
        return SmartServerResponse(('names',) + tuple(filenames))
 
306
        return protocol.SmartServerResponse(('names',) + tuple(filenames))
534
307
 
535
308
    def do_list_dir(self, relpath):
536
309
        filenames = self._backing_transport.list_dir(relpath)
537
 
        return SmartServerResponse(('names',) + tuple(filenames))
 
310
        return protocol.SmartServerResponse(('names',) + tuple(filenames))
538
311
 
539
312
    def do_mkdir(self, relpath, mode):
540
313
        self._backing_transport.mkdir(relpath,
552
325
    def _handle_do_put(self):
553
326
        self._backing_transport.put_bytes(self._relpath,
554
327
                self._body_bytes, self._mode)
555
 
        self.response = SmartServerResponse(('ok',))
 
328
        self.response = protocol.SmartServerResponse(('ok',))
556
329
 
557
330
    def _deserialise_offsets(self, text):
558
331
        # XXX: FIXME this should be on the protocol object.
579
352
                mode=self._mode,
580
353
                create_parent_dir=self._create_parent,
581
354
                dir_mode=self._dir_mode)
582
 
        self.response = SmartServerResponse(('ok',))
 
355
        self.response = protocol.SmartServerResponse(('ok',))
583
356
 
584
357
    def do_readv(self, relpath):
585
358
        self._converted_command = True
597
370
        offsets = self._deserialise_offsets(self._body_bytes)
598
371
        backing_bytes = ''.join(bytes for offset, bytes in
599
372
            self._backing_transport.readv(self._relpath, offsets))
600
 
        self.response = SmartServerResponse(('readv',), backing_bytes)
 
373
        self.response = protocol.SmartServerResponse(('readv',), backing_bytes)
601
374
        
602
375
    def do_rename(self, rel_from, rel_to):
603
376
        self._backing_transport.rename(rel_from, rel_to)
607
380
 
608
381
    def do_stat(self, relpath):
609
382
        stat = self._backing_transport.stat(relpath)
610
 
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
 
383
        return protocol.SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
611
384
        
612
385
    def do_get_bundle(self, path, revision_id):
613
386
        # open transport relative to our base
618
391
        base_revision = revision.NULL_REVISION
619
392
        write_bundle(repo, revision_id, base_revision, tmpf)
620
393
        tmpf.seek(0)
621
 
        return SmartServerResponse((), tmpf.read())
 
394
        return protocol.SmartServerResponse((), tmpf.read())
622
395
 
623
396
    def dispatch_command(self, cmd, args):
624
397
        """Deprecated compatibility method.""" # XXX XXX
644
417
        if not self._converted_command:
645
418
            self.finished_reading = True
646
419
            if result is None:
647
 
                self.response = SmartServerResponse(('ok',))
 
420
                self.response = protocol.SmartServerResponse(('ok',))
648
421
 
649
422
    def _call_converting_errors(self, callable, args, kwargs):
650
423
        """Call callable converting errors to Response objects."""
651
424
        try:
652
425
            return callable(*args, **kwargs)
653
426
        except errors.NoSuchFile, e:
654
 
            return SmartServerResponse(('NoSuchFile', e.path))
 
427
            return protocol.SmartServerResponse(('NoSuchFile', e.path))
655
428
        except errors.FileExists, e:
656
 
            return SmartServerResponse(('FileExists', e.path))
 
429
            return protocol.SmartServerResponse(('FileExists', e.path))
657
430
        except errors.DirectoryNotEmpty, e:
658
 
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
 
431
            return protocol.SmartServerResponse(('DirectoryNotEmpty', e.path))
659
432
        except errors.ShortReadvError, e:
660
 
            return SmartServerResponse(('ShortReadvError',
 
433
            return protocol.SmartServerResponse(('ShortReadvError',
661
434
                e.path, str(e.offset), str(e.length), str(e.actual)))
662
435
        except UnicodeError, e:
663
436
            # If it is a DecodeError, than most likely we are starting
670
443
            else:
671
444
                val = 's:' + str_or_unicode.encode('base64')
672
445
            # This handles UnicodeEncodeError or UnicodeDecodeError
673
 
            return SmartServerResponse((e.__class__.__name__,
 
446
            return protocol.SmartServerResponse((e.__class__.__name__,
674
447
                    e.encoding, val, str(e.start), str(e.end), e.reason))
675
448
        except errors.TransportNotPossible, e:
676
449
            if e.msg == "readonly transport":
677
 
                return SmartServerResponse(('ReadOnlyError', ))
 
450
                return protocol.SmartServerResponse(('ReadOnlyError', ))
678
451
            else:
679
452
                raise
680
453
 
887
660
 
888
661
    def _call2(self, method, *args):
889
662
        """Call a method on the remote server."""
890
 
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
891
 
        protocol.call(method, *args)
892
 
        return protocol.read_response_tuple()
 
663
        request = self._medium.get_request()
 
664
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
665
        smart_protocol.call(method, *args)
 
666
        return smart_protocol.read_response_tuple()
893
667
 
894
668
    def _call_with_body_bytes(self, method, args, body):
895
669
        """Call a method on the remote server with body bytes."""
896
 
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
897
 
        protocol.call_with_body_bytes((method, ) + args, body)
898
 
        return protocol.read_response_tuple()
 
670
        request = self._medium.get_request()
 
671
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
672
        smart_protocol.call_with_body_bytes((method, ) + args, body)
 
673
        return smart_protocol.read_response_tuple()
899
674
 
900
675
    def has(self, relpath):
901
676
        """Indicate whether a remote file of the given name exists or not.
919
694
 
920
695
    def get_bytes(self, relpath):
921
696
        remote = self._remote_path(relpath)
922
 
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
923
 
        protocol.call('get', remote)
924
 
        resp = protocol.read_response_tuple(True)
 
697
        request = self._medium.get_request()
 
698
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
699
        smart_protocol.call('get', remote)
 
700
        resp = smart_protocol.read_response_tuple(True)
925
701
        if resp != ('ok', ):
926
 
            protocol.cancel_read_body()
 
702
            smart_protocol.cancel_read_body()
927
703
            self._translate_error(resp, relpath)
928
 
        return protocol.read_body_bytes()
 
704
        return smart_protocol.read_body_bytes()
929
705
 
930
706
    def _serialise_optional_mode(self, mode):
931
707
        if mode is None:
1011
787
                               limit=self._max_readv_combine,
1012
788
                               fudge_factor=self._bytes_to_read_before_seek))
1013
789
 
1014
 
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
1015
 
        protocol.call_with_body_readv_array(
 
790
        request = self._medium.get_request()
 
791
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
792
        smart_protocol.call_with_body_readv_array(
1016
793
            ('readv', self._remote_path(relpath)),
1017
794
            [(c.start, c.length) for c in coalesced])
1018
 
        resp = protocol.read_response_tuple(True)
 
795
        resp = smart_protocol.read_response_tuple(True)
1019
796
 
1020
797
        if resp[0] != 'readv':
1021
798
            # This should raise an exception
1022
 
            protocol.cancel_read_body()
 
799
            smart_protocol.cancel_read_body()
1023
800
            self._translate_error(resp)
1024
801
            return
1025
802
 
1026
803
        # FIXME: this should know how many bytes are needed, for clarity.
1027
 
        data = protocol.read_body_bytes()
 
804
        data = smart_protocol.read_body_bytes()
1028
805
        # Cache the results, but only until they have been fulfilled
1029
806
        data_map = {}
1030
807
        for c_offset in coalesced:
1141
918
            self._translate_error(resp)
1142
919
 
1143
920
 
1144
 
class SmartClientRequestProtocolOne(SmartProtocolBase):
1145
 
    """The client-side protocol for smart version 1."""
1146
 
 
1147
 
    def __init__(self, request):
1148
 
        """Construct a SmartClientRequestProtocolOne.
1149
 
 
1150
 
        :param request: A SmartClientMediumRequest to serialise onto and
1151
 
            deserialise from.
1152
 
        """
1153
 
        self._request = request
1154
 
        self._body_buffer = None
1155
 
 
1156
 
    def call(self, *args):
1157
 
        bytes = _encode_tuple(args)
1158
 
        self._request.accept_bytes(bytes)
1159
 
        self._request.finished_writing()
1160
 
 
1161
 
    def call_with_body_bytes(self, args, body):
1162
 
        """Make a remote call of args with body bytes 'body'.
1163
 
 
1164
 
        After calling this, call read_response_tuple to find the result out.
1165
 
        """
1166
 
        bytes = _encode_tuple(args)
1167
 
        self._request.accept_bytes(bytes)
1168
 
        bytes = self._encode_bulk_data(body)
1169
 
        self._request.accept_bytes(bytes)
1170
 
        self._request.finished_writing()
1171
 
 
1172
 
    def call_with_body_readv_array(self, args, body):
1173
 
        """Make a remote call with a readv array.
1174
 
 
1175
 
        The body is encoded with one line per readv offset pair. The numbers in
1176
 
        each pair are separated by a comma, and no trailing \n is emitted.
1177
 
        """
1178
 
        bytes = _encode_tuple(args)
1179
 
        self._request.accept_bytes(bytes)
1180
 
        readv_bytes = self._serialise_offsets(body)
1181
 
        bytes = self._encode_bulk_data(readv_bytes)
1182
 
        self._request.accept_bytes(bytes)
1183
 
        self._request.finished_writing()
1184
 
 
1185
 
    def cancel_read_body(self):
1186
 
        """After expecting a body, a response code may indicate one otherwise.
1187
 
 
1188
 
        This method lets the domain client inform the protocol that no body
1189
 
        will be transmitted. This is a terminal method: after calling it the
1190
 
        protocol is not able to be used further.
1191
 
        """
1192
 
        self._request.finished_reading()
1193
 
 
1194
 
    def read_response_tuple(self, expect_body=False):
1195
 
        """Read a response tuple from the wire.
1196
 
 
1197
 
        This should only be called once.
1198
 
        """
1199
 
        result = self._recv_tuple()
1200
 
        if not expect_body:
1201
 
            self._request.finished_reading()
1202
 
        return result
1203
 
 
1204
 
    def read_body_bytes(self, count=-1):
1205
 
        """Read bytes from the body, decoding into a byte stream.
1206
 
        
1207
 
        We read all bytes at once to ensure we've checked the trailer for 
1208
 
        errors, and then feed the buffer back as read_body_bytes is called.
1209
 
        """
1210
 
        if self._body_buffer is not None:
1211
 
            return self._body_buffer.read(count)
1212
 
        _body_decoder = LengthPrefixedBodyDecoder()
1213
 
 
1214
 
        while not _body_decoder.finished_reading:
1215
 
            bytes_wanted = _body_decoder.next_read_size()
1216
 
            bytes = self._request.read_bytes(bytes_wanted)
1217
 
            _body_decoder.accept_bytes(bytes)
1218
 
        self._request.finished_reading()
1219
 
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
1220
 
        # XXX: TODO check the trailer result.
1221
 
        return self._body_buffer.read(count)
1222
 
 
1223
 
    def _recv_tuple(self):
1224
 
        """Receive a tuple from the medium request."""
1225
 
        line = ''
1226
 
        while not line or line[-1] != '\n':
1227
 
            # TODO: this is inefficient - but tuples are short.
1228
 
            new_char = self._request.read_bytes(1)
1229
 
            line += new_char
1230
 
            assert new_char != '', "end of file reading from server."
1231
 
        return _decode_tuple(line)
1232
 
 
1233
 
    def query_version(self):
1234
 
        """Return protocol version number of the server."""
1235
 
        self.call('hello')
1236
 
        resp = self.read_response_tuple()
1237
 
        if resp == ('ok', '1'):
1238
 
            return 1
1239
 
        else:
1240
 
            raise errors.SmartProtocolError("bad response %r" % (resp,))
1241
 
 
1242
 
 
1243
921
class SmartTCPTransport(SmartTransport):
1244
922
    """Connection to smart server over plain tcp.
1245
923