1
# Copyright (C) 2006-2010 Canonical Ltd
1
# Copyright (C) 2006, 2007 Canonical Ltd
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
from bzrlib.bencode import bdecode_as_tuple, bencode
32
from bzrlib.util.bencode import bdecode_as_tuple, bencode
40
35
# Protocol version strings. These are sent as prefixes of bzr requests and
63
58
def _encode_tuple(args):
64
59
"""Encode the tuple args to a bytestream."""
65
joined = '\x01'.join(args) + '\n'
66
if type(joined) is unicode:
67
# XXX: We should fix things so this never happens! -AJB, 20100304
68
mutter('response args contain unicode, should be only bytes: %r',
70
joined = joined.encode('ascii')
60
return '\x01'.join(args) + '\n'
74
63
class Requester(object):
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
"""Server-side encoding and decoding logic for smart version 1."""
128
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
self._backing_transport = backing_transport
131
119
self._root_client_path = root_client_path
132
self._jail_root = jail_root
133
120
self.unused_data = ''
134
121
self._finished = False
135
122
self.in_buffer = ''
157
144
req_args = _decode_tuple(first_line)
158
145
self.request = request.SmartServerRequestHandler(
159
146
self._backing_transport, commands=request.request_handlers,
160
root_client_path=self._root_client_path,
161
jail_root=self._jail_root)
162
self.request.args_received(req_args)
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
163
149
if self.request.finished_reading:
164
150
# trivial request
165
151
self.unused_data = self.in_buffer
626
612
mutter('hpss call: %s', repr(args)[1:-1])
627
613
if getattr(self._request._medium, 'base', None) is not None:
628
614
mutter(' (to %s)', self._request._medium.base)
629
self._request_start_time = osutils.timer_func()
615
self._request_start_time = time.time()
630
616
self._write_args(args)
631
617
self._request.finished_writing()
632
618
self._last_verb = args[0]
641
627
if getattr(self._request._medium, '_path', None) is not None:
642
628
mutter(' (to %s)', self._request._medium._path)
643
629
mutter(' %d bytes', len(body))
644
self._request_start_time = osutils.timer_func()
630
self._request_start_time = time.time()
645
631
if 'hpssdetail' in debug.debug_flags:
646
632
mutter('hpss body content: %s', body)
647
633
self._write_args(args)
660
646
mutter('hpss call w/readv: %s', repr(args)[1:-1])
661
647
if getattr(self._request._medium, '_path', None) is not None:
662
648
mutter(' (to %s)', self._request._medium._path)
663
self._request_start_time = osutils.timer_func()
649
self._request_start_time = time.time()
664
650
self._write_args(args)
665
651
readv_bytes = self._serialise_offsets(body)
666
652
bytes = self._encode_bulk_data(readv_bytes)
692
678
if 'hpss' in debug.debug_flags:
693
679
if self._request_start_time is not None:
694
680
mutter(' result: %6.3fs %s',
695
osutils.timer_func() - self._request_start_time,
681
time.time() - self._request_start_time,
696
682
repr(result)[1:-1])
697
683
self._request_start_time = None
874
860
def build_server_protocol_three(backing_transport, write_func,
875
root_client_path, jail_root=None):
876
862
request_handler = request.SmartServerRequestHandler(
877
863
backing_transport, commands=request.request_handlers,
878
root_client_path=root_client_path, jail_root=jail_root)
864
root_client_path=root_client_path)
879
865
responder = ProtocolThreeResponder(write_func)
880
866
message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
867
return ProtocolThreeDecoder(message_handler)
911
897
# We do *not* set self.decoding_failed here. The message handler
912
898
# has raised an error, but the decoder is still able to parse bytes
913
899
# and determine when this message ends.
914
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
log_exception_quietly()
900
log_exception_quietly()
916
901
self.message_handler.protocol_error(exception.exc_value)
917
902
# The state machine is ready to continue decoding, but the
918
903
# exception has interrupted the loop that runs the state machine.
1051
1036
raise errors.SmartMessageHandlerError(sys.exc_info())
1053
1038
def _state_accept_reading_unused(self):
1054
self.unused_data += self._get_in_buffer()
1039
self.unused_data = self._get_in_buffer()
1055
1040
self._set_in_buffer(None)
1057
1042
def next_read_size(self):
1073
1058
class _ProtocolThreeEncoder(object):
1075
1060
response_marker = request_marker = MESSAGE_VERSION_THREE
1076
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1078
1062
def __init__(self, write_func):
1081
1064
self._real_write_func = write_func
1083
1066
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
# TODO: Another possibility would be to turn this into an async model.
1088
# Where we let another thread know that we have some bytes if
1089
# they want it, but we don't actually block for it
1090
# Note that osutils.send_all always sends 64kB chunks anyway, so
1091
# we might just push out smaller bits at a time?
1092
self._buf.append(bytes)
1093
self._buf_len += len(bytes)
1094
if self._buf_len > self.BUFFER_SIZE:
1097
1069
def flush(self):
1099
self._real_write_func(''.join(self._buf))
1071
self._real_write_func(self._buf)
1103
1074
def _serialise_offsets(self, offsets):
1104
1075
"""Serialise a readv offset list."""
1153
1124
_ProtocolThreeEncoder.__init__(self, write_func)
1154
1125
self.response_sent = False
1155
1126
self._headers = {'Software version': bzrlib.__version__}
1156
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1158
self._response_start_time = None
1160
def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
if self._response_start_time is None:
1162
self._response_start_time = osutils.timer_func()
1164
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1167
if extra_bytes is None:
1170
extra = ' ' + repr(extra_bytes[:40])
1172
extra = extra[:29] + extra[-1] + '...'
1173
mutter('%12s: [%s] %s%s%s'
1174
% (action, self._thread_id, t, message, extra))
1176
1128
def send_error(self, exception):
1177
1129
if self.response_sent:
1183
1135
('UnknownMethod', exception.verb))
1184
1136
self.send_response(failure)
1186
if 'hpss' in debug.debug_flags:
1187
self._trace('error', str(exception))
1188
1138
self.response_sent = True
1189
1139
self._write_protocol_version()
1190
1140
self._write_headers(self._headers)
1204
1154
self._write_success_status()
1206
1156
self._write_error_status()
1207
if 'hpss' in debug.debug_flags:
1208
self._trace('response', repr(response.args))
1209
1157
self._write_structure(response.args)
1210
1158
if response.body is not None:
1211
1159
self._write_prefixed_body(response.body)
1212
if 'hpss' in debug.debug_flags:
1213
self._trace('body', '%d bytes' % (len(response.body),),
1214
response.body, include_time=True)
1215
1160
elif response.body_stream is not None:
1216
count = num_bytes = 0
1218
1161
for exc_info, chunk in _iter_with_errors(response.body_stream):
1220
1162
if exc_info is not None:
1221
1163
self._write_error_status()
1222
1164
error_struct = request._translate_error(exc_info[1])
1223
1165
self._write_structure(error_struct)
1226
if isinstance(chunk, request.FailedSmartServerResponse):
1227
self._write_error_status()
1228
self._write_structure(chunk.args)
1230
num_bytes += len(chunk)
1231
if first_chunk is None:
1233
1168
self._write_prefixed_body(chunk)
1234
if 'hpssdetail' in debug.debug_flags:
1235
# Not worth timing separately, as _write_func is
1237
self._trace('body chunk',
1238
'%d bytes' % (len(chunk),),
1239
chunk, suppress_time=True)
1240
if 'hpss' in debug.debug_flags:
1241
self._trace('body stream',
1242
'%d bytes %d chunks' % (num_bytes, count),
1244
1170
self._write_end()
1245
if 'hpss' in debug.debug_flags:
1246
self._trace('response end', '', include_time=True)
1249
1173
def _iter_with_errors(iterable):
1301
1223
base = getattr(self._medium_request._medium, 'base', None)
1302
1224
if base is not None:
1303
1225
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1226
self._request_start_time = time.time()
1305
1227
self._write_protocol_version()
1306
1228
self._write_headers(self._headers)
1307
1229
self._write_structure(args)
1319
1241
if path is not None:
1320
1242
mutter(' (to %s)', path)
1321
1243
mutter(' %d bytes', len(body))
1322
self._request_start_time = osutils.timer_func()
1244
self._request_start_time = time.time()
1323
1245
self._write_protocol_version()
1324
1246
self._write_headers(self._headers)
1325
1247
self._write_structure(args)
1338
1260
path = getattr(self._medium_request._medium, '_path', None)
1339
1261
if path is not None:
1340
1262
mutter(' (to %s)', path)
1341
self._request_start_time = osutils.timer_func()
1263
self._request_start_time = time.time()
1342
1264
self._write_protocol_version()
1343
1265
self._write_headers(self._headers)
1344
1266
self._write_structure(args)
1355
1277
path = getattr(self._medium_request._medium, '_path', None)
1356
1278
if path is not None:
1357
1279
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1280
self._request_start_time = time.time()
1359
1281
self._write_protocol_version()
1360
1282
self._write_headers(self._headers)
1361
1283
self._write_structure(args)