/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

remove all trailing whitespace from bzr source

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
25
 
import thread
26
 
import threading
27
25
import time
28
26
 
29
27
import bzrlib
30
 
from bzrlib import (
31
 
    debug,
32
 
    errors,
33
 
    osutils,
34
 
    )
 
28
from bzrlib import debug
 
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
 
from bzrlib.bencode import bdecode_as_tuple, bencode
 
32
from bzrlib.util.bencode import bdecode, bencode
38
33
 
39
34
 
40
35
# Protocol version strings.  These are sent as prefixes of bzr requests and
62
57
 
63
58
def _encode_tuple(args):
64
59
    """Encode the tuple args to a bytestream."""
65
 
    joined = '\x01'.join(args) + '\n'
66
 
    if type(joined) is unicode:
67
 
        # XXX: We should fix things so this never happens!  -AJB, 20100304
68
 
        mutter('response args contain unicode, should be only bytes: %r',
69
 
               joined)
70
 
        joined = joined.encode('ascii')
71
 
    return joined
 
60
    return '\x01'.join(args) + '\n'
72
61
 
73
62
 
74
63
class Requester(object):
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
    """Server-side encoding and decoding logic for smart version 1."""
127
116
 
128
 
    def __init__(self, backing_transport, write_func, root_client_path='/',
129
 
            jail_root=None):
 
117
    def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
        self._backing_transport = backing_transport
131
119
        self._root_client_path = root_client_path
132
 
        self._jail_root = jail_root
133
120
        self.unused_data = ''
134
121
        self._finished = False
135
122
        self.in_buffer = ''
157
144
                req_args = _decode_tuple(first_line)
158
145
                self.request = request.SmartServerRequestHandler(
159
146
                    self._backing_transport, commands=request.request_handlers,
160
 
                    root_client_path=self._root_client_path,
161
 
                    jail_root=self._jail_root)
162
 
                self.request.args_received(req_args)
 
147
                    root_client_path=self._root_client_path)
 
148
                self.request.dispatch_command(req_args[0], req_args[1:])
163
149
                if self.request.finished_reading:
164
150
                    # trivial request
165
151
                    self.unused_data = self.in_buffer
626
612
            mutter('hpss call:   %s', repr(args)[1:-1])
627
613
            if getattr(self._request._medium, 'base', None) is not None:
628
614
                mutter('             (to %s)', self._request._medium.base)
629
 
            self._request_start_time = osutils.timer_func()
 
615
            self._request_start_time = time.time()
630
616
        self._write_args(args)
631
617
        self._request.finished_writing()
632
618
        self._last_verb = args[0]
641
627
            if getattr(self._request._medium, '_path', None) is not None:
642
628
                mutter('                  (to %s)', self._request._medium._path)
643
629
            mutter('              %d bytes', len(body))
644
 
            self._request_start_time = osutils.timer_func()
 
630
            self._request_start_time = time.time()
645
631
            if 'hpssdetail' in debug.debug_flags:
646
632
                mutter('hpss body content: %s', body)
647
633
        self._write_args(args)
660
646
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
661
647
            if getattr(self._request._medium, '_path', None) is not None:
662
648
                mutter('                  (to %s)', self._request._medium._path)
663
 
            self._request_start_time = osutils.timer_func()
 
649
            self._request_start_time = time.time()
664
650
        self._write_args(args)
665
651
        readv_bytes = self._serialise_offsets(body)
666
652
        bytes = self._encode_bulk_data(readv_bytes)
670
656
            mutter('              %d bytes in readv request', len(readv_bytes))
671
657
        self._last_verb = args[0]
672
658
 
673
 
    def call_with_body_stream(self, args, stream):
674
 
        # Protocols v1 and v2 don't support body streams.  So it's safe to
675
 
        # assume that a v1/v2 server doesn't support whatever method we're
676
 
        # trying to call with a body stream.
677
 
        self._request.finished_writing()
678
 
        self._request.finished_reading()
679
 
        raise errors.UnknownSmartMethod(args[0])
680
 
 
681
659
    def cancel_read_body(self):
682
660
        """After expecting a body, a response code may indicate one otherwise.
683
661
 
692
670
        if 'hpss' in debug.debug_flags:
693
671
            if self._request_start_time is not None:
694
672
                mutter('   result:   %6.3fs  %s',
695
 
                       osutils.timer_func() - self._request_start_time,
 
673
                       time.time() - self._request_start_time,
696
674
                       repr(result)[1:-1])
697
675
                self._request_start_time = None
698
676
            else:
872
850
 
873
851
 
874
852
def build_server_protocol_three(backing_transport, write_func,
875
 
                                root_client_path, jail_root=None):
 
853
                                root_client_path):
876
854
    request_handler = request.SmartServerRequestHandler(
877
855
        backing_transport, commands=request.request_handlers,
878
 
        root_client_path=root_client_path, jail_root=jail_root)
 
856
        root_client_path=root_client_path)
879
857
    responder = ProtocolThreeResponder(write_func)
880
858
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
859
    return ProtocolThreeDecoder(message_handler)
911
889
            # We do *not* set self.decoding_failed here.  The message handler
912
890
            # has raised an error, but the decoder is still able to parse bytes
913
891
            # and determine when this message ends.
914
 
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
 
                log_exception_quietly()
 
892
            log_exception_quietly()
916
893
            self.message_handler.protocol_error(exception.exc_value)
917
894
            # The state machine is ready to continue decoding, but the
918
895
            # exception has interrupted the loop that runs the state machine.
954
931
    def _extract_prefixed_bencoded_data(self):
955
932
        prefixed_bytes = self._extract_length_prefixed_bytes()
956
933
        try:
957
 
            decoded = bdecode_as_tuple(prefixed_bytes)
 
934
            decoded = bdecode(prefixed_bytes)
958
935
        except ValueError:
959
936
            raise errors.SmartProtocolError(
960
937
                'Bytes %r not bencoded' % (prefixed_bytes,))
1051
1028
            raise errors.SmartMessageHandlerError(sys.exc_info())
1052
1029
 
1053
1030
    def _state_accept_reading_unused(self):
1054
 
        self.unused_data += self._get_in_buffer()
 
1031
        self.unused_data = self._get_in_buffer()
1055
1032
        self._set_in_buffer(None)
1056
1033
 
1057
1034
    def next_read_size(self):
1073
1050
class _ProtocolThreeEncoder(object):
1074
1051
 
1075
1052
    response_marker = request_marker = MESSAGE_VERSION_THREE
1076
 
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1077
1053
 
1078
1054
    def __init__(self, write_func):
1079
 
        self._buf = []
1080
 
        self._buf_len = 0
 
1055
        self._buf = ''
1081
1056
        self._real_write_func = write_func
1082
1057
 
1083
1058
    def _write_func(self, bytes):
1084
 
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
 
        #       for total number of bytes to write, rather than buffer based on
1086
 
        #       the number of write() calls
1087
 
        # TODO: Another possibility would be to turn this into an async model.
1088
 
        #       Where we let another thread know that we have some bytes if
1089
 
        #       they want it, but we don't actually block for it
1090
 
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
1091
 
        #       we might just push out smaller bits at a time?
1092
 
        self._buf.append(bytes)
1093
 
        self._buf_len += len(bytes)
1094
 
        if self._buf_len > self.BUFFER_SIZE:
1095
 
            self.flush()
 
1059
        self._buf += bytes
1096
1060
 
1097
1061
    def flush(self):
1098
1062
        if self._buf:
1099
 
            self._real_write_func(''.join(self._buf))
1100
 
            del self._buf[:]
1101
 
            self._buf_len = 0
 
1063
            self._real_write_func(self._buf)
 
1064
            self._buf = ''
1102
1065
 
1103
1066
    def _serialise_offsets(self, offsets):
1104
1067
        """Serialise a readv offset list."""
1137
1100
        self._write_func(struct.pack('!L', len(bytes)))
1138
1101
        self._write_func(bytes)
1139
1102
 
1140
 
    def _write_chunked_body_start(self):
1141
 
        self._write_func('oC')
1142
 
 
1143
1103
    def _write_error_status(self):
1144
1104
        self._write_func('oE')
1145
1105
 
1153
1113
        _ProtocolThreeEncoder.__init__(self, write_func)
1154
1114
        self.response_sent = False
1155
1115
        self._headers = {'Software version': bzrlib.__version__}
1156
 
        if 'hpss' in debug.debug_flags:
1157
 
            self._thread_id = thread.get_ident()
1158
 
            self._response_start_time = None
1159
 
 
1160
 
    def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
 
        if self._response_start_time is None:
1162
 
            self._response_start_time = osutils.timer_func()
1163
 
        if include_time:
1164
 
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
 
        else:
1166
 
            t = ''
1167
 
        if extra_bytes is None:
1168
 
            extra = ''
1169
 
        else:
1170
 
            extra = ' ' + repr(extra_bytes[:40])
1171
 
            if len(extra) > 33:
1172
 
                extra = extra[:29] + extra[-1] + '...'
1173
 
        mutter('%12s: [%s] %s%s%s'
1174
 
               % (action, self._thread_id, t, message, extra))
1175
1116
 
1176
1117
    def send_error(self, exception):
1177
1118
        if self.response_sent:
1183
1124
                ('UnknownMethod', exception.verb))
1184
1125
            self.send_response(failure)
1185
1126
            return
1186
 
        if 'hpss' in debug.debug_flags:
1187
 
            self._trace('error', str(exception))
1188
1127
        self.response_sent = True
1189
1128
        self._write_protocol_version()
1190
1129
        self._write_headers(self._headers)
1204
1143
            self._write_success_status()
1205
1144
        else:
1206
1145
            self._write_error_status()
1207
 
        if 'hpss' in debug.debug_flags:
1208
 
            self._trace('response', repr(response.args))
1209
1146
        self._write_structure(response.args)
1210
1147
        if response.body is not None:
1211
1148
            self._write_prefixed_body(response.body)
1212
 
            if 'hpss' in debug.debug_flags:
1213
 
                self._trace('body', '%d bytes' % (len(response.body),),
1214
 
                            response.body, include_time=True)
1215
1149
        elif response.body_stream is not None:
1216
 
            count = num_bytes = 0
1217
 
            first_chunk = None
1218
 
            for exc_info, chunk in _iter_with_errors(response.body_stream):
1219
 
                count += 1
1220
 
                if exc_info is not None:
1221
 
                    self._write_error_status()
1222
 
                    error_struct = request._translate_error(exc_info[1])
1223
 
                    self._write_structure(error_struct)
1224
 
                    break
1225
 
                else:
1226
 
                    if isinstance(chunk, request.FailedSmartServerResponse):
1227
 
                        self._write_error_status()
1228
 
                        self._write_structure(chunk.args)
1229
 
                        break
1230
 
                    num_bytes += len(chunk)
1231
 
                    if first_chunk is None:
1232
 
                        first_chunk = chunk
1233
 
                    self._write_prefixed_body(chunk)
1234
 
                    if 'hpssdetail' in debug.debug_flags:
1235
 
                        # Not worth timing separately, as _write_func is
1236
 
                        # actually buffered
1237
 
                        self._trace('body chunk',
1238
 
                                    '%d bytes' % (len(chunk),),
1239
 
                                    chunk, suppress_time=True)
1240
 
            if 'hpss' in debug.debug_flags:
1241
 
                self._trace('body stream',
1242
 
                            '%d bytes %d chunks' % (num_bytes, count),
1243
 
                            first_chunk)
 
1150
            for chunk in response.body_stream:
 
1151
                self._write_prefixed_body(chunk)
 
1152
                self.flush()
1244
1153
        self._write_end()
1245
 
        if 'hpss' in debug.debug_flags:
1246
 
            self._trace('response end', '', include_time=True)
1247
 
 
1248
 
 
1249
 
def _iter_with_errors(iterable):
1250
 
    """Handle errors from iterable.next().
1251
 
 
1252
 
    Use like::
1253
 
 
1254
 
        for exc_info, value in _iter_with_errors(iterable):
1255
 
            ...
1256
 
 
1257
 
    This is a safer alternative to::
1258
 
 
1259
 
        try:
1260
 
            for value in iterable:
1261
 
               ...
1262
 
        except:
1263
 
            ...
1264
 
 
1265
 
    Because the latter will catch errors from the for-loop body, not just
1266
 
    iterable.next()
1267
 
 
1268
 
    If an error occurs, exc_info will be a exc_info tuple, and the generator
1269
 
    will terminate.  Otherwise exc_info will be None, and value will be the
1270
 
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
1271
 
    will not be itercepted.
1272
 
    """
1273
 
    iterator = iter(iterable)
1274
 
    while True:
1275
 
        try:
1276
 
            yield None, iterator.next()
1277
 
        except StopIteration:
1278
 
            return
1279
 
        except (KeyboardInterrupt, SystemExit):
1280
 
            raise
1281
 
        except Exception:
1282
 
            mutter('_iter_with_errors caught error')
1283
 
            log_exception_quietly()
1284
 
            yield sys.exc_info(), None
1285
 
            return
1286
1154
 
1287
1155
 
1288
1156
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1301
1169
            base = getattr(self._medium_request._medium, 'base', None)
1302
1170
            if base is not None:
1303
1171
                mutter('             (to %s)', base)
1304
 
            self._request_start_time = osutils.timer_func()
 
1172
            self._request_start_time = time.time()
1305
1173
        self._write_protocol_version()
1306
1174
        self._write_headers(self._headers)
1307
1175
        self._write_structure(args)
1319
1187
            if path is not None:
1320
1188
                mutter('                  (to %s)', path)
1321
1189
            mutter('              %d bytes', len(body))
1322
 
            self._request_start_time = osutils.timer_func()
 
1190
            self._request_start_time = time.time()
1323
1191
        self._write_protocol_version()
1324
1192
        self._write_headers(self._headers)
1325
1193
        self._write_structure(args)
1338
1206
            path = getattr(self._medium_request._medium, '_path', None)
1339
1207
            if path is not None:
1340
1208
                mutter('                  (to %s)', path)
1341
 
            self._request_start_time = osutils.timer_func()
 
1209
            self._request_start_time = time.time()
1342
1210
        self._write_protocol_version()
1343
1211
        self._write_headers(self._headers)
1344
1212
        self._write_structure(args)
1349
1217
        self._write_end()
1350
1218
        self._medium_request.finished_writing()
1351
1219
 
1352
 
    def call_with_body_stream(self, args, stream):
1353
 
        if 'hpss' in debug.debug_flags:
1354
 
            mutter('hpss call w/body stream: %r', args)
1355
 
            path = getattr(self._medium_request._medium, '_path', None)
1356
 
            if path is not None:
1357
 
                mutter('                  (to %s)', path)
1358
 
            self._request_start_time = osutils.timer_func()
1359
 
        self._write_protocol_version()
1360
 
        self._write_headers(self._headers)
1361
 
        self._write_structure(args)
1362
 
        # TODO: notice if the server has sent an early error reply before we
1363
 
        #       have finished sending the stream.  We would notice at the end
1364
 
        #       anyway, but if the medium can deliver it early then it's good
1365
 
        #       to short-circuit the whole request...
1366
 
        for exc_info, part in _iter_with_errors(stream):
1367
 
            if exc_info is not None:
1368
 
                # Iterating the stream failed.  Cleanly abort the request.
1369
 
                self._write_error_status()
1370
 
                # Currently the client unconditionally sends ('error',) as the
1371
 
                # error args.
1372
 
                self._write_structure(('error',))
1373
 
                self._write_end()
1374
 
                self._medium_request.finished_writing()
1375
 
                raise exc_info[0], exc_info[1], exc_info[2]
1376
 
            else:
1377
 
                self._write_prefixed_body(part)
1378
 
                self.flush()
1379
 
        self._write_end()
1380
 
        self._medium_request.finished_writing()
1381