/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Robert Collins
  • Date: 2010-05-06 11:08:10 UTC
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100506110810-h3j07fh5gmw54s25
Cleaner matcher matching revised unlocking protocol.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
 
25
import thread
 
26
import threading
25
27
import time
26
28
 
27
29
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
30
from bzrlib import (
 
31
    debug,
 
32
    errors,
 
33
    osutils,
 
34
    )
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode, bencode
 
37
from bzrlib.bencode import bdecode_as_tuple, bencode
33
38
 
34
39
 
35
40
# Protocol version strings.  These are sent as prefixes of bzr requests and
57
62
 
58
63
def _encode_tuple(args):
59
64
    """Encode the tuple args to a bytestream."""
60
 
    return '\x01'.join(args) + '\n'
 
65
    joined = '\x01'.join(args) + '\n'
 
66
    if type(joined) is unicode:
 
67
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
68
        mutter('response args contain unicode, should be only bytes: %r',
 
69
               joined)
 
70
        joined = joined.encode('ascii')
 
71
    return joined
61
72
 
62
73
 
63
74
class Requester(object):
109
120
        for start, length in offsets:
110
121
            txt.append('%d,%d' % (start, length))
111
122
        return '\n'.join(txt)
112
 
        
 
123
 
113
124
 
114
125
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
126
    """Server-side encoding and decoding logic for smart version 1."""
116
 
    
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
127
 
 
128
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
129
            jail_root=None):
118
130
        self._backing_transport = backing_transport
119
131
        self._root_client_path = root_client_path
 
132
        self._jail_root = jail_root
120
133
        self.unused_data = ''
121
134
        self._finished = False
122
135
        self.in_buffer = ''
127
140
 
128
141
    def accept_bytes(self, bytes):
129
142
        """Take bytes, and advance the internal state machine appropriately.
130
 
        
 
143
 
131
144
        :param bytes: must be a byte string
132
145
        """
133
146
        if not isinstance(bytes, str):
144
157
                req_args = _decode_tuple(first_line)
145
158
                self.request = request.SmartServerRequestHandler(
146
159
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
160
                    root_client_path=self._root_client_path,
 
161
                    jail_root=self._jail_root)
 
162
                self.request.args_received(req_args)
149
163
                if self.request.finished_reading:
150
164
                    # trivial request
151
165
                    self.unused_data = self.in_buffer
169
183
 
170
184
        if self._has_dispatched:
171
185
            if self._finished:
172
 
                # nothing to do.XXX: this routine should be a single state 
 
186
                # nothing to do.XXX: this routine should be a single state
173
187
                # machine too.
174
188
                self.unused_data += self.in_buffer
175
189
                self.in_buffer = ''
211
225
 
212
226
    def _write_protocol_version(self):
213
227
        """Write any prefixes this protocol requires.
214
 
        
 
228
 
215
229
        Version one doesn't send protocol versions.
216
230
        """
217
231
 
234
248
 
235
249
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
236
250
    r"""Version two of the server side of the smart protocol.
237
 
   
 
251
 
238
252
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
239
253
    """
240
254
 
250
264
 
251
265
    def _write_protocol_version(self):
252
266
        r"""Write any prefixes this protocol requires.
253
 
        
 
267
 
254
268
        Version two sends the value of RESPONSE_VERSION_TWO.
255
269
        """
256
270
        self._write_func(self.response_marker)
412
426
        self.chunks = collections.deque()
413
427
        self.error = False
414
428
        self.error_in_progress = None
415
 
    
 
429
 
416
430
    def next_read_size(self):
417
431
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
418
432
        # end-of-body marker is 4 bytes: 'END\n'.
506
520
                self.chunks.append(self.chunk_in_progress)
507
521
            self.chunk_in_progress = None
508
522
            self.state_accept = self._state_accept_expecting_length
509
 
        
 
523
 
510
524
    def _state_accept_reading_unused(self):
511
525
        self.unused_data += self._get_in_buffer()
512
526
        self._in_buffer_list = []
514
528
 
515
529
class LengthPrefixedBodyDecoder(_StatefulDecoder):
516
530
    """Decodes the length-prefixed bulk data."""
517
 
    
 
531
 
518
532
    def __init__(self):
519
533
        _StatefulDecoder.__init__(self)
520
534
        self.state_accept = self._state_accept_expecting_length
521
535
        self.state_read = self._state_read_no_data
522
536
        self._body = ''
523
537
        self._trailer_buffer = ''
524
 
    
 
538
 
525
539
    def next_read_size(self):
526
540
        if self.bytes_left is not None:
527
541
            # Ideally we want to read all the remainder of the body and the
537
551
        else:
538
552
            # Reading excess data.  Either way, 1 byte at a time is fine.
539
553
            return 1
540
 
        
 
554
 
541
555
    def read_pending_data(self):
542
556
        """Return any pending data that has been decoded."""
543
557
        return self.state_read()
564
578
                self._body = self._body[:self.bytes_left]
565
579
            self.bytes_left = None
566
580
            self.state_accept = self._state_accept_reading_trailer
567
 
        
 
581
 
568
582
    def _state_accept_reading_trailer(self):
569
583
        self._trailer_buffer += self._get_in_buffer()
570
584
        self._set_in_buffer(None)
574
588
            self.unused_data = self._trailer_buffer[len('done\n'):]
575
589
            self.state_accept = self._state_accept_reading_unused
576
590
            self.finished_reading = True
577
 
    
 
591
 
578
592
    def _state_accept_reading_unused(self):
579
593
        self.unused_data += self._get_in_buffer()
580
594
        self._set_in_buffer(None)
612
626
            mutter('hpss call:   %s', repr(args)[1:-1])
613
627
            if getattr(self._request._medium, 'base', None) is not None:
614
628
                mutter('             (to %s)', self._request._medium.base)
615
 
            self._request_start_time = time.time()
 
629
            self._request_start_time = osutils.timer_func()
616
630
        self._write_args(args)
617
631
        self._request.finished_writing()
618
632
        self._last_verb = args[0]
627
641
            if getattr(self._request._medium, '_path', None) is not None:
628
642
                mutter('                  (to %s)', self._request._medium._path)
629
643
            mutter('              %d bytes', len(body))
630
 
            self._request_start_time = time.time()
 
644
            self._request_start_time = osutils.timer_func()
631
645
            if 'hpssdetail' in debug.debug_flags:
632
646
                mutter('hpss body content: %s', body)
633
647
        self._write_args(args)
646
660
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
661
            if getattr(self._request._medium, '_path', None) is not None:
648
662
                mutter('                  (to %s)', self._request._medium._path)
649
 
            self._request_start_time = time.time()
 
663
            self._request_start_time = osutils.timer_func()
650
664
        self._write_args(args)
651
665
        readv_bytes = self._serialise_offsets(body)
652
666
        bytes = self._encode_bulk_data(readv_bytes)
656
670
            mutter('              %d bytes in readv request', len(readv_bytes))
657
671
        self._last_verb = args[0]
658
672
 
 
673
    def call_with_body_stream(self, args, stream):
 
674
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
675
        # assume that a v1/v2 server doesn't support whatever method we're
 
676
        # trying to call with a body stream.
 
677
        self._request.finished_writing()
 
678
        self._request.finished_reading()
 
679
        raise errors.UnknownSmartMethod(args[0])
 
680
 
659
681
    def cancel_read_body(self):
660
682
        """After expecting a body, a response code may indicate one otherwise.
661
683
 
670
692
        if 'hpss' in debug.debug_flags:
671
693
            if self._request_start_time is not None:
672
694
                mutter('   result:   %6.3fs  %s',
673
 
                       time.time() - self._request_start_time,
 
695
                       osutils.timer_func() - self._request_start_time,
674
696
                       repr(result)[1:-1])
675
697
                self._request_start_time = None
676
698
            else:
721
743
    def _response_is_unknown_method(self, result_tuple):
722
744
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
723
745
        method' response to the request.
724
 
        
 
746
 
725
747
        :param response: The response from a smart client call_expecting_body
726
748
            call.
727
749
        :param verb: The verb used in that call.
734
756
            # The response will have no body, so we've finished reading.
735
757
            self._request.finished_reading()
736
758
            raise errors.UnknownSmartMethod(self._last_verb)
737
 
        
 
759
 
738
760
    def read_body_bytes(self, count=-1):
739
761
        """Read bytes from the body, decoding into a byte stream.
740
 
        
741
 
        We read all bytes at once to ensure we've checked the trailer for 
 
762
 
 
763
        We read all bytes at once to ensure we've checked the trailer for
742
764
        errors, and then feed the buffer back as read_body_bytes is called.
743
765
        """
744
766
        if self._body_buffer is not None:
782
804
 
783
805
    def _write_protocol_version(self):
784
806
        """Write any prefixes this protocol requires.
785
 
        
 
807
 
786
808
        Version one doesn't send protocol versions.
787
809
        """
788
810
 
789
811
 
790
812
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
791
813
    """Version two of the client side of the smart protocol.
792
 
    
 
814
 
793
815
    This prefixes the request with the value of REQUEST_VERSION_TWO.
794
816
    """
795
817
 
823
845
 
824
846
    def _write_protocol_version(self):
825
847
        """Write any prefixes this protocol requires.
826
 
        
 
848
 
827
849
        Version two sends the value of REQUEST_VERSION_TWO.
828
850
        """
829
851
        self._request.accept_bytes(self.request_marker)
850
872
 
851
873
 
852
874
def build_server_protocol_three(backing_transport, write_func,
853
 
                                root_client_path):
 
875
                                root_client_path, jail_root=None):
854
876
    request_handler = request.SmartServerRequestHandler(
855
877
        backing_transport, commands=request.request_handlers,
856
 
        root_client_path=root_client_path)
 
878
        root_client_path=root_client_path, jail_root=jail_root)
857
879
    responder = ProtocolThreeResponder(write_func)
858
880
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
859
881
    return ProtocolThreeDecoder(message_handler)
889
911
            # We do *not* set self.decoding_failed here.  The message handler
890
912
            # has raised an error, but the decoder is still able to parse bytes
891
913
            # and determine when this message ends.
892
 
            log_exception_quietly()
 
914
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
915
                log_exception_quietly()
893
916
            self.message_handler.protocol_error(exception.exc_value)
894
917
            # The state machine is ready to continue decoding, but the
895
918
            # exception has interrupted the loop that runs the state machine.
931
954
    def _extract_prefixed_bencoded_data(self):
932
955
        prefixed_bytes = self._extract_length_prefixed_bytes()
933
956
        try:
934
 
            decoded = bdecode(prefixed_bytes)
 
957
            decoded = bdecode_as_tuple(prefixed_bytes)
935
958
        except ValueError:
936
959
            raise errors.SmartProtocolError(
937
960
                'Bytes %r not bencoded' % (prefixed_bytes,))
977
1000
            self.message_handler.headers_received(decoded)
978
1001
        except:
979
1002
            raise errors.SmartMessageHandlerError(sys.exc_info())
980
 
    
 
1003
 
981
1004
    def _state_accept_expecting_message_part(self):
982
1005
        message_part_kind = self._extract_single_byte()
983
1006
        if message_part_kind == 'o':
1028
1051
            raise errors.SmartMessageHandlerError(sys.exc_info())
1029
1052
 
1030
1053
    def _state_accept_reading_unused(self):
1031
 
        self.unused_data = self._get_in_buffer()
 
1054
        self.unused_data += self._get_in_buffer()
1032
1055
        self._set_in_buffer(None)
1033
1056
 
1034
1057
    def next_read_size(self):
1050
1073
class _ProtocolThreeEncoder(object):
1051
1074
 
1052
1075
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1076
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1053
1077
 
1054
1078
    def __init__(self, write_func):
1055
 
        self._buf = ''
 
1079
        self._buf = []
 
1080
        self._buf_len = 0
1056
1081
        self._real_write_func = write_func
1057
1082
 
1058
1083
    def _write_func(self, bytes):
1059
 
        self._buf += bytes
 
1084
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1085
        #       for total number of bytes to write, rather than buffer based on
 
1086
        #       the number of write() calls
 
1087
        # TODO: Another possibility would be to turn this into an async model.
 
1088
        #       Where we let another thread know that we have some bytes if
 
1089
        #       they want it, but we don't actually block for it
 
1090
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1091
        #       we might just push out smaller bits at a time?
 
1092
        self._buf.append(bytes)
 
1093
        self._buf_len += len(bytes)
 
1094
        if self._buf_len > self.BUFFER_SIZE:
 
1095
            self.flush()
1060
1096
 
1061
1097
    def flush(self):
1062
1098
        if self._buf:
1063
 
            self._real_write_func(self._buf)
1064
 
            self._buf = ''
 
1099
            self._real_write_func(''.join(self._buf))
 
1100
            del self._buf[:]
 
1101
            self._buf_len = 0
1065
1102
 
1066
1103
    def _serialise_offsets(self, offsets):
1067
1104
        """Serialise a readv offset list."""
1069
1106
        for start, length in offsets:
1070
1107
            txt.append('%d,%d' % (start, length))
1071
1108
        return '\n'.join(txt)
1072
 
        
 
1109
 
1073
1110
    def _write_protocol_version(self):
1074
1111
        self._write_func(MESSAGE_VERSION_THREE)
1075
1112
 
1100
1137
        self._write_func(struct.pack('!L', len(bytes)))
1101
1138
        self._write_func(bytes)
1102
1139
 
 
1140
    def _write_chunked_body_start(self):
 
1141
        self._write_func('oC')
 
1142
 
1103
1143
    def _write_error_status(self):
1104
1144
        self._write_func('oE')
1105
1145
 
1113
1153
        _ProtocolThreeEncoder.__init__(self, write_func)
1114
1154
        self.response_sent = False
1115
1155
        self._headers = {'Software version': bzrlib.__version__}
 
1156
        if 'hpss' in debug.debug_flags:
 
1157
            self._thread_id = thread.get_ident()
 
1158
            self._response_start_time = None
 
1159
 
 
1160
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1161
        if self._response_start_time is None:
 
1162
            self._response_start_time = osutils.timer_func()
 
1163
        if include_time:
 
1164
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1165
        else:
 
1166
            t = ''
 
1167
        if extra_bytes is None:
 
1168
            extra = ''
 
1169
        else:
 
1170
            extra = ' ' + repr(extra_bytes[:40])
 
1171
            if len(extra) > 33:
 
1172
                extra = extra[:29] + extra[-1] + '...'
 
1173
        mutter('%12s: [%s] %s%s%s'
 
1174
               % (action, self._thread_id, t, message, extra))
1116
1175
 
1117
1176
    def send_error(self, exception):
1118
1177
        if self.response_sent:
1124
1183
                ('UnknownMethod', exception.verb))
1125
1184
            self.send_response(failure)
1126
1185
            return
 
1186
        if 'hpss' in debug.debug_flags:
 
1187
            self._trace('error', str(exception))
1127
1188
        self.response_sent = True
1128
1189
        self._write_protocol_version()
1129
1190
        self._write_headers(self._headers)
1143
1204
            self._write_success_status()
1144
1205
        else:
1145
1206
            self._write_error_status()
 
1207
        if 'hpss' in debug.debug_flags:
 
1208
            self._trace('response', repr(response.args))
1146
1209
        self._write_structure(response.args)
1147
1210
        if response.body is not None:
1148
1211
            self._write_prefixed_body(response.body)
 
1212
            if 'hpss' in debug.debug_flags:
 
1213
                self._trace('body', '%d bytes' % (len(response.body),),
 
1214
                            response.body, include_time=True)
1149
1215
        elif response.body_stream is not None:
1150
 
            for chunk in response.body_stream:
1151
 
                self._write_prefixed_body(chunk)
1152
 
                self.flush()
 
1216
            count = num_bytes = 0
 
1217
            first_chunk = None
 
1218
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1219
                count += 1
 
1220
                if exc_info is not None:
 
1221
                    self._write_error_status()
 
1222
                    error_struct = request._translate_error(exc_info[1])
 
1223
                    self._write_structure(error_struct)
 
1224
                    break
 
1225
                else:
 
1226
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1227
                        self._write_error_status()
 
1228
                        self._write_structure(chunk.args)
 
1229
                        break
 
1230
                    num_bytes += len(chunk)
 
1231
                    if first_chunk is None:
 
1232
                        first_chunk = chunk
 
1233
                    self._write_prefixed_body(chunk)
 
1234
                    if 'hpssdetail' in debug.debug_flags:
 
1235
                        # Not worth timing separately, as _write_func is
 
1236
                        # actually buffered
 
1237
                        self._trace('body chunk',
 
1238
                                    '%d bytes' % (len(chunk),),
 
1239
                                    chunk, suppress_time=True)
 
1240
            if 'hpss' in debug.debug_flags:
 
1241
                self._trace('body stream',
 
1242
                            '%d bytes %d chunks' % (num_bytes, count),
 
1243
                            first_chunk)
1153
1244
        self._write_end()
1154
 
        
 
1245
        if 'hpss' in debug.debug_flags:
 
1246
            self._trace('response end', '', include_time=True)
 
1247
 
 
1248
 
 
1249
def _iter_with_errors(iterable):
 
1250
    """Handle errors from iterable.next().
 
1251
 
 
1252
    Use like::
 
1253
 
 
1254
        for exc_info, value in _iter_with_errors(iterable):
 
1255
            ...
 
1256
 
 
1257
    This is a safer alternative to::
 
1258
 
 
1259
        try:
 
1260
            for value in iterable:
 
1261
               ...
 
1262
        except:
 
1263
            ...
 
1264
 
 
1265
    Because the latter will catch errors from the for-loop body, not just
 
1266
    iterable.next()
 
1267
 
 
1268
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1269
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1270
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1271
    will not be itercepted.
 
1272
    """
 
1273
    iterator = iter(iterable)
 
1274
    while True:
 
1275
        try:
 
1276
            yield None, iterator.next()
 
1277
        except StopIteration:
 
1278
            return
 
1279
        except (KeyboardInterrupt, SystemExit):
 
1280
            raise
 
1281
        except Exception:
 
1282
            mutter('_iter_with_errors caught error')
 
1283
            log_exception_quietly()
 
1284
            yield sys.exc_info(), None
 
1285
            return
 
1286
 
1155
1287
 
1156
1288
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1157
1289
 
1162
1294
 
1163
1295
    def set_headers(self, headers):
1164
1296
        self._headers = headers.copy()
1165
 
        
 
1297
 
1166
1298
    def call(self, *args):
1167
1299
        if 'hpss' in debug.debug_flags:
1168
1300
            mutter('hpss call:   %s', repr(args)[1:-1])
1169
1301
            base = getattr(self._medium_request._medium, 'base', None)
1170
1302
            if base is not None:
1171
1303
                mutter('             (to %s)', base)
1172
 
            self._request_start_time = time.time()
 
1304
            self._request_start_time = osutils.timer_func()
1173
1305
        self._write_protocol_version()
1174
1306
        self._write_headers(self._headers)
1175
1307
        self._write_structure(args)
1187
1319
            if path is not None:
1188
1320
                mutter('                  (to %s)', path)
1189
1321
            mutter('              %d bytes', len(body))
1190
 
            self._request_start_time = time.time()
 
1322
            self._request_start_time = osutils.timer_func()
1191
1323
        self._write_protocol_version()
1192
1324
        self._write_headers(self._headers)
1193
1325
        self._write_structure(args)
1206
1338
            path = getattr(self._medium_request._medium, '_path', None)
1207
1339
            if path is not None:
1208
1340
                mutter('                  (to %s)', path)
1209
 
            self._request_start_time = time.time()
 
1341
            self._request_start_time = osutils.timer_func()
1210
1342
        self._write_protocol_version()
1211
1343
        self._write_headers(self._headers)
1212
1344
        self._write_structure(args)
1217
1349
        self._write_end()
1218
1350
        self._medium_request.finished_writing()
1219
1351
 
 
1352
    def call_with_body_stream(self, args, stream):
 
1353
        if 'hpss' in debug.debug_flags:
 
1354
            mutter('hpss call w/body stream: %r', args)
 
1355
            path = getattr(self._medium_request._medium, '_path', None)
 
1356
            if path is not None:
 
1357
                mutter('                  (to %s)', path)
 
1358
            self._request_start_time = osutils.timer_func()
 
1359
        self._write_protocol_version()
 
1360
        self._write_headers(self._headers)
 
1361
        self._write_structure(args)
 
1362
        # TODO: notice if the server has sent an early error reply before we
 
1363
        #       have finished sending the stream.  We would notice at the end
 
1364
        #       anyway, but if the medium can deliver it early then it's good
 
1365
        #       to short-circuit the whole request...
 
1366
        for exc_info, part in _iter_with_errors(stream):
 
1367
            if exc_info is not None:
 
1368
                # Iterating the stream failed.  Cleanly abort the request.
 
1369
                self._write_error_status()
 
1370
                # Currently the client unconditionally sends ('error',) as the
 
1371
                # error args.
 
1372
                self._write_structure(('error',))
 
1373
                self._write_end()
 
1374
                self._medium_request.finished_writing()
 
1375
                raise exc_info[0], exc_info[1], exc_info[2]
 
1376
            else:
 
1377
                self._write_prefixed_body(part)
 
1378
                self.flush()
 
1379
        self._write_end()
 
1380
        self._medium_request.finished_writing()
 
1381