21
from __future__ import absolute_import
22
from cStringIO import StringIO
29
import thread as _thread
35
from bzrlib.smart import message, request
36
from bzrlib.trace import log_exception_quietly, mutter
37
from bzrlib.bencode import bdecode_as_tuple, bencode
38
from ...sixish import (
42
from . import message, request
43
from ...sixish import text_type
44
from ...trace import log_exception_quietly, mutter
45
from ...bencode import bdecode_as_tuple, bencode
40
48
# Protocol version strings. These are sent as prefixes of bzr requests and
41
49
# responses to identify the protocol version being used. (There are no version
42
50
# one strings because that version doesn't send any).
43
REQUEST_VERSION_TWO = 'bzr request 2\n'
44
RESPONSE_VERSION_TWO = 'bzr response 2\n'
51
REQUEST_VERSION_TWO = b'bzr request 2\n'
52
RESPONSE_VERSION_TWO = b'bzr response 2\n'
46
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
54
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
47
55
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
55
63
def _decode_tuple(req_line):
56
if req_line is None or req_line == '':
64
if req_line is None or req_line == b'':
58
if req_line[-1] != '\n':
66
if not req_line.endswith(b'\n'):
59
67
raise errors.SmartProtocolError("request %r not terminated" % req_line)
60
return tuple(req_line[:-1].split('\x01'))
68
return tuple(req_line[:-1].split(b'\x01'))
63
71
def _encode_tuple(args):
64
72
"""Encode the tuple args to a bytestream."""
65
joined = '\x01'.join(args) + '\n'
66
if type(joined) is unicode:
67
# XXX: We should fix things so this never happens! -AJB, 20100304
68
mutter('response args contain unicode, should be only bytes: %r',
70
joined = joined.encode('ascii')
74
if isinstance(arg, text_type):
76
return b'\x01'.join(args) + b'\n'
74
79
class Requester(object):
112
117
# support multiple chunks?
113
118
def _encode_bulk_data(self, body):
114
119
"""Encode body as a bulk data chunk."""
115
return ''.join(('%d\n' % len(body), body, 'done\n'))
120
return b''.join((b'%d\n' % len(body), body, b'done\n'))
117
122
def _serialise_offsets(self, offsets):
118
123
"""Serialise a readv offset list."""
120
125
for start, length in offsets:
121
txt.append('%d,%d' % (start, length))
122
return '\n'.join(txt)
126
txt.append(b'%d,%d' % (start, length))
127
return b'\n'.join(txt)
125
130
class SmartServerRequestProtocolOne(SmartProtocolBase):
130
135
self._backing_transport = backing_transport
131
136
self._root_client_path = root_client_path
132
137
self._jail_root = jail_root
133
self.unused_data = ''
138
self.unused_data = b''
134
139
self._finished = False
136
141
self._has_dispatched = False
137
142
self.request = None
138
143
self._body_decoder = None
139
144
self._write_func = write_func
141
def accept_bytes(self, bytes):
146
def accept_bytes(self, data):
142
147
"""Take bytes, and advance the internal state machine appropriately.
144
:param bytes: must be a byte string
149
:param data: must be a byte string
146
if not isinstance(bytes, str):
147
raise ValueError(bytes)
148
self.in_buffer += bytes
151
if not isinstance(data, bytes):
152
raise ValueError(data)
153
self.in_buffer += data
149
154
if not self._has_dispatched:
150
if '\n' not in self.in_buffer:
155
if b'\n' not in self.in_buffer:
151
156
# no command line yet
153
158
self._has_dispatched = True
155
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
160
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
157
162
req_args = _decode_tuple(first_line)
158
163
self.request = request.SmartServerRequestHandler(
159
164
self._backing_transport, commands=request.request_handlers,
163
168
if self.request.finished_reading:
164
169
# trivial request
165
170
self.unused_data = self.in_buffer
167
172
self._send_response(self.request.response)
168
173
except KeyboardInterrupt:
170
except errors.UnknownSmartMethod, err:
175
except errors.UnknownSmartMethod as err:
171
176
protocol_error = errors.SmartProtocolError(
172
177
"bad request %r" % (err.verb,))
173
178
failure = request.FailedSmartServerResponse(
174
('error', str(protocol_error)))
179
(b'error', str(protocol_error).encode('utf-8')))
175
180
self._send_response(failure)
177
except Exception, exception:
182
except Exception as exception:
178
183
# everything else: pass to client, flush, and quit
179
184
log_exception_quietly()
180
185
self._send_response(request.FailedSmartServerResponse(
181
('error', str(exception))))
186
(b'error', str(exception).encode('utf-8'))))
184
189
if self._has_dispatched:
258
263
def _write_success_or_failure_prefix(self, response):
259
264
"""Write the protocol specific success/failure prefix."""
260
265
if response.is_successful():
261
self._write_func('success\n')
266
self._write_func(b'success\n')
263
self._write_func('failed\n')
268
self._write_func(b'failed\n')
265
270
def _write_protocol_version(self):
266
271
r"""Write any prefixes this protocol requires.
278
283
self._write_success_or_failure_prefix(response)
279
284
self._write_func(_encode_tuple(response.args))
280
285
if response.body is not None:
281
if not isinstance(response.body, str):
282
raise AssertionError('body must be a str')
286
if not isinstance(response.body, bytes):
287
raise AssertionError('body must be bytes')
283
288
if not (response.body_stream is None):
284
289
raise AssertionError(
285
290
'body_stream and body cannot both be set')
286
bytes = self._encode_bulk_data(response.body)
287
self._write_func(bytes)
291
data = self._encode_bulk_data(response.body)
292
self._write_func(data)
288
293
elif response.body_stream is not None:
289
294
_send_stream(response.body_stream, self._write_func)
292
297
def _send_stream(stream, write_func):
293
write_func('chunked\n')
298
write_func(b'chunked\n')
294
299
_send_chunks(stream, write_func)
298
303
def _send_chunks(stream, write_func):
299
304
for chunk in stream:
300
if isinstance(chunk, str):
301
bytes = "%x\n%s" % (len(chunk), chunk)
305
if isinstance(chunk, bytes):
306
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
303
308
elif isinstance(chunk, request.FailedSmartServerResponse):
305
310
_send_chunks(chunk.args, write_func)
339
344
self.finished_reading = False
340
345
self._in_buffer_list = []
341
346
self._in_buffer_len = 0
342
self.unused_data = ''
347
self.unused_data = b''
343
348
self.bytes_left = None
344
349
self._number_needed_bytes = None
346
351
def _get_in_buffer(self):
347
352
if len(self._in_buffer_list) == 1:
348
353
return self._in_buffer_list[0]
349
in_buffer = ''.join(self._in_buffer_list)
354
in_buffer = b''.join(self._in_buffer_list)
350
355
if len(in_buffer) != self._in_buffer_len:
351
356
raise AssertionError(
352
357
"Length of buffer did not match expected value: %s != %s"
408
413
# _NeedMoreBytes).
409
414
current_state = self.state_accept
410
415
self.state_accept()
411
except _NeedMoreBytes, e:
416
except _NeedMoreBytes as e:
412
417
self._number_needed_bytes = e.count
490
495
def _state_accept_expecting_length(self):
491
496
prefix = self._extract_line()
493
498
self.error = True
494
499
self.error_in_progress = []
495
500
self._state_accept_expecting_length()
497
elif prefix == 'END':
502
elif prefix == b'END':
498
503
# We've read the end-of-body marker.
499
504
# Any further bytes are unused data, including the bytes left in
500
505
# the _in_buffer.
504
509
self.bytes_left = int(prefix, 16)
505
self.chunk_in_progress = ''
510
self.chunk_in_progress = b''
506
511
self.state_accept = self._state_accept_reading_chunk
508
513
def _state_accept_reading_chunk(self):
533
538
_StatefulDecoder.__init__(self)
534
539
self.state_accept = self._state_accept_expecting_length
535
540
self.state_read = self._state_read_no_data
537
self._trailer_buffer = ''
542
self._trailer_buffer = b''
539
544
def next_read_size(self):
540
545
if self.bytes_left is not None:
584
589
self._set_in_buffer(None)
585
590
# TODO: what if the trailer does not match "done\n"? Should this raise
586
591
# a ProtocolViolation exception?
587
if self._trailer_buffer.startswith('done\n'):
588
self.unused_data = self._trailer_buffer[len('done\n'):]
592
if self._trailer_buffer.startswith(b'done\n'):
593
self.unused_data = self._trailer_buffer[len(b'done\n'):]
589
594
self.state_accept = self._state_accept_reading_unused
590
595
self.finished_reading = True
654
659
"""Make a remote call with a readv array.
656
661
The body is encoded with one line per readv offset pair. The numbers in
657
each pair are separated by a comma, and no trailing \n is emitted.
662
each pair are separated by a comma, and no trailing \\n is emitted.
659
664
if 'hpss' in debug.debug_flags:
660
665
mutter('hpss call w/readv: %s', repr(args)[1:-1])
749
754
:param verb: The verb used in that call.
750
755
:raises: UnexpectedSmartServerResponse
752
if (result_tuple == ('error', "Generic bzr smart protocol error: "
753
"bad request '%s'" % self._last_verb) or
754
result_tuple == ('error', "Generic bzr smart protocol error: "
755
"bad request u'%s'" % self._last_verb)):
757
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
758
b"bad request '" + self._last_verb + b"'") or
759
result_tuple == (b'error', b"Generic bzr smart protocol error: "
760
b"bad request u'%s'" % self._last_verb)):
756
761
# The response will have no body, so we've finished reading.
757
762
self._request.finished_reading()
758
763
raise errors.UnknownSmartMethod(self._last_verb)
770
775
while not _body_decoder.finished_reading:
771
776
bytes = self._request.read_bytes(_body_decoder.next_read_size())
773
778
# end of file encountered reading from server
774
779
raise errors.ConnectionReset(
775
780
"Connection lost while reading response body.")
776
781
_body_decoder.accept_bytes(bytes)
777
782
self._request.finished_reading()
778
self._body_buffer = StringIO(_body_decoder.read_pending_data())
783
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
779
784
# XXX: TODO check the trailer result.
780
785
if 'hpss' in debug.debug_flags:
781
786
mutter(' %d body bytes read',
789
794
def query_version(self):
790
795
"""Return protocol version number of the server."""
792
797
resp = self.read_response_tuple()
793
if resp == ('ok', '1'):
798
if resp == (b'ok', '1'):
795
elif resp == ('ok', '2'):
800
elif resp == (b'ok', '2'):
798
803
raise errors.SmartProtocolError("bad response %r" % (resp,))
830
835
response_status = self._request.read_line()
831
836
result = SmartClientRequestProtocolOne._read_response_tuple(self)
832
837
self._response_is_unknown_method(result)
833
if response_status == 'success\n':
838
if response_status == b'success\n':
834
839
self.response_status = True
835
840
if not expect_body:
836
841
self._request.finished_reading()
838
elif response_status == 'failed\n':
843
elif response_status == b'failed\n':
839
844
self.response_status = False
840
845
self._request.finished_reading()
841
846
raise errors.ErrorFromSmartServer(result)
858
863
_body_decoder = ChunkedBodyDecoder()
859
864
while not _body_decoder.finished_reading:
860
865
bytes = self._request.read_bytes(_body_decoder.next_read_size())
862
867
# end of file encountered reading from server
863
868
raise errors.ConnectionReset(
864
869
"Connection lost while reading streamed body.")
865
870
_body_decoder.accept_bytes(bytes)
866
871
for body_bytes in iter(_body_decoder.read_next_chunk, None):
867
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
872
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
868
873
mutter(' %d byte chunk read',
907
912
_StatefulDecoder.accept_bytes(self, bytes)
908
913
except KeyboardInterrupt:
910
except errors.SmartMessageHandlerError, exception:
915
except errors.SmartMessageHandlerError as exception:
911
916
# We do *not* set self.decoding_failed here. The message handler
912
917
# has raised an error, but the decoder is still able to parse bytes
913
918
# and determine when this message ends.
918
923
# exception has interrupted the loop that runs the state machine.
919
924
# So we call accept_bytes again to restart it.
920
925
self.accept_bytes('')
921
except Exception, exception:
926
except Exception as exception:
922
927
# The decoder itself has raised an exception. We cannot continue
924
929
self.decoding_failed = True
993
998
def _state_accept_expecting_headers(self):
994
999
decoded = self._extract_prefixed_bencoded_data()
995
if type(decoded) is not dict:
1000
if not isinstance(decoded, dict):
996
1001
raise errors.SmartProtocolError(
997
1002
'Header object %r is not a dict' % (decoded,))
998
1003
self.state_accept = self._state_accept_expecting_message_part
1004
1009
def _state_accept_expecting_message_part(self):
1005
1010
message_part_kind = self._extract_single_byte()
1006
if message_part_kind == 'o':
1011
if message_part_kind == b'o':
1007
1012
self.state_accept = self._state_accept_expecting_one_byte
1008
elif message_part_kind == 's':
1013
elif message_part_kind == b's':
1009
1014
self.state_accept = self._state_accept_expecting_structure
1010
elif message_part_kind == 'b':
1015
elif message_part_kind == b'b':
1011
1016
self.state_accept = self._state_accept_expecting_bytes
1012
elif message_part_kind == 'e':
1017
elif message_part_kind == b'e':
1015
1020
raise errors.SmartProtocolError(
1081
1086
self._real_write_func = write_func
1083
1088
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
1089
# TODO: Another possibility would be to turn this into an async model.
1088
1090
# Where we let another thread know that we have some bytes if
1089
1091
# they want it, but we don't actually block for it
1104
1106
"""Serialise a readv offset list."""
1106
1108
for start, length in offsets:
1107
txt.append('%d,%d' % (start, length))
1108
return '\n'.join(txt)
1109
txt.append(b'%d,%d' % (start, length))
1110
return b'\n'.join(txt)
1110
1112
def _write_protocol_version(self):
1111
1113
self._write_func(MESSAGE_VERSION_THREE)
1119
1121
self._write_prefixed_bencode(headers)
1121
1123
def _write_structure(self, args):
1122
self._write_func('s')
1124
self._write_func(b's')
1124
1126
for arg in args:
1125
if type(arg) is unicode:
1127
if isinstance(arg, text_type):
1126
1128
utf8_args.append(arg.encode('utf8'))
1128
1130
utf8_args.append(arg)
1129
1131
self._write_prefixed_bencode(utf8_args)
1131
1133
def _write_end(self):
1132
self._write_func('e')
1134
self._write_func(b'e')
1135
1137
def _write_prefixed_body(self, bytes):
1136
self._write_func('b')
1138
self._write_func(b'b')
1137
1139
self._write_func(struct.pack('!L', len(bytes)))
1138
1140
self._write_func(bytes)
1140
1142
def _write_chunked_body_start(self):
1141
self._write_func('oC')
1143
self._write_func(b'oC')
1143
1145
def _write_error_status(self):
1144
self._write_func('oE')
1146
self._write_func(b'oE')
1146
1148
def _write_success_status(self):
1147
self._write_func('oS')
1149
self._write_func(b'oS')
1150
1152
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1152
1154
def __init__(self, write_func):
1153
1155
_ProtocolThreeEncoder.__init__(self, write_func)
1154
1156
self.response_sent = False
1155
self._headers = {'Software version': bzrlib.__version__}
1158
b'Software version': breezy.__version__.encode('utf-8')}
1156
1159
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1160
self._thread_id = _thread.get_ident()
1158
1161
self._response_start_time = None
1160
1163
def _trace(self, action, message, extra_bytes=None, include_time=False):
1180
1183
% (exception,))
1181
1184
if isinstance(exception, errors.UnknownSmartMethod):
1182
1185
failure = request.FailedSmartServerResponse(
1183
('UnknownMethod', exception.verb))
1186
(b'UnknownMethod', exception.verb))
1184
1187
self.send_response(failure)
1186
1189
if 'hpss' in debug.debug_flags:
1189
1192
self._write_protocol_version()
1190
1193
self._write_headers(self._headers)
1191
1194
self._write_error_status()
1192
self._write_structure(('error', str(exception)))
1195
self._write_structure((b'error', str(exception).encode('utf-8', 'replace')))
1193
1196
self._write_end()
1195
1198
def send_response(self, response):
1273
1277
iterator = iter(iterable)
1276
yield None, iterator.next()
1280
yield None, next(iterator)
1277
1281
except StopIteration:
1279
1283
except (KeyboardInterrupt, SystemExit):
1291
1295
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1292
1296
self._medium_request = medium_request
1293
1297
self._headers = {}
1298
self.body_stream_started = None
1295
1300
def set_headers(self, headers):
1296
1301
self._headers = headers.copy()
1331
1336
"""Make a remote call with a readv array.
1333
1338
The body is encoded with one line per readv offset pair. The numbers in
1334
each pair are separated by a comma, and no trailing \n is emitted.
1339
each pair are separated by a comma, and no trailing \\n is emitted.
1336
1341
if 'hpss' in debug.debug_flags:
1337
1342
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1356
1361
if path is not None:
1357
1362
mutter(' (to %s)', path)
1358
1363
self._request_start_time = osutils.timer_func()
1364
self.body_stream_started = False
1359
1365
self._write_protocol_version()
1360
1366
self._write_headers(self._headers)
1361
1367
self._write_structure(args)
1363
1369
# have finished sending the stream. We would notice at the end
1364
1370
# anyway, but if the medium can deliver it early then it's good
1365
1371
# to short-circuit the whole request...
1372
# Provoke any ConnectionReset failures before we start the body stream.
1374
self.body_stream_started = True
1366
1375
for exc_info, part in _iter_with_errors(stream):
1367
1376
if exc_info is not None:
1368
1377
# Iterating the stream failed. Cleanly abort the request.
1369
1378
self._write_error_status()
1370
1379
# Currently the client unconditionally sends ('error',) as the
1372
self._write_structure(('error',))
1381
self._write_structure((b'error',))
1373
1382
self._write_end()
1374
1383
self._medium_request.finished_writing()
1375
raise exc_info[0], exc_info[1], exc_info[2]
1377
1389
self._write_prefixed_body(part)