22
from collections.abc import deque
23
except ImportError: # python < 3.7
24
from collections import deque
21
from __future__ import absolute_import
26
from io import BytesIO
24
from cStringIO import StringIO
38
from . import message, request
39
from ...trace import log_exception_quietly, mutter
40
from ...bencode import bdecode_as_tuple, bencode
36
from brzlib.smart import message, request
37
from brzlib.trace import log_exception_quietly, mutter
38
from brzlib.bencode import bdecode_as_tuple, bencode
43
41
# Protocol version strings. These are sent as prefixes of bzr requests and
44
42
# responses to identify the protocol version being used. (There are no version
45
43
# one strings because that version doesn't send any).
46
REQUEST_VERSION_TWO = b'bzr request 2\n'
47
RESPONSE_VERSION_TWO = b'bzr response 2\n'
44
REQUEST_VERSION_TWO = 'bzr request 2\n'
45
RESPONSE_VERSION_TWO = 'bzr response 2\n'
49
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
47
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
50
48
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
58
56
def _decode_tuple(req_line):
59
if req_line is None or req_line == b'':
57
if req_line is None or req_line == '':
61
if not req_line.endswith(b'\n'):
59
if req_line[-1] != '\n':
62
60
raise errors.SmartProtocolError("request %r not terminated" % req_line)
63
return tuple(req_line[:-1].split(b'\x01'))
61
return tuple(req_line[:-1].split('\x01'))
66
64
def _encode_tuple(args):
67
65
"""Encode the tuple args to a bytestream."""
69
if isinstance(arg, str):
71
return b'\x01'.join(args) + b'\n'
66
joined = '\x01'.join(args) + '\n'
67
if type(joined) is unicode:
68
# XXX: We should fix things so this never happens! -AJB, 20100304
69
mutter('response args contain unicode, should be only bytes: %r',
71
joined = joined.encode('ascii')
74
75
class Requester(object):
112
113
# support multiple chunks?
113
114
def _encode_bulk_data(self, body):
114
115
"""Encode body as a bulk data chunk."""
115
return b''.join((b'%d\n' % len(body), body, b'done\n'))
116
return ''.join(('%d\n' % len(body), body, 'done\n'))
117
118
def _serialise_offsets(self, offsets):
118
119
"""Serialise a readv offset list."""
120
121
for start, length in offsets:
121
txt.append(b'%d,%d' % (start, length))
122
return b'\n'.join(txt)
122
txt.append('%d,%d' % (start, length))
123
return '\n'.join(txt)
125
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
127
"""Server-side encoding and decoding logic for smart version 1."""
128
129
def __init__(self, backing_transport, write_func, root_client_path='/',
130
131
self._backing_transport = backing_transport
131
132
self._root_client_path = root_client_path
132
133
self._jail_root = jail_root
133
self.unused_data = b''
134
self.unused_data = ''
134
135
self._finished = False
136
137
self._has_dispatched = False
137
138
self.request = None
138
139
self._body_decoder = None
139
140
self._write_func = write_func
141
def accept_bytes(self, data):
142
def accept_bytes(self, bytes):
142
143
"""Take bytes, and advance the internal state machine appropriately.
144
:param data: must be a byte string
145
:param bytes: must be a byte string
146
if not isinstance(data, bytes):
147
raise ValueError(data)
148
self.in_buffer += data
147
if not isinstance(bytes, str):
148
raise ValueError(bytes)
149
self.in_buffer += bytes
149
150
if not self._has_dispatched:
150
if b'\n' not in self.in_buffer:
151
if '\n' not in self.in_buffer:
151
152
# no command line yet
153
154
self._has_dispatched = True
155
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
156
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
157
158
req_args = _decode_tuple(first_line)
158
159
self.request = request.SmartServerRequestHandler(
159
160
self._backing_transport, commands=request.request_handlers,
163
164
if self.request.finished_reading:
164
165
# trivial request
165
166
self.unused_data = self.in_buffer
167
168
self._send_response(self.request.response)
168
169
except KeyboardInterrupt:
170
except errors.UnknownSmartMethod as err:
171
except errors.UnknownSmartMethod, err:
171
172
protocol_error = errors.SmartProtocolError(
172
"bad request '%s'" % (err.verb.decode('ascii'),))
173
"bad request %r" % (err.verb,))
173
174
failure = request.FailedSmartServerResponse(
174
(b'error', str(protocol_error).encode('utf-8')))
175
('error', str(protocol_error)))
175
176
self._send_response(failure)
177
except Exception as exception:
178
except Exception, exception:
178
179
# everything else: pass to client, flush, and quit
179
180
log_exception_quietly()
180
181
self._send_response(request.FailedSmartServerResponse(
181
(b'error', str(exception).encode('utf-8'))))
182
('error', str(exception))))
184
185
if self._has_dispatched:
218
219
self._write_success_or_failure_prefix(response)
219
220
self._write_func(_encode_tuple(args))
220
221
if body is not None:
221
if not isinstance(body, bytes):
222
if not isinstance(body, str):
222
223
raise ValueError(body)
223
data = self._encode_bulk_data(body)
224
self._write_func(data)
224
bytes = self._encode_bulk_data(body)
225
self._write_func(bytes)
226
227
def _write_protocol_version(self):
227
228
"""Write any prefixes this protocol requires.
258
259
def _write_success_or_failure_prefix(self, response):
259
260
"""Write the protocol specific success/failure prefix."""
260
261
if response.is_successful():
261
self._write_func(b'success\n')
262
self._write_func('success\n')
263
self._write_func(b'failed\n')
264
self._write_func('failed\n')
265
266
def _write_protocol_version(self):
266
267
r"""Write any prefixes this protocol requires.
278
279
self._write_success_or_failure_prefix(response)
279
280
self._write_func(_encode_tuple(response.args))
280
281
if response.body is not None:
281
if not isinstance(response.body, bytes):
282
raise AssertionError('body must be bytes')
282
if not isinstance(response.body, str):
283
raise AssertionError('body must be a str')
283
284
if not (response.body_stream is None):
284
285
raise AssertionError(
285
286
'body_stream and body cannot both be set')
286
data = self._encode_bulk_data(response.body)
287
self._write_func(data)
287
bytes = self._encode_bulk_data(response.body)
288
self._write_func(bytes)
288
289
elif response.body_stream is not None:
289
290
_send_stream(response.body_stream, self._write_func)
292
293
def _send_stream(stream, write_func):
293
write_func(b'chunked\n')
294
write_func('chunked\n')
294
295
_send_chunks(stream, write_func)
298
299
def _send_chunks(stream, write_func):
299
300
for chunk in stream:
300
if isinstance(chunk, bytes):
301
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
301
if isinstance(chunk, str):
302
bytes = "%x\n%s" % (len(chunk), chunk)
303
304
elif isinstance(chunk, request.FailedSmartServerResponse):
305
306
_send_chunks(chunk.args, write_func)
339
340
self.finished_reading = False
340
341
self._in_buffer_list = []
341
342
self._in_buffer_len = 0
342
self.unused_data = b''
343
self.unused_data = ''
343
344
self.bytes_left = None
344
345
self._number_needed_bytes = None
346
347
def _get_in_buffer(self):
347
348
if len(self._in_buffer_list) == 1:
348
349
return self._in_buffer_list[0]
349
in_buffer = b''.join(self._in_buffer_list)
350
in_buffer = ''.join(self._in_buffer_list)
350
351
if len(in_buffer) != self._in_buffer_len:
351
352
raise AssertionError(
352
353
"Length of buffer did not match expected value: %s != %s"
365
366
# check if we can yield the bytes from just the first entry in our list
366
367
if len(self._in_buffer_list) == 0:
367
368
raise AssertionError('Callers must be sure we have buffered bytes'
368
' before calling _get_in_bytes')
369
' before calling _get_in_bytes')
369
370
if len(self._in_buffer_list[0]) > count:
370
371
return self._in_buffer_list[0][:count]
371
372
# We can't yield it from the first buffer, so collapse all buffers, and
376
377
def _set_in_buffer(self, new_buf):
377
378
if new_buf is not None:
378
if not isinstance(new_buf, bytes):
379
raise TypeError(new_buf)
380
379
self._in_buffer_list = [new_buf]
381
380
self._in_buffer_len = len(new_buf)
383
382
self._in_buffer_list = []
384
383
self._in_buffer_len = 0
386
def accept_bytes(self, new_buf):
385
def accept_bytes(self, bytes):
387
386
"""Decode as much of bytes as possible.
389
If 'new_buf' contains too much data it will be appended to
388
If 'bytes' contains too much data it will be appended to
390
389
self.unused_data.
392
391
finished_reading will be set when no more data is required. Further
393
392
data will be appended to self.unused_data.
395
if not isinstance(new_buf, bytes):
396
raise TypeError(new_buf)
397
394
# accept_bytes is allowed to change the state
398
395
self._number_needed_bytes = None
399
396
# lsprof puts a very large amount of time on this specific call for
400
397
# large readv arrays
401
self._in_buffer_list.append(new_buf)
402
self._in_buffer_len += len(new_buf)
398
self._in_buffer_list.append(bytes)
399
self._in_buffer_len += len(bytes)
404
401
# Run the function for the current state.
405
402
current_state = self.state_accept
494
491
def _state_accept_expecting_length(self):
495
492
prefix = self._extract_line()
497
494
self.error = True
498
495
self.error_in_progress = []
499
496
self._state_accept_expecting_length()
501
elif prefix == b'END':
498
elif prefix == 'END':
502
499
# We've read the end-of-body marker.
503
500
# Any further bytes are unused data, including the bytes left in
504
501
# the _in_buffer.
537
534
_StatefulDecoder.__init__(self)
538
535
self.state_accept = self._state_accept_expecting_length
539
536
self.state_read = self._state_read_no_data
541
self._trailer_buffer = b''
538
self._trailer_buffer = ''
543
540
def next_read_size(self):
544
541
if self.bytes_left is not None:
563
560
def _state_accept_expecting_length(self):
564
561
in_buf = self._get_in_buffer()
565
pos = in_buf.find(b'\n')
562
pos = in_buf.find('\n')
568
565
self.bytes_left = int(in_buf[:pos])
569
self._set_in_buffer(in_buf[pos + 1:])
566
self._set_in_buffer(in_buf[pos+1:])
570
567
self.state_accept = self._state_accept_reading_body
571
568
self.state_read = self._state_read_body_buffer
588
585
self._set_in_buffer(None)
589
586
# TODO: what if the trailer does not match "done\n"? Should this raise
590
587
# a ProtocolViolation exception?
591
if self._trailer_buffer.startswith(b'done\n'):
592
self.unused_data = self._trailer_buffer[len(b'done\n'):]
588
if self._trailer_buffer.startswith('done\n'):
589
self.unused_data = self._trailer_buffer[len('done\n'):]
593
590
self.state_accept = self._state_accept_reading_unused
594
591
self.finished_reading = True
630
627
mutter('hpss call: %s', repr(args)[1:-1])
631
628
if getattr(self._request._medium, 'base', None) is not None:
632
629
mutter(' (to %s)', self._request._medium.base)
633
self._request_start_time = osutils.perf_counter()
630
self._request_start_time = osutils.timer_func()
634
631
self._write_args(args)
635
632
self._request.finished_writing()
636
633
self._last_verb = args[0]
643
640
if 'hpss' in debug.debug_flags:
644
641
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
645
642
if getattr(self._request._medium, '_path', None) is not None:
647
self._request._medium._path)
643
mutter(' (to %s)', self._request._medium._path)
648
644
mutter(' %d bytes', len(body))
649
self._request_start_time = osutils.perf_counter()
645
self._request_start_time = osutils.timer_func()
650
646
if 'hpssdetail' in debug.debug_flags:
651
647
mutter('hpss body content: %s', body)
652
648
self._write_args(args)
664
660
if 'hpss' in debug.debug_flags:
665
661
mutter('hpss call w/readv: %s', repr(args)[1:-1])
666
662
if getattr(self._request._medium, '_path', None) is not None:
668
self._request._medium._path)
669
self._request_start_time = osutils.perf_counter()
663
mutter(' (to %s)', self._request._medium._path)
664
self._request_start_time = osutils.timer_func()
670
665
self._write_args(args)
671
666
readv_bytes = self._serialise_offsets(body)
672
667
bytes = self._encode_bulk_data(readv_bytes)
698
693
if 'hpss' in debug.debug_flags:
699
694
if self._request_start_time is not None:
700
695
mutter(' result: %6.3fs %s',
701
osutils.perf_counter() - self._request_start_time,
696
osutils.timer_func() - self._request_start_time,
702
697
repr(result)[1:-1])
703
698
self._request_start_time = None
755
750
:param verb: The verb used in that call.
756
751
:raises: UnexpectedSmartServerResponse
758
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
759
b"bad request '" + self._last_verb + b"'")
760
or result_tuple == (b'error', b"Generic bzr smart protocol error: "
761
b"bad request u'%s'" % self._last_verb)):
753
if (result_tuple == ('error', "Generic bzr smart protocol error: "
754
"bad request '%s'" % self._last_verb) or
755
result_tuple == ('error', "Generic bzr smart protocol error: "
756
"bad request u'%s'" % self._last_verb)):
762
757
# The response will have no body, so we've finished reading.
763
758
self._request.finished_reading()
764
759
raise errors.UnknownSmartMethod(self._last_verb)
776
771
while not _body_decoder.finished_reading:
777
772
bytes = self._request.read_bytes(_body_decoder.next_read_size())
779
774
# end of file encountered reading from server
780
775
raise errors.ConnectionReset(
781
776
"Connection lost while reading response body.")
782
777
_body_decoder.accept_bytes(bytes)
783
778
self._request.finished_reading()
784
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
779
self._body_buffer = StringIO(_body_decoder.read_pending_data())
785
780
# XXX: TODO check the trailer result.
786
781
if 'hpss' in debug.debug_flags:
787
782
mutter(' %d body bytes read',
795
790
def query_version(self):
796
791
"""Return protocol version number of the server."""
798
793
resp = self.read_response_tuple()
799
if resp == (b'ok', b'1'):
794
if resp == ('ok', '1'):
801
elif resp == (b'ok', b'2'):
796
elif resp == ('ok', '2'):
804
799
raise errors.SmartProtocolError("bad response %r" % (resp,))
836
831
response_status = self._request.read_line()
837
832
result = SmartClientRequestProtocolOne._read_response_tuple(self)
838
833
self._response_is_unknown_method(result)
839
if response_status == b'success\n':
834
if response_status == 'success\n':
840
835
self.response_status = True
841
836
if not expect_body:
842
837
self._request.finished_reading()
844
elif response_status == b'failed\n':
839
elif response_status == 'failed\n':
845
840
self.response_status = False
846
841
self._request.finished_reading()
847
842
raise errors.ErrorFromSmartServer(result)
864
859
_body_decoder = ChunkedBodyDecoder()
865
860
while not _body_decoder.finished_reading:
866
861
bytes = self._request.read_bytes(_body_decoder.next_read_size())
868
863
# end of file encountered reading from server
869
864
raise errors.ConnectionReset(
870
865
"Connection lost while reading streamed body.")
871
866
_body_decoder.accept_bytes(bytes)
872
867
for body_bytes in iter(_body_decoder.read_next_chunk, None):
873
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
868
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
874
869
mutter(' %d byte chunk read',
883
878
backing_transport, commands=request.request_handlers,
884
879
root_client_path=root_client_path, jail_root=jail_root)
885
880
responder = ProtocolThreeResponder(write_func)
886
message_handler = message.ConventionalRequestHandler(
887
request_handler, responder)
881
message_handler = message.ConventionalRequestHandler(request_handler, responder)
888
882
return ProtocolThreeDecoder(message_handler)
914
908
_StatefulDecoder.accept_bytes(self, bytes)
915
909
except KeyboardInterrupt:
917
except errors.SmartMessageHandlerError as exception:
911
except errors.SmartMessageHandlerError, exception:
918
912
# We do *not* set self.decoding_failed here. The message handler
919
913
# has raised an error, but the decoder is still able to parse bytes
920
914
# and determine when this message ends.
924
918
# The state machine is ready to continue decoding, but the
925
919
# exception has interrupted the loop that runs the state machine.
926
920
# So we call accept_bytes again to restart it.
927
self.accept_bytes(b'')
928
except Exception as exception:
921
self.accept_bytes('')
922
except Exception, exception:
929
923
# The decoder itself has raised an exception. We cannot continue
931
925
self.decoding_failed = True
1000
994
def _state_accept_expecting_headers(self):
1001
995
decoded = self._extract_prefixed_bencoded_data()
1002
if not isinstance(decoded, dict):
996
if type(decoded) is not dict:
1003
997
raise errors.SmartProtocolError(
1004
998
'Header object %r is not a dict' % (decoded,))
1005
999
self.state_accept = self._state_accept_expecting_message_part
1011
1005
def _state_accept_expecting_message_part(self):
1012
1006
message_part_kind = self._extract_single_byte()
1013
if message_part_kind == b'o':
1007
if message_part_kind == 'o':
1014
1008
self.state_accept = self._state_accept_expecting_one_byte
1015
elif message_part_kind == b's':
1009
elif message_part_kind == 's':
1016
1010
self.state_accept = self._state_accept_expecting_structure
1017
elif message_part_kind == b'b':
1011
elif message_part_kind == 'b':
1018
1012
self.state_accept = self._state_accept_expecting_bytes
1019
elif message_part_kind == b'e':
1013
elif message_part_kind == 'e':
1022
1016
raise errors.SmartProtocolError(
1108
1102
"""Serialise a readv offset list."""
1110
1104
for start, length in offsets:
1111
txt.append(b'%d,%d' % (start, length))
1112
return b'\n'.join(txt)
1105
txt.append('%d,%d' % (start, length))
1106
return '\n'.join(txt)
1114
1108
def _write_protocol_version(self):
1115
1109
self._write_func(MESSAGE_VERSION_THREE)
1123
1117
self._write_prefixed_bencode(headers)
1125
1119
def _write_structure(self, args):
1126
self._write_func(b's')
1120
self._write_func('s')
1128
1122
for arg in args:
1129
if isinstance(arg, str):
1123
if type(arg) is unicode:
1130
1124
utf8_args.append(arg.encode('utf8'))
1132
1126
utf8_args.append(arg)
1133
1127
self._write_prefixed_bencode(utf8_args)
1135
1129
def _write_end(self):
1136
self._write_func(b'e')
1130
self._write_func('e')
1139
1133
def _write_prefixed_body(self, bytes):
1140
self._write_func(b'b')
1134
self._write_func('b')
1141
1135
self._write_func(struct.pack('!L', len(bytes)))
1142
1136
self._write_func(bytes)
1144
1138
def _write_chunked_body_start(self):
1145
self._write_func(b'oC')
1139
self._write_func('oC')
1147
1141
def _write_error_status(self):
1148
self._write_func(b'oE')
1142
self._write_func('oE')
1150
1144
def _write_success_status(self):
1151
self._write_func(b'oS')
1145
self._write_func('oS')
1154
1148
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1156
1150
def __init__(self, write_func):
1157
1151
_ProtocolThreeEncoder.__init__(self, write_func)
1158
1152
self.response_sent = False
1160
b'Software version': breezy.__version__.encode('utf-8')}
1153
self._headers = {'Software version': brzlib.__version__}
1161
1154
if 'hpss' in debug.debug_flags:
1162
self._thread_id = _thread.get_ident()
1155
self._thread_id = thread.get_ident()
1163
1156
self._response_start_time = None
1165
1158
def _trace(self, action, message, extra_bytes=None, include_time=False):
1166
1159
if self._response_start_time is None:
1167
self._response_start_time = osutils.perf_counter()
1160
self._response_start_time = osutils.timer_func()
1168
1161
if include_time:
1169
t = '%5.3fs ' % (osutils.perf_counter() - self._response_start_time)
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1172
1165
if extra_bytes is None:
1185
1178
% (exception,))
1186
1179
if isinstance(exception, errors.UnknownSmartMethod):
1187
1180
failure = request.FailedSmartServerResponse(
1188
(b'UnknownMethod', exception.verb))
1181
('UnknownMethod', exception.verb))
1189
1182
self.send_response(failure)
1191
1184
if 'hpss' in debug.debug_flags:
1194
1187
self._write_protocol_version()
1195
1188
self._write_headers(self._headers)
1196
1189
self._write_error_status()
1197
self._write_structure(
1198
(b'error', str(exception).encode('utf-8', 'replace')))
1190
self._write_structure(('error', str(exception)))
1199
1191
self._write_end()
1201
1193
def send_response(self, response):
1309
1301
base = getattr(self._medium_request._medium, 'base', None)
1310
1302
if base is not None:
1311
1303
mutter(' (to %s)', base)
1312
self._request_start_time = osutils.perf_counter()
1304
self._request_start_time = osutils.timer_func()
1313
1305
self._write_protocol_version()
1314
1306
self._write_headers(self._headers)
1315
1307
self._write_structure(args)
1327
1319
if path is not None:
1328
1320
mutter(' (to %s)', path)
1329
1321
mutter(' %d bytes', len(body))
1330
self._request_start_time = osutils.perf_counter()
1322
self._request_start_time = osutils.timer_func()
1331
1323
self._write_protocol_version()
1332
1324
self._write_headers(self._headers)
1333
1325
self._write_structure(args)
1346
1338
path = getattr(self._medium_request._medium, '_path', None)
1347
1339
if path is not None:
1348
1340
mutter(' (to %s)', path)
1349
self._request_start_time = osutils.perf_counter()
1341
self._request_start_time = osutils.timer_func()
1350
1342
self._write_protocol_version()
1351
1343
self._write_headers(self._headers)
1352
1344
self._write_structure(args)
1363
1355
path = getattr(self._medium_request._medium, '_path', None)
1364
1356
if path is not None:
1365
1357
mutter(' (to %s)', path)
1366
self._request_start_time = osutils.perf_counter()
1358
self._request_start_time = osutils.timer_func()
1367
1359
self.body_stream_started = False
1368
1360
self._write_protocol_version()
1369
1361
self._write_headers(self._headers)
1381
1373
self._write_error_status()
1382
1374
# Currently the client unconditionally sends ('error',) as the
1384
self._write_structure((b'error',))
1376
self._write_structure(('error',))
1385
1377
self._write_end()
1386
1378
self._medium_request.finished_writing()
1387
(exc_type, exc_val, exc_tb) = exc_info
1379
raise exc_info[0], exc_info[1], exc_info[2]
1393
1381
self._write_prefixed_body(part)
1395
1383
self._write_end()
1396
1384
self._medium_request.finished_writing()