21
from __future__ import absolute_import
24
from collections.abc import deque
25
except ImportError: # python < 3.7
26
from collections import deque
22
from cStringIO import StringIO
33
import thread as _thread
42
from ...sixish import (
46
from . import message, request
47
from ...sixish import text_type
48
from ...trace import log_exception_quietly, mutter
49
from ...bencode import bdecode_as_tuple, bencode
28
from bzrlib import debug
29
from bzrlib import errors
30
from bzrlib.smart import message, request
31
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.bencode import bdecode_as_tuple, bencode
52
35
# Protocol version strings. These are sent as prefixes of bzr requests and
53
36
# responses to identify the protocol version being used. (There are no version
54
37
# one strings because that version doesn't send any).
55
REQUEST_VERSION_TWO = b'bzr request 2\n'
56
RESPONSE_VERSION_TWO = b'bzr response 2\n'
38
REQUEST_VERSION_TWO = 'bzr request 2\n'
39
RESPONSE_VERSION_TWO = 'bzr response 2\n'
58
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
41
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
59
42
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
121
101
# support multiple chunks?
122
102
def _encode_bulk_data(self, body):
123
103
"""Encode body as a bulk data chunk."""
124
return b''.join((b'%d\n' % len(body), body, b'done\n'))
104
return ''.join(('%d\n' % len(body), body, 'done\n'))
126
106
def _serialise_offsets(self, offsets):
127
107
"""Serialise a readv offset list."""
129
109
for start, length in offsets:
130
txt.append(b'%d,%d' % (start, length))
131
return b'\n'.join(txt)
110
txt.append('%d,%d' % (start, length))
111
return '\n'.join(txt)
134
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
135
115
"""Server-side encoding and decoding logic for smart version 1."""
137
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
139
118
self._backing_transport = backing_transport
140
119
self._root_client_path = root_client_path
141
self._jail_root = jail_root
142
self.unused_data = b''
120
self.unused_data = ''
143
121
self._finished = False
145
123
self._has_dispatched = False
146
124
self.request = None
147
125
self._body_decoder = None
148
126
self._write_func = write_func
150
def accept_bytes(self, data):
128
def accept_bytes(self, bytes):
151
129
"""Take bytes, and advance the internal state machine appropriately.
153
:param data: must be a byte string
131
:param bytes: must be a byte string
155
if not isinstance(data, bytes):
156
raise ValueError(data)
157
self.in_buffer += data
133
if not isinstance(bytes, str):
134
raise ValueError(bytes)
135
self.in_buffer += bytes
158
136
if not self._has_dispatched:
159
if b'\n' not in self.in_buffer:
137
if '\n' not in self.in_buffer:
160
138
# no command line yet
162
140
self._has_dispatched = True
164
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
142
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
166
144
req_args = _decode_tuple(first_line)
167
145
self.request = request.SmartServerRequestHandler(
168
146
self._backing_transport, commands=request.request_handlers,
169
root_client_path=self._root_client_path,
170
jail_root=self._jail_root)
171
self.request.args_received(req_args)
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
172
149
if self.request.finished_reading:
173
150
# trivial request
174
151
self.unused_data = self.in_buffer
176
153
self._send_response(self.request.response)
177
154
except KeyboardInterrupt:
179
except errors.UnknownSmartMethod as err:
156
except errors.UnknownSmartMethod, err:
180
157
protocol_error = errors.SmartProtocolError(
181
"bad request '%s'" % (err.verb.decode('ascii'),))
158
"bad request %r" % (err.verb,))
182
159
failure = request.FailedSmartServerResponse(
183
(b'error', str(protocol_error).encode('utf-8')))
160
('error', str(protocol_error)))
184
161
self._send_response(failure)
186
except Exception as exception:
163
except Exception, exception:
187
164
# everything else: pass to client, flush, and quit
188
165
log_exception_quietly()
189
166
self._send_response(request.FailedSmartServerResponse(
190
(b'error', str(exception).encode('utf-8'))))
167
('error', str(exception))))
193
170
if self._has_dispatched:
287
264
self._write_success_or_failure_prefix(response)
288
265
self._write_func(_encode_tuple(response.args))
289
266
if response.body is not None:
290
if not isinstance(response.body, bytes):
291
raise AssertionError('body must be bytes')
267
if not isinstance(response.body, str):
268
raise AssertionError('body must be a str')
292
269
if not (response.body_stream is None):
293
270
raise AssertionError(
294
271
'body_stream and body cannot both be set')
295
data = self._encode_bulk_data(response.body)
296
self._write_func(data)
272
bytes = self._encode_bulk_data(response.body)
273
self._write_func(bytes)
297
274
elif response.body_stream is not None:
298
275
_send_stream(response.body_stream, self._write_func)
301
278
def _send_stream(stream, write_func):
302
write_func(b'chunked\n')
279
write_func('chunked\n')
303
280
_send_chunks(stream, write_func)
307
284
def _send_chunks(stream, write_func):
308
285
for chunk in stream:
309
if isinstance(chunk, bytes):
310
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
286
if isinstance(chunk, str):
287
bytes = "%x\n%s" % (len(chunk), chunk)
312
289
elif isinstance(chunk, request.FailedSmartServerResponse):
314
291
_send_chunks(chunk.args, write_func)
385
362
def _set_in_buffer(self, new_buf):
386
363
if new_buf is not None:
387
if not isinstance(new_buf, bytes):
388
raise TypeError(new_buf)
389
364
self._in_buffer_list = [new_buf]
390
365
self._in_buffer_len = len(new_buf)
392
367
self._in_buffer_list = []
393
368
self._in_buffer_len = 0
395
def accept_bytes(self, new_buf):
370
def accept_bytes(self, bytes):
396
371
"""Decode as much of bytes as possible.
398
If 'new_buf' contains too much data it will be appended to
373
If 'bytes' contains too much data it will be appended to
399
374
self.unused_data.
401
376
finished_reading will be set when no more data is required. Further
402
377
data will be appended to self.unused_data.
404
if not isinstance(new_buf, bytes):
405
raise TypeError(new_buf)
406
379
# accept_bytes is allowed to change the state
407
380
self._number_needed_bytes = None
408
381
# lsprof puts a very large amount of time on this specific call for
409
382
# large readv arrays
410
self._in_buffer_list.append(new_buf)
411
self._in_buffer_len += len(new_buf)
383
self._in_buffer_list.append(bytes)
384
self._in_buffer_len += len(bytes)
413
386
# Run the function for the current state.
414
387
current_state = self.state_accept
652
625
if 'hpss' in debug.debug_flags:
653
626
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
654
627
if getattr(self._request._medium, '_path', None) is not None:
656
self._request._medium._path)
628
mutter(' (to %s)', self._request._medium._path)
657
629
mutter(' %d bytes', len(body))
658
self._request_start_time = osutils.perf_counter()
630
self._request_start_time = time.time()
659
631
if 'hpssdetail' in debug.debug_flags:
660
632
mutter('hpss body content: %s', body)
661
633
self._write_args(args)
668
640
"""Make a remote call with a readv array.
670
642
The body is encoded with one line per readv offset pair. The numbers in
671
each pair are separated by a comma, and no trailing \\n is emitted.
643
each pair are separated by a comma, and no trailing \n is emitted.
673
645
if 'hpss' in debug.debug_flags:
674
646
mutter('hpss call w/readv: %s', repr(args)[1:-1])
675
647
if getattr(self._request._medium, '_path', None) is not None:
677
self._request._medium._path)
678
self._request_start_time = osutils.perf_counter()
648
mutter(' (to %s)', self._request._medium._path)
649
self._request_start_time = time.time()
679
650
self._write_args(args)
680
651
readv_bytes = self._serialise_offsets(body)
681
652
bytes = self._encode_bulk_data(readv_bytes)
764
735
:param verb: The verb used in that call.
765
736
:raises: UnexpectedSmartServerResponse
767
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
768
b"bad request '" + self._last_verb + b"'")
769
or result_tuple == (b'error', b"Generic bzr smart protocol error: "
770
b"bad request u'%s'" % self._last_verb)):
738
if (result_tuple == ('error', "Generic bzr smart protocol error: "
739
"bad request '%s'" % self._last_verb) or
740
result_tuple == ('error', "Generic bzr smart protocol error: "
741
"bad request u'%s'" % self._last_verb)):
771
742
# The response will have no body, so we've finished reading.
772
743
self._request.finished_reading()
773
744
raise errors.UnknownSmartMethod(self._last_verb)
785
756
while not _body_decoder.finished_reading:
786
757
bytes = self._request.read_bytes(_body_decoder.next_read_size())
788
759
# end of file encountered reading from server
789
760
raise errors.ConnectionReset(
790
761
"Connection lost while reading response body.")
791
762
_body_decoder.accept_bytes(bytes)
792
763
self._request.finished_reading()
793
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
764
self._body_buffer = StringIO(_body_decoder.read_pending_data())
794
765
# XXX: TODO check the trailer result.
795
766
if 'hpss' in debug.debug_flags:
796
767
mutter(' %d body bytes read',
873
844
_body_decoder = ChunkedBodyDecoder()
874
845
while not _body_decoder.finished_reading:
875
846
bytes = self._request.read_bytes(_body_decoder.next_read_size())
877
848
# end of file encountered reading from server
878
849
raise errors.ConnectionReset(
879
850
"Connection lost while reading streamed body.")
880
851
_body_decoder.accept_bytes(bytes)
881
852
for body_bytes in iter(_body_decoder.read_next_chunk, None):
882
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
853
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
883
854
mutter(' %d byte chunk read',
889
860
def build_server_protocol_three(backing_transport, write_func,
890
root_client_path, jail_root=None):
891
862
request_handler = request.SmartServerRequestHandler(
892
863
backing_transport, commands=request.request_handlers,
893
root_client_path=root_client_path, jail_root=jail_root)
864
root_client_path=root_client_path)
894
865
responder = ProtocolThreeResponder(write_func)
895
message_handler = message.ConventionalRequestHandler(
896
request_handler, responder)
866
message_handler = message.ConventionalRequestHandler(request_handler, responder)
897
867
return ProtocolThreeDecoder(message_handler)
1020
990
def _state_accept_expecting_message_part(self):
1021
991
message_part_kind = self._extract_single_byte()
1022
if message_part_kind == b'o':
992
if message_part_kind == 'o':
1023
993
self.state_accept = self._state_accept_expecting_one_byte
1024
elif message_part_kind == b's':
994
elif message_part_kind == 's':
1025
995
self.state_accept = self._state_accept_expecting_structure
1026
elif message_part_kind == b'b':
996
elif message_part_kind == 'b':
1027
997
self.state_accept = self._state_accept_expecting_bytes
1028
elif message_part_kind == b'e':
998
elif message_part_kind == 'e':
1031
1001
raise errors.SmartProtocolError(
1089
1059
class _ProtocolThreeEncoder(object):
1091
1061
response_marker = request_marker = MESSAGE_VERSION_THREE
1092
BUFFER_SIZE = 1024 * 1024 # 1 MiB buffer before flushing
1094
1063
def __init__(self, write_func):
1097
1065
self._real_write_func = write_func
1099
1067
def _write_func(self, bytes):
1100
# TODO: Another possibility would be to turn this into an async model.
1101
# Where we let another thread know that we have some bytes if
1102
# they want it, but we don't actually block for it
1103
# Note that osutils.send_all always sends 64kB chunks anyway, so
1104
# we might just push out smaller bits at a time?
1105
1068
self._buf.append(bytes)
1106
self._buf_len += len(bytes)
1107
if self._buf_len > self.BUFFER_SIZE:
1069
if len(self._buf) > 100:
1110
1072
def flush(self):
1112
self._real_write_func(b''.join(self._buf))
1074
self._real_write_func(''.join(self._buf))
1113
1075
del self._buf[:]
1116
1077
def _serialise_offsets(self, offsets):
1117
1078
"""Serialise a readv offset list."""
1119
1080
for start, length in offsets:
1120
txt.append(b'%d,%d' % (start, length))
1121
return b'\n'.join(txt)
1081
txt.append('%d,%d' % (start, length))
1082
return '\n'.join(txt)
1123
1084
def _write_protocol_version(self):
1124
1085
self._write_func(MESSAGE_VERSION_THREE)
1132
1093
self._write_prefixed_bencode(headers)
1134
1095
def _write_structure(self, args):
1135
self._write_func(b's')
1096
self._write_func('s')
1137
1098
for arg in args:
1138
if isinstance(arg, text_type):
1099
if type(arg) is unicode:
1139
1100
utf8_args.append(arg.encode('utf8'))
1141
1102
utf8_args.append(arg)
1142
1103
self._write_prefixed_bencode(utf8_args)
1144
1105
def _write_end(self):
1145
self._write_func(b'e')
1106
self._write_func('e')
1148
1109
def _write_prefixed_body(self, bytes):
1149
self._write_func(b'b')
1110
self._write_func('b')
1150
1111
self._write_func(struct.pack('!L', len(bytes)))
1151
1112
self._write_func(bytes)
1153
1114
def _write_chunked_body_start(self):
1154
self._write_func(b'oC')
1115
self._write_func('oC')
1156
1117
def _write_error_status(self):
1157
self._write_func(b'oE')
1118
self._write_func('oE')
1159
1120
def _write_success_status(self):
1160
self._write_func(b'oS')
1121
self._write_func('oS')
1163
1124
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1165
1126
def __init__(self, write_func):
1166
1127
_ProtocolThreeEncoder.__init__(self, write_func)
1167
1128
self.response_sent = False
1169
b'Software version': breezy.__version__.encode('utf-8')}
1170
if 'hpss' in debug.debug_flags:
1171
self._thread_id = _thread.get_ident()
1172
self._response_start_time = None
1174
def _trace(self, action, message, extra_bytes=None, include_time=False):
1175
if self._response_start_time is None:
1176
self._response_start_time = osutils.perf_counter()
1178
t = '%5.3fs ' % (osutils.perf_counter() - self._response_start_time)
1181
if extra_bytes is None:
1184
extra = ' ' + repr(extra_bytes[:40])
1186
extra = extra[:29] + extra[-1] + '...'
1187
mutter('%12s: [%s] %s%s%s'
1188
% (action, self._thread_id, t, message, extra))
1129
self._headers = {'Software version': bzrlib.__version__}
1190
1131
def send_error(self, exception):
1191
1132
if self.response_sent:
1194
1135
% (exception,))
1195
1136
if isinstance(exception, errors.UnknownSmartMethod):
1196
1137
failure = request.FailedSmartServerResponse(
1197
(b'UnknownMethod', exception.verb))
1138
('UnknownMethod', exception.verb))
1198
1139
self.send_response(failure)
1200
if 'hpss' in debug.debug_flags:
1201
self._trace('error', str(exception))
1202
1141
self.response_sent = True
1203
1142
self._write_protocol_version()
1204
1143
self._write_headers(self._headers)
1205
1144
self._write_error_status()
1206
self._write_structure(
1207
(b'error', str(exception).encode('utf-8', 'replace')))
1145
self._write_structure(('error', str(exception)))
1208
1146
self._write_end()
1210
1148
def send_response(self, response):
1219
1157
self._write_success_status()
1221
1159
self._write_error_status()
1222
if 'hpss' in debug.debug_flags:
1223
self._trace('response', repr(response.args))
1224
1160
self._write_structure(response.args)
1225
1161
if response.body is not None:
1226
1162
self._write_prefixed_body(response.body)
1227
if 'hpss' in debug.debug_flags:
1228
self._trace('body', '%d bytes' % (len(response.body),),
1229
response.body, include_time=True)
1230
1163
elif response.body_stream is not None:
1231
count = num_bytes = 0
1233
1164
for exc_info, chunk in _iter_with_errors(response.body_stream):
1235
1165
if exc_info is not None:
1236
1166
self._write_error_status()
1237
1167
error_struct = request._translate_error(exc_info[1])
1242
1172
self._write_error_status()
1243
1173
self._write_structure(chunk.args)
1245
num_bytes += len(chunk)
1246
if first_chunk is None:
1248
1175
self._write_prefixed_body(chunk)
1250
if 'hpssdetail' in debug.debug_flags:
1251
# Not worth timing separately, as _write_func is
1253
self._trace('body chunk',
1254
'%d bytes' % (len(chunk),),
1255
chunk, suppress_time=True)
1256
if 'hpss' in debug.debug_flags:
1257
self._trace('body stream',
1258
'%d bytes %d chunks' % (num_bytes, count),
1260
1176
self._write_end()
1261
if 'hpss' in debug.debug_flags:
1262
self._trace('response end', '', include_time=True)
1265
1179
def _iter_with_errors(iterable):
1348
1261
"""Make a remote call with a readv array.
1350
1263
The body is encoded with one line per readv offset pair. The numbers in
1351
each pair are separated by a comma, and no trailing \\n is emitted.
1264
each pair are separated by a comma, and no trailing \n is emitted.
1353
1266
if 'hpss' in debug.debug_flags:
1354
1267
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1355
1268
path = getattr(self._medium_request._medium, '_path', None)
1356
1269
if path is not None:
1357
1270
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.perf_counter()
1271
self._request_start_time = time.time()
1359
1272
self._write_protocol_version()
1360
1273
self._write_headers(self._headers)
1361
1274
self._write_structure(args)
1381
1293
# have finished sending the stream. We would notice at the end
1382
1294
# anyway, but if the medium can deliver it early then it's good
1383
1295
# to short-circuit the whole request...
1384
# Provoke any ConnectionReset failures before we start the body stream.
1386
self.body_stream_started = True
1387
1296
for exc_info, part in _iter_with_errors(stream):
1388
1297
if exc_info is not None:
1389
1298
# Iterating the stream failed. Cleanly abort the request.
1390
1299
self._write_error_status()
1391
1300
# Currently the client unconditionally sends ('error',) as the
1393
self._write_structure((b'error',))
1302
self._write_structure(('error',))
1394
1303
self._write_end()
1395
1304
self._medium_request.finished_writing()
1305
raise exc_info[0], exc_info[1], exc_info[2]
1401
1307
self._write_prefixed_body(part)
1403
1309
self._write_end()
1404
1310
self._medium_request.finished_writing()