21
from __future__ import absolute_import
22
from collections.abc import deque
23
except ImportError: # python < 3.7
24
from collections import deque
24
from cStringIO import StringIO
26
from io import BytesIO
36
from bzrlib.smart import message, request
37
from bzrlib.trace import log_exception_quietly, mutter
38
from bzrlib.bencode import bdecode_as_tuple, bencode
38
from . import message, request
39
from ...trace import log_exception_quietly, mutter
40
from ...bencode import bdecode_as_tuple, bencode
41
43
# Protocol version strings. These are sent as prefixes of bzr requests and
42
44
# responses to identify the protocol version being used. (There are no version
43
45
# one strings because that version doesn't send any).
44
REQUEST_VERSION_TWO = 'bzr request 2\n'
45
RESPONSE_VERSION_TWO = 'bzr response 2\n'
46
REQUEST_VERSION_TWO = b'bzr request 2\n'
47
RESPONSE_VERSION_TWO = b'bzr response 2\n'
47
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
49
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
48
50
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
56
58
def _decode_tuple(req_line):
57
if req_line is None or req_line == '':
59
if req_line is None or req_line == b'':
59
if req_line[-1] != '\n':
61
if not req_line.endswith(b'\n'):
60
62
raise errors.SmartProtocolError("request %r not terminated" % req_line)
61
return tuple(req_line[:-1].split('\x01'))
63
return tuple(req_line[:-1].split(b'\x01'))
64
66
def _encode_tuple(args):
65
67
"""Encode the tuple args to a bytestream."""
66
joined = '\x01'.join(args) + '\n'
67
if type(joined) is unicode:
68
# XXX: We should fix things so this never happens! -AJB, 20100304
69
mutter('response args contain unicode, should be only bytes: %r',
71
joined = joined.encode('ascii')
69
if isinstance(arg, str):
71
return b'\x01'.join(args) + b'\n'
75
74
class Requester(object):
113
112
# support multiple chunks?
114
113
def _encode_bulk_data(self, body):
115
114
"""Encode body as a bulk data chunk."""
116
return ''.join(('%d\n' % len(body), body, 'done\n'))
115
return b''.join((b'%d\n' % len(body), body, b'done\n'))
118
117
def _serialise_offsets(self, offsets):
119
118
"""Serialise a readv offset list."""
121
120
for start, length in offsets:
122
txt.append('%d,%d' % (start, length))
123
return '\n'.join(txt)
121
txt.append(b'%d,%d' % (start, length))
122
return b'\n'.join(txt)
126
125
class SmartServerRequestProtocolOne(SmartProtocolBase):
127
126
"""Server-side encoding and decoding logic for smart version 1."""
129
128
def __init__(self, backing_transport, write_func, root_client_path='/',
131
130
self._backing_transport = backing_transport
132
131
self._root_client_path = root_client_path
133
132
self._jail_root = jail_root
134
self.unused_data = ''
133
self.unused_data = b''
135
134
self._finished = False
137
136
self._has_dispatched = False
138
137
self.request = None
139
138
self._body_decoder = None
140
139
self._write_func = write_func
142
def accept_bytes(self, bytes):
141
def accept_bytes(self, data):
143
142
"""Take bytes, and advance the internal state machine appropriately.
145
:param bytes: must be a byte string
144
:param data: must be a byte string
147
if not isinstance(bytes, str):
148
raise ValueError(bytes)
149
self.in_buffer += bytes
146
if not isinstance(data, bytes):
147
raise ValueError(data)
148
self.in_buffer += data
150
149
if not self._has_dispatched:
151
if '\n' not in self.in_buffer:
150
if b'\n' not in self.in_buffer:
152
151
# no command line yet
154
153
self._has_dispatched = True
156
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
155
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
158
157
req_args = _decode_tuple(first_line)
159
158
self.request = request.SmartServerRequestHandler(
160
159
self._backing_transport, commands=request.request_handlers,
164
163
if self.request.finished_reading:
165
164
# trivial request
166
165
self.unused_data = self.in_buffer
168
167
self._send_response(self.request.response)
169
168
except KeyboardInterrupt:
171
except errors.UnknownSmartMethod, err:
170
except errors.UnknownSmartMethod as err:
172
171
protocol_error = errors.SmartProtocolError(
173
"bad request %r" % (err.verb,))
172
"bad request '%s'" % (err.verb.decode('ascii'),))
174
173
failure = request.FailedSmartServerResponse(
175
('error', str(protocol_error)))
174
(b'error', str(protocol_error).encode('utf-8')))
176
175
self._send_response(failure)
178
except Exception, exception:
177
except Exception as exception:
179
178
# everything else: pass to client, flush, and quit
180
179
log_exception_quietly()
181
180
self._send_response(request.FailedSmartServerResponse(
182
('error', str(exception))))
181
(b'error', str(exception).encode('utf-8'))))
185
184
if self._has_dispatched:
219
218
self._write_success_or_failure_prefix(response)
220
219
self._write_func(_encode_tuple(args))
221
220
if body is not None:
222
if not isinstance(body, str):
221
if not isinstance(body, bytes):
223
222
raise ValueError(body)
224
bytes = self._encode_bulk_data(body)
225
self._write_func(bytes)
223
data = self._encode_bulk_data(body)
224
self._write_func(data)
227
226
def _write_protocol_version(self):
228
227
"""Write any prefixes this protocol requires.
259
258
def _write_success_or_failure_prefix(self, response):
260
259
"""Write the protocol specific success/failure prefix."""
261
260
if response.is_successful():
262
self._write_func('success\n')
261
self._write_func(b'success\n')
264
self._write_func('failed\n')
263
self._write_func(b'failed\n')
266
265
def _write_protocol_version(self):
267
266
r"""Write any prefixes this protocol requires.
279
278
self._write_success_or_failure_prefix(response)
280
279
self._write_func(_encode_tuple(response.args))
281
280
if response.body is not None:
282
if not isinstance(response.body, str):
283
raise AssertionError('body must be a str')
281
if not isinstance(response.body, bytes):
282
raise AssertionError('body must be bytes')
284
283
if not (response.body_stream is None):
285
284
raise AssertionError(
286
285
'body_stream and body cannot both be set')
287
bytes = self._encode_bulk_data(response.body)
288
self._write_func(bytes)
286
data = self._encode_bulk_data(response.body)
287
self._write_func(data)
289
288
elif response.body_stream is not None:
290
289
_send_stream(response.body_stream, self._write_func)
293
292
def _send_stream(stream, write_func):
294
write_func('chunked\n')
293
write_func(b'chunked\n')
295
294
_send_chunks(stream, write_func)
299
298
def _send_chunks(stream, write_func):
300
299
for chunk in stream:
301
if isinstance(chunk, str):
302
bytes = "%x\n%s" % (len(chunk), chunk)
300
if isinstance(chunk, bytes):
301
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
304
303
elif isinstance(chunk, request.FailedSmartServerResponse):
306
305
_send_chunks(chunk.args, write_func)
340
339
self.finished_reading = False
341
340
self._in_buffer_list = []
342
341
self._in_buffer_len = 0
343
self.unused_data = ''
342
self.unused_data = b''
344
343
self.bytes_left = None
345
344
self._number_needed_bytes = None
347
346
def _get_in_buffer(self):
348
347
if len(self._in_buffer_list) == 1:
349
348
return self._in_buffer_list[0]
350
in_buffer = ''.join(self._in_buffer_list)
349
in_buffer = b''.join(self._in_buffer_list)
351
350
if len(in_buffer) != self._in_buffer_len:
352
351
raise AssertionError(
353
352
"Length of buffer did not match expected value: %s != %s"
366
365
# check if we can yield the bytes from just the first entry in our list
367
366
if len(self._in_buffer_list) == 0:
368
367
raise AssertionError('Callers must be sure we have buffered bytes'
369
' before calling _get_in_bytes')
368
' before calling _get_in_bytes')
370
369
if len(self._in_buffer_list[0]) > count:
371
370
return self._in_buffer_list[0][:count]
372
371
# We can't yield it from the first buffer, so collapse all buffers, and
377
376
def _set_in_buffer(self, new_buf):
378
377
if new_buf is not None:
378
if not isinstance(new_buf, bytes):
379
raise TypeError(new_buf)
379
380
self._in_buffer_list = [new_buf]
380
381
self._in_buffer_len = len(new_buf)
382
383
self._in_buffer_list = []
383
384
self._in_buffer_len = 0
385
def accept_bytes(self, bytes):
386
def accept_bytes(self, new_buf):
386
387
"""Decode as much of bytes as possible.
388
If 'bytes' contains too much data it will be appended to
389
If 'new_buf' contains too much data it will be appended to
389
390
self.unused_data.
391
392
finished_reading will be set when no more data is required. Further
392
393
data will be appended to self.unused_data.
395
if not isinstance(new_buf, bytes):
396
raise TypeError(new_buf)
394
397
# accept_bytes is allowed to change the state
395
398
self._number_needed_bytes = None
396
399
# lsprof puts a very large amount of time on this specific call for
397
400
# large readv arrays
398
self._in_buffer_list.append(bytes)
399
self._in_buffer_len += len(bytes)
401
self._in_buffer_list.append(new_buf)
402
self._in_buffer_len += len(new_buf)
401
404
# Run the function for the current state.
402
405
current_state = self.state_accept
491
494
def _state_accept_expecting_length(self):
492
495
prefix = self._extract_line()
494
497
self.error = True
495
498
self.error_in_progress = []
496
499
self._state_accept_expecting_length()
498
elif prefix == 'END':
501
elif prefix == b'END':
499
502
# We've read the end-of-body marker.
500
503
# Any further bytes are unused data, including the bytes left in
501
504
# the _in_buffer.
534
537
_StatefulDecoder.__init__(self)
535
538
self.state_accept = self._state_accept_expecting_length
536
539
self.state_read = self._state_read_no_data
538
self._trailer_buffer = ''
541
self._trailer_buffer = b''
540
543
def next_read_size(self):
541
544
if self.bytes_left is not None:
560
563
def _state_accept_expecting_length(self):
561
564
in_buf = self._get_in_buffer()
562
pos = in_buf.find('\n')
565
pos = in_buf.find(b'\n')
565
568
self.bytes_left = int(in_buf[:pos])
566
self._set_in_buffer(in_buf[pos+1:])
569
self._set_in_buffer(in_buf[pos + 1:])
567
570
self.state_accept = self._state_accept_reading_body
568
571
self.state_read = self._state_read_body_buffer
585
588
self._set_in_buffer(None)
586
589
# TODO: what if the trailer does not match "done\n"? Should this raise
587
590
# a ProtocolViolation exception?
588
if self._trailer_buffer.startswith('done\n'):
589
self.unused_data = self._trailer_buffer[len('done\n'):]
591
if self._trailer_buffer.startswith(b'done\n'):
592
self.unused_data = self._trailer_buffer[len(b'done\n'):]
590
593
self.state_accept = self._state_accept_reading_unused
591
594
self.finished_reading = True
627
630
mutter('hpss call: %s', repr(args)[1:-1])
628
631
if getattr(self._request._medium, 'base', None) is not None:
629
632
mutter(' (to %s)', self._request._medium.base)
630
self._request_start_time = osutils.timer_func()
633
self._request_start_time = osutils.perf_counter()
631
634
self._write_args(args)
632
635
self._request.finished_writing()
633
636
self._last_verb = args[0]
640
643
if 'hpss' in debug.debug_flags:
641
644
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
642
645
if getattr(self._request._medium, '_path', None) is not None:
643
mutter(' (to %s)', self._request._medium._path)
647
self._request._medium._path)
644
648
mutter(' %d bytes', len(body))
645
self._request_start_time = osutils.timer_func()
649
self._request_start_time = osutils.perf_counter()
646
650
if 'hpssdetail' in debug.debug_flags:
647
651
mutter('hpss body content: %s', body)
648
652
self._write_args(args)
660
664
if 'hpss' in debug.debug_flags:
661
665
mutter('hpss call w/readv: %s', repr(args)[1:-1])
662
666
if getattr(self._request._medium, '_path', None) is not None:
663
mutter(' (to %s)', self._request._medium._path)
664
self._request_start_time = osutils.timer_func()
668
self._request._medium._path)
669
self._request_start_time = osutils.perf_counter()
665
670
self._write_args(args)
666
671
readv_bytes = self._serialise_offsets(body)
667
672
bytes = self._encode_bulk_data(readv_bytes)
693
698
if 'hpss' in debug.debug_flags:
694
699
if self._request_start_time is not None:
695
700
mutter(' result: %6.3fs %s',
696
osutils.timer_func() - self._request_start_time,
701
osutils.perf_counter() - self._request_start_time,
697
702
repr(result)[1:-1])
698
703
self._request_start_time = None
750
755
:param verb: The verb used in that call.
751
756
:raises: UnexpectedSmartServerResponse
753
if (result_tuple == ('error', "Generic bzr smart protocol error: "
754
"bad request '%s'" % self._last_verb) or
755
result_tuple == ('error', "Generic bzr smart protocol error: "
756
"bad request u'%s'" % self._last_verb)):
758
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
759
b"bad request '" + self._last_verb + b"'")
760
or result_tuple == (b'error', b"Generic bzr smart protocol error: "
761
b"bad request u'%s'" % self._last_verb)):
757
762
# The response will have no body, so we've finished reading.
758
763
self._request.finished_reading()
759
764
raise errors.UnknownSmartMethod(self._last_verb)
771
776
while not _body_decoder.finished_reading:
772
777
bytes = self._request.read_bytes(_body_decoder.next_read_size())
774
779
# end of file encountered reading from server
775
780
raise errors.ConnectionReset(
776
781
"Connection lost while reading response body.")
777
782
_body_decoder.accept_bytes(bytes)
778
783
self._request.finished_reading()
779
self._body_buffer = StringIO(_body_decoder.read_pending_data())
784
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
780
785
# XXX: TODO check the trailer result.
781
786
if 'hpss' in debug.debug_flags:
782
787
mutter(' %d body bytes read',
790
795
def query_version(self):
791
796
"""Return protocol version number of the server."""
793
798
resp = self.read_response_tuple()
794
if resp == ('ok', '1'):
799
if resp == (b'ok', b'1'):
796
elif resp == ('ok', '2'):
801
elif resp == (b'ok', b'2'):
799
804
raise errors.SmartProtocolError("bad response %r" % (resp,))
831
836
response_status = self._request.read_line()
832
837
result = SmartClientRequestProtocolOne._read_response_tuple(self)
833
838
self._response_is_unknown_method(result)
834
if response_status == 'success\n':
839
if response_status == b'success\n':
835
840
self.response_status = True
836
841
if not expect_body:
837
842
self._request.finished_reading()
839
elif response_status == 'failed\n':
844
elif response_status == b'failed\n':
840
845
self.response_status = False
841
846
self._request.finished_reading()
842
847
raise errors.ErrorFromSmartServer(result)
859
864
_body_decoder = ChunkedBodyDecoder()
860
865
while not _body_decoder.finished_reading:
861
866
bytes = self._request.read_bytes(_body_decoder.next_read_size())
863
868
# end of file encountered reading from server
864
869
raise errors.ConnectionReset(
865
870
"Connection lost while reading streamed body.")
866
871
_body_decoder.accept_bytes(bytes)
867
872
for body_bytes in iter(_body_decoder.read_next_chunk, None):
868
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
873
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
869
874
mutter(' %d byte chunk read',
878
883
backing_transport, commands=request.request_handlers,
879
884
root_client_path=root_client_path, jail_root=jail_root)
880
885
responder = ProtocolThreeResponder(write_func)
881
message_handler = message.ConventionalRequestHandler(request_handler, responder)
886
message_handler = message.ConventionalRequestHandler(
887
request_handler, responder)
882
888
return ProtocolThreeDecoder(message_handler)
908
914
_StatefulDecoder.accept_bytes(self, bytes)
909
915
except KeyboardInterrupt:
911
except errors.SmartMessageHandlerError, exception:
917
except errors.SmartMessageHandlerError as exception:
912
918
# We do *not* set self.decoding_failed here. The message handler
913
919
# has raised an error, but the decoder is still able to parse bytes
914
920
# and determine when this message ends.
918
924
# The state machine is ready to continue decoding, but the
919
925
# exception has interrupted the loop that runs the state machine.
920
926
# So we call accept_bytes again to restart it.
921
self.accept_bytes('')
922
except Exception, exception:
927
self.accept_bytes(b'')
928
except Exception as exception:
923
929
# The decoder itself has raised an exception. We cannot continue
925
931
self.decoding_failed = True
994
1000
def _state_accept_expecting_headers(self):
995
1001
decoded = self._extract_prefixed_bencoded_data()
996
if type(decoded) is not dict:
1002
if not isinstance(decoded, dict):
997
1003
raise errors.SmartProtocolError(
998
1004
'Header object %r is not a dict' % (decoded,))
999
1005
self.state_accept = self._state_accept_expecting_message_part
1005
1011
def _state_accept_expecting_message_part(self):
1006
1012
message_part_kind = self._extract_single_byte()
1007
if message_part_kind == 'o':
1013
if message_part_kind == b'o':
1008
1014
self.state_accept = self._state_accept_expecting_one_byte
1009
elif message_part_kind == 's':
1015
elif message_part_kind == b's':
1010
1016
self.state_accept = self._state_accept_expecting_structure
1011
elif message_part_kind == 'b':
1017
elif message_part_kind == b'b':
1012
1018
self.state_accept = self._state_accept_expecting_bytes
1013
elif message_part_kind == 'e':
1019
elif message_part_kind == b'e':
1016
1022
raise errors.SmartProtocolError(
1102
1108
"""Serialise a readv offset list."""
1104
1110
for start, length in offsets:
1105
txt.append('%d,%d' % (start, length))
1106
return '\n'.join(txt)
1111
txt.append(b'%d,%d' % (start, length))
1112
return b'\n'.join(txt)
1108
1114
def _write_protocol_version(self):
1109
1115
self._write_func(MESSAGE_VERSION_THREE)
1117
1123
self._write_prefixed_bencode(headers)
1119
1125
def _write_structure(self, args):
1120
self._write_func('s')
1126
self._write_func(b's')
1122
1128
for arg in args:
1123
if type(arg) is unicode:
1129
if isinstance(arg, str):
1124
1130
utf8_args.append(arg.encode('utf8'))
1126
1132
utf8_args.append(arg)
1127
1133
self._write_prefixed_bencode(utf8_args)
1129
1135
def _write_end(self):
1130
self._write_func('e')
1136
self._write_func(b'e')
1133
1139
def _write_prefixed_body(self, bytes):
1134
self._write_func('b')
1140
self._write_func(b'b')
1135
1141
self._write_func(struct.pack('!L', len(bytes)))
1136
1142
self._write_func(bytes)
1138
1144
def _write_chunked_body_start(self):
1139
self._write_func('oC')
1145
self._write_func(b'oC')
1141
1147
def _write_error_status(self):
1142
self._write_func('oE')
1148
self._write_func(b'oE')
1144
1150
def _write_success_status(self):
1145
self._write_func('oS')
1151
self._write_func(b'oS')
1148
1154
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1150
1156
def __init__(self, write_func):
1151
1157
_ProtocolThreeEncoder.__init__(self, write_func)
1152
1158
self.response_sent = False
1153
self._headers = {'Software version': bzrlib.__version__}
1160
b'Software version': breezy.__version__.encode('utf-8')}
1154
1161
if 'hpss' in debug.debug_flags:
1155
self._thread_id = thread.get_ident()
1162
self._thread_id = _thread.get_ident()
1156
1163
self._response_start_time = None
1158
1165
def _trace(self, action, message, extra_bytes=None, include_time=False):
1159
1166
if self._response_start_time is None:
1160
self._response_start_time = osutils.timer_func()
1167
self._response_start_time = osutils.perf_counter()
1161
1168
if include_time:
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1169
t = '%5.3fs ' % (osutils.perf_counter() - self._response_start_time)
1165
1172
if extra_bytes is None:
1178
1185
% (exception,))
1179
1186
if isinstance(exception, errors.UnknownSmartMethod):
1180
1187
failure = request.FailedSmartServerResponse(
1181
('UnknownMethod', exception.verb))
1188
(b'UnknownMethod', exception.verb))
1182
1189
self.send_response(failure)
1184
1191
if 'hpss' in debug.debug_flags:
1187
1194
self._write_protocol_version()
1188
1195
self._write_headers(self._headers)
1189
1196
self._write_error_status()
1190
self._write_structure(('error', str(exception)))
1197
self._write_structure(
1198
(b'error', str(exception).encode('utf-8', 'replace')))
1191
1199
self._write_end()
1193
1201
def send_response(self, response):
1301
1309
base = getattr(self._medium_request._medium, 'base', None)
1302
1310
if base is not None:
1303
1311
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1312
self._request_start_time = osutils.perf_counter()
1305
1313
self._write_protocol_version()
1306
1314
self._write_headers(self._headers)
1307
1315
self._write_structure(args)
1319
1327
if path is not None:
1320
1328
mutter(' (to %s)', path)
1321
1329
mutter(' %d bytes', len(body))
1322
self._request_start_time = osutils.timer_func()
1330
self._request_start_time = osutils.perf_counter()
1323
1331
self._write_protocol_version()
1324
1332
self._write_headers(self._headers)
1325
1333
self._write_structure(args)
1338
1346
path = getattr(self._medium_request._medium, '_path', None)
1339
1347
if path is not None:
1340
1348
mutter(' (to %s)', path)
1341
self._request_start_time = osutils.timer_func()
1349
self._request_start_time = osutils.perf_counter()
1342
1350
self._write_protocol_version()
1343
1351
self._write_headers(self._headers)
1344
1352
self._write_structure(args)
1355
1363
path = getattr(self._medium_request._medium, '_path', None)
1356
1364
if path is not None:
1357
1365
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1366
self._request_start_time = osutils.perf_counter()
1359
1367
self.body_stream_started = False
1360
1368
self._write_protocol_version()
1361
1369
self._write_headers(self._headers)
1373
1381
self._write_error_status()
1374
1382
# Currently the client unconditionally sends ('error',) as the
1376
self._write_structure(('error',))
1384
self._write_structure((b'error',))
1377
1385
self._write_end()
1378
1386
self._medium_request.finished_writing()
1379
raise exc_info[0], exc_info[1], exc_info[2]
1387
(exc_type, exc_val, exc_tb) = exc_info
1381
1393
self._write_prefixed_body(part)
1383
1395
self._write_end()
1384
1396
self._medium_request.finished_writing()