22
from collections.abc import deque
23
except ImportError: # python < 3.7
24
from collections import deque
26
from io import BytesIO
22
from cStringIO import StringIO
38
from . import message, request
39
from ...trace import log_exception_quietly, mutter
40
from ...bencode import bdecode_as_tuple, bencode
28
from bzrlib import debug
29
from bzrlib import errors
30
from bzrlib.smart import message, request
31
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.bencode import bdecode_as_tuple, bencode
43
35
# Protocol version strings. These are sent as prefixes of bzr requests and
44
36
# responses to identify the protocol version being used. (There are no version
45
37
# one strings because that version doesn't send any).
46
REQUEST_VERSION_TWO = b'bzr request 2\n'
47
RESPONSE_VERSION_TWO = b'bzr response 2\n'
38
REQUEST_VERSION_TWO = 'bzr request 2\n'
39
RESPONSE_VERSION_TWO = 'bzr response 2\n'
49
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
41
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
50
42
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
112
101
# support multiple chunks?
113
102
def _encode_bulk_data(self, body):
114
103
"""Encode body as a bulk data chunk."""
115
return b''.join((b'%d\n' % len(body), body, b'done\n'))
104
return ''.join(('%d\n' % len(body), body, 'done\n'))
117
106
def _serialise_offsets(self, offsets):
118
107
"""Serialise a readv offset list."""
120
109
for start, length in offsets:
121
txt.append(b'%d,%d' % (start, length))
122
return b'\n'.join(txt)
110
txt.append('%d,%d' % (start, length))
111
return '\n'.join(txt)
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
"""Server-side encoding and decoding logic for smart version 1."""
128
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
self._backing_transport = backing_transport
131
119
self._root_client_path = root_client_path
132
self._jail_root = jail_root
133
self.unused_data = b''
120
self.unused_data = ''
134
121
self._finished = False
136
123
self._has_dispatched = False
137
124
self.request = None
138
125
self._body_decoder = None
139
126
self._write_func = write_func
141
def accept_bytes(self, data):
128
def accept_bytes(self, bytes):
142
129
"""Take bytes, and advance the internal state machine appropriately.
144
:param data: must be a byte string
131
:param bytes: must be a byte string
146
if not isinstance(data, bytes):
147
raise ValueError(data)
148
self.in_buffer += data
133
if not isinstance(bytes, str):
134
raise ValueError(bytes)
135
self.in_buffer += bytes
149
136
if not self._has_dispatched:
150
if b'\n' not in self.in_buffer:
137
if '\n' not in self.in_buffer:
151
138
# no command line yet
153
140
self._has_dispatched = True
155
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
142
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
157
144
req_args = _decode_tuple(first_line)
158
145
self.request = request.SmartServerRequestHandler(
159
146
self._backing_transport, commands=request.request_handlers,
160
root_client_path=self._root_client_path,
161
jail_root=self._jail_root)
162
self.request.args_received(req_args)
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
163
149
if self.request.finished_reading:
164
150
# trivial request
165
151
self.unused_data = self.in_buffer
167
153
self._send_response(self.request.response)
168
154
except KeyboardInterrupt:
170
except errors.UnknownSmartMethod as err:
156
except errors.UnknownSmartMethod, err:
171
157
protocol_error = errors.SmartProtocolError(
172
"bad request '%s'" % (err.verb.decode('ascii'),))
158
"bad request %r" % (err.verb,))
173
159
failure = request.FailedSmartServerResponse(
174
(b'error', str(protocol_error).encode('utf-8')))
160
('error', str(protocol_error)))
175
161
self._send_response(failure)
177
except Exception as exception:
163
except Exception, exception:
178
164
# everything else: pass to client, flush, and quit
179
165
log_exception_quietly()
180
166
self._send_response(request.FailedSmartServerResponse(
181
(b'error', str(exception).encode('utf-8'))))
167
('error', str(exception))))
184
170
if self._has_dispatched:
278
264
self._write_success_or_failure_prefix(response)
279
265
self._write_func(_encode_tuple(response.args))
280
266
if response.body is not None:
281
if not isinstance(response.body, bytes):
282
raise AssertionError('body must be bytes')
267
if not isinstance(response.body, str):
268
raise AssertionError('body must be a str')
283
269
if not (response.body_stream is None):
284
270
raise AssertionError(
285
271
'body_stream and body cannot both be set')
286
data = self._encode_bulk_data(response.body)
287
self._write_func(data)
272
bytes = self._encode_bulk_data(response.body)
273
self._write_func(bytes)
288
274
elif response.body_stream is not None:
289
275
_send_stream(response.body_stream, self._write_func)
292
278
def _send_stream(stream, write_func):
293
write_func(b'chunked\n')
279
write_func('chunked\n')
294
280
_send_chunks(stream, write_func)
298
284
def _send_chunks(stream, write_func):
299
285
for chunk in stream:
300
if isinstance(chunk, bytes):
301
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
286
if isinstance(chunk, str):
287
bytes = "%x\n%s" % (len(chunk), chunk)
303
289
elif isinstance(chunk, request.FailedSmartServerResponse):
305
291
_send_chunks(chunk.args, write_func)
376
362
def _set_in_buffer(self, new_buf):
377
363
if new_buf is not None:
378
if not isinstance(new_buf, bytes):
379
raise TypeError(new_buf)
380
364
self._in_buffer_list = [new_buf]
381
365
self._in_buffer_len = len(new_buf)
383
367
self._in_buffer_list = []
384
368
self._in_buffer_len = 0
386
def accept_bytes(self, new_buf):
370
def accept_bytes(self, bytes):
387
371
"""Decode as much of bytes as possible.
389
If 'new_buf' contains too much data it will be appended to
373
If 'bytes' contains too much data it will be appended to
390
374
self.unused_data.
392
376
finished_reading will be set when no more data is required. Further
393
377
data will be appended to self.unused_data.
395
if not isinstance(new_buf, bytes):
396
raise TypeError(new_buf)
397
379
# accept_bytes is allowed to change the state
398
380
self._number_needed_bytes = None
399
381
# lsprof puts a very large amount of time on this specific call for
400
382
# large readv arrays
401
self._in_buffer_list.append(new_buf)
402
self._in_buffer_len += len(new_buf)
383
self._in_buffer_list.append(bytes)
384
self._in_buffer_len += len(bytes)
404
386
# Run the function for the current state.
405
387
current_state = self.state_accept
643
625
if 'hpss' in debug.debug_flags:
644
626
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
645
627
if getattr(self._request._medium, '_path', None) is not None:
647
self._request._medium._path)
628
mutter(' (to %s)', self._request._medium._path)
648
629
mutter(' %d bytes', len(body))
649
self._request_start_time = osutils.perf_counter()
630
self._request_start_time = time.time()
650
631
if 'hpssdetail' in debug.debug_flags:
651
632
mutter('hpss body content: %s', body)
652
633
self._write_args(args)
659
640
"""Make a remote call with a readv array.
661
642
The body is encoded with one line per readv offset pair. The numbers in
662
each pair are separated by a comma, and no trailing \\n is emitted.
643
each pair are separated by a comma, and no trailing \n is emitted.
664
645
if 'hpss' in debug.debug_flags:
665
646
mutter('hpss call w/readv: %s', repr(args)[1:-1])
666
647
if getattr(self._request._medium, '_path', None) is not None:
668
self._request._medium._path)
669
self._request_start_time = osutils.perf_counter()
648
mutter(' (to %s)', self._request._medium._path)
649
self._request_start_time = time.time()
670
650
self._write_args(args)
671
651
readv_bytes = self._serialise_offsets(body)
672
652
bytes = self._encode_bulk_data(readv_bytes)
755
735
:param verb: The verb used in that call.
756
736
:raises: UnexpectedSmartServerResponse
758
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
759
b"bad request '" + self._last_verb + b"'")
760
or result_tuple == (b'error', b"Generic bzr smart protocol error: "
761
b"bad request u'%s'" % self._last_verb)):
738
if (result_tuple == ('error', "Generic bzr smart protocol error: "
739
"bad request '%s'" % self._last_verb) or
740
result_tuple == ('error', "Generic bzr smart protocol error: "
741
"bad request u'%s'" % self._last_verb)):
762
742
# The response will have no body, so we've finished reading.
763
743
self._request.finished_reading()
764
744
raise errors.UnknownSmartMethod(self._last_verb)
776
756
while not _body_decoder.finished_reading:
777
757
bytes = self._request.read_bytes(_body_decoder.next_read_size())
779
759
# end of file encountered reading from server
780
760
raise errors.ConnectionReset(
781
761
"Connection lost while reading response body.")
782
762
_body_decoder.accept_bytes(bytes)
783
763
self._request.finished_reading()
784
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
764
self._body_buffer = StringIO(_body_decoder.read_pending_data())
785
765
# XXX: TODO check the trailer result.
786
766
if 'hpss' in debug.debug_flags:
787
767
mutter(' %d body bytes read',
864
844
_body_decoder = ChunkedBodyDecoder()
865
845
while not _body_decoder.finished_reading:
866
846
bytes = self._request.read_bytes(_body_decoder.next_read_size())
868
848
# end of file encountered reading from server
869
849
raise errors.ConnectionReset(
870
850
"Connection lost while reading streamed body.")
871
851
_body_decoder.accept_bytes(bytes)
872
852
for body_bytes in iter(_body_decoder.read_next_chunk, None):
873
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
853
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
874
854
mutter(' %d byte chunk read',
880
860
def build_server_protocol_three(backing_transport, write_func,
881
root_client_path, jail_root=None):
882
862
request_handler = request.SmartServerRequestHandler(
883
863
backing_transport, commands=request.request_handlers,
884
root_client_path=root_client_path, jail_root=jail_root)
864
root_client_path=root_client_path)
885
865
responder = ProtocolThreeResponder(write_func)
886
message_handler = message.ConventionalRequestHandler(
887
request_handler, responder)
866
message_handler = message.ConventionalRequestHandler(request_handler, responder)
888
867
return ProtocolThreeDecoder(message_handler)
1011
990
def _state_accept_expecting_message_part(self):
1012
991
message_part_kind = self._extract_single_byte()
1013
if message_part_kind == b'o':
992
if message_part_kind == 'o':
1014
993
self.state_accept = self._state_accept_expecting_one_byte
1015
elif message_part_kind == b's':
994
elif message_part_kind == 's':
1016
995
self.state_accept = self._state_accept_expecting_structure
1017
elif message_part_kind == b'b':
996
elif message_part_kind == 'b':
1018
997
self.state_accept = self._state_accept_expecting_bytes
1019
elif message_part_kind == b'e':
998
elif message_part_kind == 'e':
1022
1001
raise errors.SmartProtocolError(
1080
1059
class _ProtocolThreeEncoder(object):
1082
1061
response_marker = request_marker = MESSAGE_VERSION_THREE
1083
BUFFER_SIZE = 1024 * 1024 # 1 MiB buffer before flushing
1085
1063
def __init__(self, write_func):
1088
1065
self._real_write_func = write_func
1090
1067
def _write_func(self, bytes):
1091
# TODO: Another possibility would be to turn this into an async model.
1092
# Where we let another thread know that we have some bytes if
1093
# they want it, but we don't actually block for it
1094
# Note that osutils.send_all always sends 64kB chunks anyway, so
1095
# we might just push out smaller bits at a time?
1096
1068
self._buf.append(bytes)
1097
self._buf_len += len(bytes)
1098
if self._buf_len > self.BUFFER_SIZE:
1069
if len(self._buf) > 100:
1101
1072
def flush(self):
1103
self._real_write_func(b''.join(self._buf))
1074
self._real_write_func(''.join(self._buf))
1104
1075
del self._buf[:]
1107
1077
def _serialise_offsets(self, offsets):
1108
1078
"""Serialise a readv offset list."""
1110
1080
for start, length in offsets:
1111
txt.append(b'%d,%d' % (start, length))
1112
return b'\n'.join(txt)
1081
txt.append('%d,%d' % (start, length))
1082
return '\n'.join(txt)
1114
1084
def _write_protocol_version(self):
1115
1085
self._write_func(MESSAGE_VERSION_THREE)
1156
1126
def __init__(self, write_func):
1157
1127
_ProtocolThreeEncoder.__init__(self, write_func)
1158
1128
self.response_sent = False
1160
b'Software version': breezy.__version__.encode('utf-8')}
1161
if 'hpss' in debug.debug_flags:
1162
self._thread_id = _thread.get_ident()
1163
self._response_start_time = None
1165
def _trace(self, action, message, extra_bytes=None, include_time=False):
1166
if self._response_start_time is None:
1167
self._response_start_time = osutils.perf_counter()
1169
t = '%5.3fs ' % (osutils.perf_counter() - self._response_start_time)
1172
if extra_bytes is None:
1175
extra = ' ' + repr(extra_bytes[:40])
1177
extra = extra[:29] + extra[-1] + '...'
1178
mutter('%12s: [%s] %s%s%s'
1179
% (action, self._thread_id, t, message, extra))
1129
self._headers = {'Software version': bzrlib.__version__}
1181
1131
def send_error(self, exception):
1182
1132
if self.response_sent:
1185
1135
% (exception,))
1186
1136
if isinstance(exception, errors.UnknownSmartMethod):
1187
1137
failure = request.FailedSmartServerResponse(
1188
(b'UnknownMethod', exception.verb))
1138
('UnknownMethod', exception.verb))
1189
1139
self.send_response(failure)
1191
if 'hpss' in debug.debug_flags:
1192
self._trace('error', str(exception))
1193
1141
self.response_sent = True
1194
1142
self._write_protocol_version()
1195
1143
self._write_headers(self._headers)
1196
1144
self._write_error_status()
1197
self._write_structure(
1198
(b'error', str(exception).encode('utf-8', 'replace')))
1145
self._write_structure(('error', str(exception)))
1199
1146
self._write_end()
1201
1148
def send_response(self, response):
1210
1157
self._write_success_status()
1212
1159
self._write_error_status()
1213
if 'hpss' in debug.debug_flags:
1214
self._trace('response', repr(response.args))
1215
1160
self._write_structure(response.args)
1216
1161
if response.body is not None:
1217
1162
self._write_prefixed_body(response.body)
1218
if 'hpss' in debug.debug_flags:
1219
self._trace('body', '%d bytes' % (len(response.body),),
1220
response.body, include_time=True)
1221
1163
elif response.body_stream is not None:
1222
count = num_bytes = 0
1224
1164
for exc_info, chunk in _iter_with_errors(response.body_stream):
1226
1165
if exc_info is not None:
1227
1166
self._write_error_status()
1228
1167
error_struct = request._translate_error(exc_info[1])
1233
1172
self._write_error_status()
1234
1173
self._write_structure(chunk.args)
1236
num_bytes += len(chunk)
1237
if first_chunk is None:
1239
1175
self._write_prefixed_body(chunk)
1241
if 'hpssdetail' in debug.debug_flags:
1242
# Not worth timing separately, as _write_func is
1244
self._trace('body chunk',
1245
'%d bytes' % (len(chunk),),
1246
chunk, suppress_time=True)
1247
if 'hpss' in debug.debug_flags:
1248
self._trace('body stream',
1249
'%d bytes %d chunks' % (num_bytes, count),
1251
1176
self._write_end()
1252
if 'hpss' in debug.debug_flags:
1253
self._trace('response end', '', include_time=True)
1256
1179
def _iter_with_errors(iterable):
1339
1261
"""Make a remote call with a readv array.
1341
1263
The body is encoded with one line per readv offset pair. The numbers in
1342
each pair are separated by a comma, and no trailing \\n is emitted.
1264
each pair are separated by a comma, and no trailing \n is emitted.
1344
1266
if 'hpss' in debug.debug_flags:
1345
1267
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1346
1268
path = getattr(self._medium_request._medium, '_path', None)
1347
1269
if path is not None:
1348
1270
mutter(' (to %s)', path)
1349
self._request_start_time = osutils.perf_counter()
1271
self._request_start_time = time.time()
1350
1272
self._write_protocol_version()
1351
1273
self._write_headers(self._headers)
1352
1274
self._write_structure(args)
1372
1293
# have finished sending the stream. We would notice at the end
1373
1294
# anyway, but if the medium can deliver it early then it's good
1374
1295
# to short-circuit the whole request...
1375
# Provoke any ConnectionReset failures before we start the body stream.
1377
self.body_stream_started = True
1378
1296
for exc_info, part in _iter_with_errors(stream):
1379
1297
if exc_info is not None:
1380
1298
# Iterating the stream failed. Cleanly abort the request.
1381
1299
self._write_error_status()
1382
1300
# Currently the client unconditionally sends ('error',) as the
1384
self._write_structure((b'error',))
1302
self._write_structure(('error',))
1385
1303
self._write_end()
1386
1304
self._medium_request.finished_writing()
1387
(exc_type, exc_val, exc_tb) = exc_info
1305
raise exc_info[0], exc_info[1], exc_info[2]
1393
1307
self._write_prefixed_body(part)
1395
1309
self._write_end()
1396
1310
self._medium_request.finished_writing()