13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
from bzrlib.bencode import bdecode_as_tuple, bencode
32
from bzrlib.util.bencode import bdecode, bencode
40
35
# Protocol version strings. These are sent as prefixes of bzr requests and
63
58
def _encode_tuple(args):
64
59
"""Encode the tuple args to a bytestream."""
65
joined = '\x01'.join(args) + '\n'
66
if type(joined) is unicode:
67
# XXX: We should fix things so this never happens! -AJB, 20100304
68
mutter('response args contain unicode, should be only bytes: %r',
70
joined = joined.encode('ascii')
60
return '\x01'.join(args) + '\n'
74
63
class Requester(object):
120
109
for start, length in offsets:
121
110
txt.append('%d,%d' % (start, length))
122
111
return '\n'.join(txt)
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
"""Server-side encoding and decoding logic for smart version 1."""
128
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
self._backing_transport = backing_transport
131
119
self._root_client_path = root_client_path
132
self._jail_root = jail_root
133
120
self.unused_data = ''
134
121
self._finished = False
135
122
self.in_buffer = ''
136
self._has_dispatched = False
123
self.has_dispatched = False
137
124
self.request = None
138
125
self._body_decoder = None
139
126
self._write_func = write_func
141
128
def accept_bytes(self, bytes):
142
129
"""Take bytes, and advance the internal state machine appropriately.
144
131
:param bytes: must be a byte string
146
if not isinstance(bytes, str):
147
raise ValueError(bytes)
133
assert isinstance(bytes, str)
148
134
self.in_buffer += bytes
149
if not self._has_dispatched:
135
if not self.has_dispatched:
150
136
if '\n' not in self.in_buffer:
151
137
# no command line yet
153
self._has_dispatched = True
139
self.has_dispatched = True
155
141
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
156
142
first_line += '\n'
157
143
req_args = _decode_tuple(first_line)
158
144
self.request = request.SmartServerRequestHandler(
159
145
self._backing_transport, commands=request.request_handlers,
160
root_client_path=self._root_client_path,
161
jail_root=self._jail_root)
162
self.request.args_received(req_args)
146
root_client_path=self._root_client_path)
147
self.request.dispatch_command(req_args[0], req_args[1:])
163
148
if self.request.finished_reading:
164
149
# trivial request
165
150
self.unused_data = self.in_buffer
196
181
self.request.accept_body(body_data)
197
182
if self._body_decoder.finished_reading:
198
183
self.request.end_of_body()
199
if not self.request.finished_reading:
200
raise AssertionError("no more body, request not finished")
184
assert self.request.finished_reading, \
185
"no more body, request not finished"
201
186
if self.request.response is not None:
202
187
self._send_response(self.request.response)
203
188
self.unused_data = self.in_buffer
204
189
self.in_buffer = ''
206
if self.request.finished_reading:
207
raise AssertionError(
208
"no response and we have finished reading.")
191
assert not self.request.finished_reading, \
192
"no response and we have finished reading."
210
194
def _send_response(self, response):
211
195
"""Send a smart server response down the output stream."""
213
raise AssertionError('response already sent')
196
assert not self._finished, 'response already sent'
214
197
args = response.args
215
198
body = response.body
216
199
self._finished = True
218
201
self._write_success_or_failure_prefix(response)
219
202
self._write_func(_encode_tuple(args))
220
203
if body is not None:
221
if not isinstance(body, str):
222
raise ValueError(body)
204
assert isinstance(body, str), 'body must be a str'
223
205
bytes = self._encode_bulk_data(body)
224
206
self._write_func(bytes)
226
208
def _write_protocol_version(self):
227
209
"""Write any prefixes this protocol requires.
229
211
Version one doesn't send protocol versions.
265
247
def _write_protocol_version(self):
266
248
r"""Write any prefixes this protocol requires.
268
250
Version two sends the value of RESPONSE_VERSION_TWO.
270
252
self._write_func(self.response_marker)
272
254
def _send_response(self, response):
273
255
"""Send a smart server response down the output stream."""
275
raise AssertionError('response already sent')
256
assert not self._finished, 'response already sent'
276
257
self._finished = True
277
258
self._write_protocol_version()
278
259
self._write_success_or_failure_prefix(response)
279
260
self._write_func(_encode_tuple(response.args))
280
261
if response.body is not None:
281
if not isinstance(response.body, str):
282
raise AssertionError('body must be a str')
283
if not (response.body_stream is None):
284
raise AssertionError(
285
'body_stream and body cannot both be set')
262
assert isinstance(response.body, str), 'body must be a str'
263
assert response.body_stream is None, (
264
'body_stream and body cannot both be set')
286
265
bytes = self._encode_bulk_data(response.body)
287
266
self._write_func(bytes)
288
267
elif response.body_stream is not None:
318
297
def __init__(self, count=None):
321
:param count: the total number of bytes needed by the current state.
322
May be None if the number of bytes needed is unknown.
324
298
self.count = count
327
301
class _StatefulDecoder(object):
328
"""Base class for writing state machines to decode byte streams.
330
Subclasses should provide a self.state_accept attribute that accepts bytes
331
and, if appropriate, updates self.state_accept to a different function.
332
accept_bytes will call state_accept as often as necessary to make sure the
333
state machine has progressed as far as possible before it returns.
335
See ProtocolThreeDecoder for an example subclass.
338
303
def __init__(self):
339
304
self.finished_reading = False
340
self._in_buffer_list = []
341
self._in_buffer_len = 0
342
305
self.unused_data = ''
343
306
self.bytes_left = None
344
307
self._number_needed_bytes = None
346
def _get_in_buffer(self):
347
if len(self._in_buffer_list) == 1:
348
return self._in_buffer_list[0]
349
in_buffer = ''.join(self._in_buffer_list)
350
if len(in_buffer) != self._in_buffer_len:
351
raise AssertionError(
352
"Length of buffer did not match expected value: %s != %s"
353
% self._in_buffer_len, len(in_buffer))
354
self._in_buffer_list = [in_buffer]
357
def _get_in_bytes(self, count):
358
"""Grab X bytes from the input_buffer.
360
Callers should have already checked that self._in_buffer_len is >
361
count. Note, this does not consume the bytes from the buffer. The
362
caller will still need to call _get_in_buffer() and then
363
_set_in_buffer() if they actually need to consume the bytes.
365
# check if we can yield the bytes from just the first entry in our list
366
if len(self._in_buffer_list) == 0:
367
raise AssertionError('Callers must be sure we have buffered bytes'
368
' before calling _get_in_bytes')
369
if len(self._in_buffer_list[0]) > count:
370
return self._in_buffer_list[0][:count]
371
# We can't yield it from the first buffer, so collapse all buffers, and
373
in_buf = self._get_in_buffer()
374
return in_buf[:count]
376
def _set_in_buffer(self, new_buf):
377
if new_buf is not None:
378
self._in_buffer_list = [new_buf]
379
self._in_buffer_len = len(new_buf)
381
self._in_buffer_list = []
382
self._in_buffer_len = 0
384
309
def accept_bytes(self, bytes):
385
310
"""Decode as much of bytes as possible.
391
316
data will be appended to self.unused_data.
393
318
# accept_bytes is allowed to change the state
319
current_state = self.state_accept
394
320
self._number_needed_bytes = None
395
# lsprof puts a very large amount of time on this specific call for
397
self._in_buffer_list.append(bytes)
398
self._in_buffer_len += len(bytes)
400
# Run the function for the current state.
401
current_state = self.state_accept
322
self.state_accept(bytes)
403
323
while current_state != self.state_accept:
404
# The current state has changed. Run the function for the new
405
# current state, so that it can:
406
# - decode any unconsumed bytes left in a buffer, and
407
# - signal how many more bytes are expected (via raising
409
324
current_state = self.state_accept
325
self.state_accept('')
411
326
except _NeedMoreBytes, e:
412
327
self._number_needed_bytes = e.count
422
337
def __init__(self):
423
338
_StatefulDecoder.__init__(self)
424
339
self.state_accept = self._state_accept_expecting_header
425
341
self.chunk_in_progress = None
426
342
self.chunks = collections.deque()
427
343
self.error = False
428
344
self.error_in_progress = None
430
346
def next_read_size(self):
431
347
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
432
348
# end-of-body marker is 4 bytes: 'END\n'.
459
375
def _extract_line(self):
460
in_buf = self._get_in_buffer()
461
pos = in_buf.find('\n')
376
pos = self._in_buffer.find('\n')
463
# We haven't read a complete line yet, so request more bytes before
378
# We haven't read a complete line yet, so there's nothing to do.
465
379
raise _NeedMoreBytes(1)
380
line = self._in_buffer[:pos]
467
381
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
468
self._set_in_buffer(in_buf[pos+1:])
382
self._in_buffer = self._in_buffer[pos+1:]
471
385
def _finished(self):
472
self.unused_data = self._get_in_buffer()
473
self._in_buffer_list = []
474
self._in_buffer_len = 0
386
self.unused_data = self._in_buffer
387
self._in_buffer = None
475
388
self.state_accept = self._state_accept_reading_unused
477
390
error_args = tuple(self.error_in_progress)
487
401
raise errors.SmartProtocolError(
488
402
'Bad chunked body header: "%s"' % (prefix,))
490
def _state_accept_expecting_length(self):
404
def _state_accept_expecting_length(self, bytes):
405
self._in_buffer += bytes
491
406
prefix = self._extract_line()
492
407
if prefix == 'ERR':
493
408
self.error = True
494
409
self.error_in_progress = []
495
self._state_accept_expecting_length()
410
self._state_accept_expecting_length('')
497
412
elif prefix == 'END':
498
413
# We've read the end-of-body marker.
505
420
self.chunk_in_progress = ''
506
421
self.state_accept = self._state_accept_reading_chunk
508
def _state_accept_reading_chunk(self):
509
in_buf = self._get_in_buffer()
510
in_buffer_len = len(in_buf)
511
self.chunk_in_progress += in_buf[:self.bytes_left]
512
self._set_in_buffer(in_buf[self.bytes_left:])
423
def _state_accept_reading_chunk(self, bytes):
424
self._in_buffer += bytes
425
in_buffer_len = len(self._in_buffer)
426
self.chunk_in_progress += self._in_buffer[:self.bytes_left]
427
self._in_buffer = self._in_buffer[self.bytes_left:]
513
428
self.bytes_left -= in_buffer_len
514
429
if self.bytes_left <= 0:
515
430
# Finished with chunk
520
435
self.chunks.append(self.chunk_in_progress)
521
436
self.chunk_in_progress = None
522
437
self.state_accept = self._state_accept_expecting_length
524
def _state_accept_reading_unused(self):
525
self.unused_data += self._get_in_buffer()
526
self._in_buffer_list = []
439
def _state_accept_reading_unused(self, bytes):
440
self.unused_data += bytes
529
443
class LengthPrefixedBodyDecoder(_StatefulDecoder):
530
444
"""Decodes the length-prefixed bulk data."""
532
446
def __init__(self):
533
447
_StatefulDecoder.__init__(self)
534
448
self.state_accept = self._state_accept_expecting_length
535
449
self.state_read = self._state_read_no_data
537
451
self._trailer_buffer = ''
539
453
def next_read_size(self):
540
454
if self.bytes_left is not None:
541
455
# Ideally we want to read all the remainder of the body and the
552
466
# Reading excess data. Either way, 1 byte at a time is fine.
555
469
def read_pending_data(self):
556
470
"""Return any pending data that has been decoded."""
557
471
return self.state_read()
559
def _state_accept_expecting_length(self):
560
in_buf = self._get_in_buffer()
561
pos = in_buf.find('\n')
473
def _state_accept_expecting_length(self, bytes):
474
self._in_buffer += bytes
475
pos = self._in_buffer.find('\n')
564
self.bytes_left = int(in_buf[:pos])
565
self._set_in_buffer(in_buf[pos+1:])
478
self.bytes_left = int(self._in_buffer[:pos])
479
self._in_buffer = self._in_buffer[pos+1:]
480
self.bytes_left -= len(self._in_buffer)
566
481
self.state_accept = self._state_accept_reading_body
567
self.state_read = self._state_read_body_buffer
482
self.state_read = self._state_read_in_buffer
569
def _state_accept_reading_body(self):
570
in_buf = self._get_in_buffer()
572
self.bytes_left -= len(in_buf)
573
self._set_in_buffer(None)
484
def _state_accept_reading_body(self, bytes):
485
self._in_buffer += bytes
486
self.bytes_left -= len(bytes)
574
487
if self.bytes_left <= 0:
575
488
# Finished with body
576
489
if self.bytes_left != 0:
577
self._trailer_buffer = self._body[self.bytes_left:]
578
self._body = self._body[:self.bytes_left]
490
self._trailer_buffer = self._in_buffer[self.bytes_left:]
491
self._in_buffer = self._in_buffer[:self.bytes_left]
579
492
self.bytes_left = None
580
493
self.state_accept = self._state_accept_reading_trailer
582
def _state_accept_reading_trailer(self):
583
self._trailer_buffer += self._get_in_buffer()
584
self._set_in_buffer(None)
495
def _state_accept_reading_trailer(self, bytes):
496
self._trailer_buffer += bytes
585
497
# TODO: what if the trailer does not match "done\n"? Should this raise
586
498
# a ProtocolViolation exception?
587
499
if self._trailer_buffer.startswith('done\n'):
588
500
self.unused_data = self._trailer_buffer[len('done\n'):]
589
501
self.state_accept = self._state_accept_reading_unused
590
502
self.finished_reading = True
592
def _state_accept_reading_unused(self):
593
self.unused_data += self._get_in_buffer()
594
self._set_in_buffer(None)
504
def _state_accept_reading_unused(self, bytes):
505
self.unused_data += bytes
596
507
def _state_read_no_data(self):
599
def _state_read_body_buffer(self):
510
def _state_read_in_buffer(self):
511
result = self._in_buffer
605
516
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
606
message.ResponseHandler):
517
message.ResponseHandler):
607
518
"""The client-side protocol for smart version 1."""
609
520
def __init__(self, request):
626
537
mutter('hpss call: %s', repr(args)[1:-1])
627
538
if getattr(self._request._medium, 'base', None) is not None:
628
539
mutter(' (to %s)', self._request._medium.base)
629
self._request_start_time = osutils.timer_func()
540
self._request_start_time = time.time()
630
541
self._write_args(args)
631
542
self._request.finished_writing()
632
543
self._last_verb = args[0]
641
552
if getattr(self._request._medium, '_path', None) is not None:
642
553
mutter(' (to %s)', self._request._medium._path)
643
554
mutter(' %d bytes', len(body))
644
self._request_start_time = osutils.timer_func()
555
self._request_start_time = time.time()
645
556
if 'hpssdetail' in debug.debug_flags:
646
557
mutter('hpss body content: %s', body)
647
558
self._write_args(args)
660
571
mutter('hpss call w/readv: %s', repr(args)[1:-1])
661
572
if getattr(self._request._medium, '_path', None) is not None:
662
573
mutter(' (to %s)', self._request._medium._path)
663
self._request_start_time = osutils.timer_func()
574
self._request_start_time = time.time()
664
575
self._write_args(args)
665
576
readv_bytes = self._serialise_offsets(body)
666
577
bytes = self._encode_bulk_data(readv_bytes)
670
581
mutter(' %d bytes in readv request', len(readv_bytes))
671
582
self._last_verb = args[0]
673
def call_with_body_stream(self, args, stream):
674
# Protocols v1 and v2 don't support body streams. So it's safe to
675
# assume that a v1/v2 server doesn't support whatever method we're
676
# trying to call with a body stream.
677
self._request.finished_writing()
678
self._request.finished_reading()
679
raise errors.UnknownSmartMethod(args[0])
681
584
def cancel_read_body(self):
682
585
"""After expecting a body, a response code may indicate one otherwise.
756
654
# The response will have no body, so we've finished reading.
757
655
self._request.finished_reading()
758
656
raise errors.UnknownSmartMethod(self._last_verb)
760
658
def read_body_bytes(self, count=-1):
761
659
"""Read bytes from the body, decoding into a byte stream.
763
We read all bytes at once to ensure we've checked the trailer for
661
We read all bytes at once to ensure we've checked the trailer for
764
662
errors, and then feed the buffer back as read_body_bytes is called.
766
664
if self._body_buffer is not None:
767
665
return self._body_buffer.read(count)
768
666
_body_decoder = LengthPrefixedBodyDecoder()
668
# Read no more than 64k at a time so that we don't risk error 10055 (no
669
# buffer space available) on Windows.
770
671
while not _body_decoder.finished_reading:
771
bytes = self._request.read_bytes(_body_decoder.next_read_size())
773
# end of file encountered reading from server
774
raise errors.ConnectionReset(
775
"Connection lost while reading response body.")
672
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
673
bytes = self._request.read_bytes(bytes_wanted)
776
674
_body_decoder.accept_bytes(bytes)
777
675
self._request.finished_reading()
778
676
self._body_buffer = StringIO(_body_decoder.read_pending_data())
785
683
def _recv_tuple(self):
786
684
"""Receive a tuple from the medium request."""
787
return _decode_tuple(self._request.read_line())
685
return _decode_tuple(self._recv_line())
687
def _recv_line(self):
688
"""Read an entire line from the medium request."""
690
while not line or line[-1] != '\n':
691
# TODO: this is inefficient - but tuples are short.
692
new_char = self._request.read_bytes(1)
694
# end of file encountered reading from server
695
raise errors.ConnectionReset(
696
"please check connectivity and permissions",
697
"(and try -Dhpss if further diagnosis is required)")
789
701
def query_version(self):
790
702
"""Return protocol version number of the server."""
827
741
if version != self.response_marker:
828
742
self._request.finished_reading()
829
743
raise errors.UnexpectedProtocolVersionMarker(version)
830
response_status = self._request.read_line()
744
response_status = self._recv_line()
831
745
result = SmartClientRequestProtocolOne._read_response_tuple(self)
832
self._response_is_unknown_method(result)
833
746
if response_status == 'success\n':
834
747
self.response_status = True
835
748
if not expect_body:
856
769
# Read no more than 64k at a time so that we don't risk error 10055 (no
857
770
# buffer space available) on Windows.
858
772
_body_decoder = ChunkedBodyDecoder()
859
773
while not _body_decoder.finished_reading:
860
bytes = self._request.read_bytes(_body_decoder.next_read_size())
862
# end of file encountered reading from server
863
raise errors.ConnectionReset(
864
"Connection lost while reading streamed body.")
774
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
775
bytes = self._request.read_bytes(bytes_wanted)
865
776
_body_decoder.accept_bytes(bytes)
866
777
for body_bytes in iter(_body_decoder.read_next_chunk, None):
867
778
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
874
785
def build_server_protocol_three(backing_transport, write_func,
875
root_client_path, jail_root=None):
876
787
request_handler = request.SmartServerRequestHandler(
877
788
backing_transport, commands=request.request_handlers,
878
root_client_path=root_client_path, jail_root=jail_root)
789
root_client_path=root_client_path)
879
790
responder = ProtocolThreeResponder(write_func)
880
791
message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
792
return ProtocolThreeDecoder(message_handler)
889
800
def __init__(self, message_handler, expect_version_marker=False):
890
801
_StatefulDecoder.__init__(self)
891
self._has_dispatched = False
802
self.has_dispatched = False
893
805
if expect_version_marker:
894
806
self.state_accept = self._state_accept_expecting_protocol_version
895
807
# We're expecting at least the protocol version marker + some
911
823
# We do *not* set self.decoding_failed here. The message handler
912
824
# has raised an error, but the decoder is still able to parse bytes
913
825
# and determine when this message ends.
914
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
log_exception_quietly()
826
log_exception_quietly()
916
827
self.message_handler.protocol_error(exception.exc_value)
917
828
# The state machine is ready to continue decoding, but the
918
829
# exception has interrupted the loop that runs the state machine.
926
837
# This happens during normal operation when the client tries a
927
838
# protocol version the server doesn't understand, so no need to
928
839
# log a traceback every time.
929
# Note that this can only happen when
930
# expect_version_marker=True, which is only the case on the
934
842
log_exception_quietly()
935
843
self.message_handler.protocol_error(exception)
937
845
def _extract_length_prefixed_bytes(self):
938
if self._in_buffer_len < 4:
846
if len(self._in_buffer) < 4:
939
847
# A length prefix by itself is 4 bytes, and we don't even have that
941
849
raise _NeedMoreBytes(4)
942
(length,) = struct.unpack('!L', self._get_in_bytes(4))
850
(length,) = struct.unpack('!L', self._in_buffer[:4])
943
851
end_of_bytes = 4 + length
944
if self._in_buffer_len < end_of_bytes:
852
if len(self._in_buffer) < end_of_bytes:
945
853
# We haven't yet read as many bytes as the length-prefix says there
947
855
raise _NeedMoreBytes(end_of_bytes)
948
856
# Extract the bytes from the buffer.
949
in_buf = self._get_in_buffer()
950
bytes = in_buf[4:end_of_bytes]
951
self._set_in_buffer(in_buf[end_of_bytes:])
857
bytes = self._in_buffer[4:end_of_bytes]
858
self._in_buffer = self._in_buffer[end_of_bytes:]
954
861
def _extract_prefixed_bencoded_data(self):
955
862
prefixed_bytes = self._extract_length_prefixed_bytes()
957
decoded = bdecode_as_tuple(prefixed_bytes)
864
decoded = bdecode(prefixed_bytes)
958
865
except ValueError:
959
866
raise errors.SmartProtocolError(
960
867
'Bytes %r not bencoded' % (prefixed_bytes,))
963
870
def _extract_single_byte(self):
964
if self._in_buffer_len == 0:
871
if self._in_buffer == '':
965
872
# The buffer is empty
966
873
raise _NeedMoreBytes(1)
967
in_buf = self._get_in_buffer()
969
self._set_in_buffer(in_buf[1:])
874
one_byte = self._in_buffer[0]
875
self._in_buffer = self._in_buffer[1:]
972
def _state_accept_expecting_protocol_version(self):
973
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
974
in_buf = self._get_in_buffer()
878
def _state_accept_expecting_protocol_version(self, bytes):
879
self._in_buffer += bytes
880
needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
975
881
if needed_bytes > 0:
976
882
# We don't have enough bytes to check if the protocol version
977
883
# marker is right. But we can check if it is already wrong by
981
887
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
982
888
# are wrong then we should just raise immediately rather than
984
if not MESSAGE_VERSION_THREE.startswith(in_buf):
890
if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
985
891
# We have enough bytes to know the protocol version is wrong
986
raise errors.UnexpectedProtocolVersionMarker(in_buf)
892
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
987
893
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
988
if not in_buf.startswith(MESSAGE_VERSION_THREE):
989
raise errors.UnexpectedProtocolVersionMarker(in_buf)
990
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
894
if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
895
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
896
self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
991
897
self.state_accept = self._state_accept_expecting_headers
993
def _state_accept_expecting_headers(self):
899
def _state_accept_expecting_headers(self, bytes):
900
self._in_buffer += bytes
994
901
decoded = self._extract_prefixed_bencoded_data()
995
902
if type(decoded) is not dict:
996
903
raise errors.SmartProtocolError(
1024
933
raise errors.SmartMessageHandlerError(sys.exc_info())
1026
def _state_accept_expecting_bytes(self):
935
def _state_accept_expecting_bytes(self, bytes):
1027
936
# XXX: this should not buffer whole message part, but instead deliver
1028
937
# the bytes as they arrive.
938
self._in_buffer += bytes
1029
939
prefixed_bytes = self._extract_length_prefixed_bytes()
1030
940
self.state_accept = self._state_accept_expecting_message_part
1042
953
raise errors.SmartMessageHandlerError(sys.exc_info())
1045
self.unused_data = self._get_in_buffer()
1046
self._set_in_buffer(None)
956
self.unused_data = self._in_buffer
957
self._in_buffer = None
1047
958
self.state_accept = self._state_accept_reading_unused
1049
960
self.message_handler.end_received()
1051
962
raise errors.SmartMessageHandlerError(sys.exc_info())
1053
def _state_accept_reading_unused(self):
1054
self.unused_data += self._get_in_buffer()
1055
self._set_in_buffer(None)
964
def _state_accept_reading_unused(self, bytes):
965
self.unused_data += bytes
1057
967
def next_read_size(self):
1058
968
if self.state_accept == self._state_accept_reading_unused:
1073
983
class _ProtocolThreeEncoder(object):
1075
985
response_marker = request_marker = MESSAGE_VERSION_THREE
1076
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1078
987
def __init__(self, write_func):
1081
self._real_write_func = write_func
1083
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
# TODO: Another possibility would be to turn this into an async model.
1088
# Where we let another thread know that we have some bytes if
1089
# they want it, but we don't actually block for it
1090
# Note that osutils.send_all always sends 64kB chunks anyway, so
1091
# we might just push out smaller bits at a time?
1092
self._buf.append(bytes)
1093
self._buf_len += len(bytes)
1094
if self._buf_len > self.BUFFER_SIZE:
1099
self._real_write_func(''.join(self._buf))
988
self._write_func = write_func
1103
990
def _serialise_offsets(self, offsets):
1104
991
"""Serialise a readv offset list."""
1153
1036
_ProtocolThreeEncoder.__init__(self, write_func)
1154
1037
self.response_sent = False
1155
1038
self._headers = {'Software version': bzrlib.__version__}
1156
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1158
self._response_start_time = None
1160
def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
if self._response_start_time is None:
1162
self._response_start_time = osutils.timer_func()
1164
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1167
if extra_bytes is None:
1170
extra = ' ' + repr(extra_bytes[:40])
1172
extra = extra[:29] + extra[-1] + '...'
1173
mutter('%12s: [%s] %s%s%s'
1174
% (action, self._thread_id, t, message, extra))
1176
1040
def send_error(self, exception):
1177
if self.response_sent:
1178
raise AssertionError(
1179
"send_error(%s) called, but response already sent."
1041
assert not self.response_sent
1181
1042
if isinstance(exception, errors.UnknownSmartMethod):
1182
1043
failure = request.FailedSmartServerResponse(
1183
1044
('UnknownMethod', exception.verb))
1184
1045
self.send_response(failure)
1186
if 'hpss' in debug.debug_flags:
1187
self._trace('error', str(exception))
1188
1047
self.response_sent = True
1189
1048
self._write_protocol_version()
1190
1049
self._write_headers(self._headers)
1204
1060
self._write_success_status()
1206
1062
self._write_error_status()
1207
if 'hpss' in debug.debug_flags:
1208
self._trace('response', repr(response.args))
1209
1063
self._write_structure(response.args)
1210
1064
if response.body is not None:
1211
1065
self._write_prefixed_body(response.body)
1212
if 'hpss' in debug.debug_flags:
1213
self._trace('body', '%d bytes' % (len(response.body),),
1214
response.body, include_time=True)
1215
1066
elif response.body_stream is not None:
1216
count = num_bytes = 0
1218
for exc_info, chunk in _iter_with_errors(response.body_stream):
1220
if exc_info is not None:
1221
self._write_error_status()
1222
error_struct = request._translate_error(exc_info[1])
1223
self._write_structure(error_struct)
1226
if isinstance(chunk, request.FailedSmartServerResponse):
1227
self._write_error_status()
1228
self._write_structure(chunk.args)
1230
num_bytes += len(chunk)
1231
if first_chunk is None:
1233
self._write_prefixed_body(chunk)
1234
if 'hpssdetail' in debug.debug_flags:
1235
# Not worth timing separately, as _write_func is
1237
self._trace('body chunk',
1238
'%d bytes' % (len(chunk),),
1239
chunk, suppress_time=True)
1240
if 'hpss' in debug.debug_flags:
1241
self._trace('body stream',
1242
'%d bytes %d chunks' % (num_bytes, count),
1067
for chunk in response.body_stream:
1068
self._write_prefixed_body(chunk)
1244
1069
self._write_end()
1245
if 'hpss' in debug.debug_flags:
1246
self._trace('response end', '', include_time=True)
1249
def _iter_with_errors(iterable):
1250
"""Handle errors from iterable.next().
1254
for exc_info, value in _iter_with_errors(iterable):
1257
This is a safer alternative to::
1260
for value in iterable:
1265
Because the latter will catch errors from the for-loop body, not just
1268
If an error occurs, exc_info will be a exc_info tuple, and the generator
1269
will terminate. Otherwise exc_info will be None, and value will be the
1270
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1271
will not be itercepted.
1273
iterator = iter(iterable)
1276
yield None, iterator.next()
1277
except StopIteration:
1279
except (KeyboardInterrupt, SystemExit):
1282
mutter('_iter_with_errors caught error')
1283
log_exception_quietly()
1284
yield sys.exc_info(), None
1288
1072
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1295
1079
def set_headers(self, headers):
1296
1080
self._headers = headers.copy()
1298
1082
def call(self, *args):
1299
1083
if 'hpss' in debug.debug_flags:
1300
1084
mutter('hpss call: %s', repr(args)[1:-1])
1301
1085
base = getattr(self._medium_request._medium, 'base', None)
1302
1086
if base is not None:
1303
1087
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1088
self._request_start_time = time.time()
1305
1089
self._write_protocol_version()
1306
1090
self._write_headers(self._headers)
1307
1091
self._write_structure(args)
1349
1133
self._write_end()
1350
1134
self._medium_request.finished_writing()
1352
def call_with_body_stream(self, args, stream):
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/body stream: %r', args)
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1359
self._write_protocol_version()
1360
self._write_headers(self._headers)
1361
self._write_structure(args)
1362
# TODO: notice if the server has sent an early error reply before we
1363
# have finished sending the stream. We would notice at the end
1364
# anyway, but if the medium can deliver it early then it's good
1365
# to short-circuit the whole request...
1366
for exc_info, part in _iter_with_errors(stream):
1367
if exc_info is not None:
1368
# Iterating the stream failed. Cleanly abort the request.
1369
self._write_error_status()
1370
# Currently the client unconditionally sends ('error',) as the
1372
self._write_structure(('error',))
1374
self._medium_request.finished_writing()
1375
raise exc_info[0], exc_info[1], exc_info[2]
1377
self._write_prefixed_body(part)
1380
self._medium_request.finished_writing()