1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
from __future__ import absolute_import
19
"""Wire-level encoding and decoding of requests and responses for the smart
24
from cStringIO import StringIO
37
from bzrlib.smart import message, request
38
from bzrlib.trace import log_exception_quietly, mutter
39
from bzrlib.bencode import bdecode_as_tuple, bencode
42
# Protocol version strings. These are sent as prefixes of bzr requests and
43
# responses to identify the protocol version being used. (There are no version
44
# one strings because that version doesn't send any).
45
REQUEST_VERSION_TWO = 'bzr request 2\n'
46
RESPONSE_VERSION_TWO = 'bzr response 2\n'
48
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
49
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
52
def _recv_tuple(from_file):
53
req_line = from_file.readline()
54
return _decode_tuple(req_line)
57
def _decode_tuple(req_line):
58
if req_line is None or req_line == '':
60
if req_line[-1] != '\n':
61
raise errors.SmartProtocolError("request %r not terminated" % req_line)
62
return tuple(req_line[:-1].split('\x01'))
65
def _encode_tuple(args):
66
"""Encode the tuple args to a bytestream."""
67
joined = '\x01'.join(args) + '\n'
68
if type(joined) is unicode:
69
# XXX: We should fix things so this never happens! -AJB, 20100304
70
mutter('response args contain unicode, should be only bytes: %r',
72
joined = joined.encode('ascii')
76
class Requester(object):
77
"""Abstract base class for an object that can issue requests on a smart
81
def call(self, *args):
82
"""Make a remote call.
84
:param args: the arguments of this call.
86
raise NotImplementedError(self.call)
88
def call_with_body_bytes(self, args, body):
89
"""Make a remote call with a body.
91
:param args: the arguments of this call.
93
:param body: the body to send with the request.
95
raise NotImplementedError(self.call_with_body_bytes)
97
def call_with_body_readv_array(self, args, body):
98
"""Make a remote call with a readv array.
100
:param args: the arguments of this call.
101
:type body: iterable of (start, length) tuples.
102
:param body: the readv ranges to send with this request.
104
raise NotImplementedError(self.call_with_body_readv_array)
106
def set_headers(self, headers):
107
raise NotImplementedError(self.set_headers)
110
class SmartProtocolBase(object):
111
"""Methods common to client and server"""
113
# TODO: this only actually accomodates a single block; possibly should
114
# support multiple chunks?
115
def _encode_bulk_data(self, body):
116
"""Encode body as a bulk data chunk."""
117
return ''.join(('%d\n' % len(body), body, 'done\n'))
119
def _serialise_offsets(self, offsets):
120
"""Serialise a readv offset list."""
122
for start, length in offsets:
123
txt.append('%d,%d' % (start, length))
124
return '\n'.join(txt)
127
class SmartServerRequestProtocolOne(SmartProtocolBase):
128
"""Server-side encoding and decoding logic for smart version 1."""
130
def __init__(self, backing_transport, write_func, root_client_path='/',
132
self._backing_transport = backing_transport
133
self._root_client_path = root_client_path
134
self._jail_root = jail_root
135
self.unused_data = ''
136
self._finished = False
138
self._has_dispatched = False
140
self._body_decoder = None
141
self._write_func = write_func
143
def accept_bytes(self, bytes):
144
"""Take bytes, and advance the internal state machine appropriately.
146
:param bytes: must be a byte string
148
if not isinstance(bytes, str):
149
raise ValueError(bytes)
150
self.in_buffer += bytes
151
if not self._has_dispatched:
152
if '\n' not in self.in_buffer:
153
# no command line yet
155
self._has_dispatched = True
157
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
159
req_args = _decode_tuple(first_line)
160
self.request = request.SmartServerRequestHandler(
161
self._backing_transport, commands=request.request_handlers,
162
root_client_path=self._root_client_path,
163
jail_root=self._jail_root)
164
self.request.args_received(req_args)
165
if self.request.finished_reading:
167
self.unused_data = self.in_buffer
169
self._send_response(self.request.response)
170
except KeyboardInterrupt:
172
except errors.UnknownSmartMethod, err:
173
protocol_error = errors.SmartProtocolError(
174
"bad request %r" % (err.verb,))
175
failure = request.FailedSmartServerResponse(
176
('error', str(protocol_error)))
177
self._send_response(failure)
179
except Exception, exception:
180
# everything else: pass to client, flush, and quit
181
log_exception_quietly()
182
self._send_response(request.FailedSmartServerResponse(
183
('error', str(exception))))
186
if self._has_dispatched:
188
# nothing to do.XXX: this routine should be a single state
190
self.unused_data += self.in_buffer
193
if self._body_decoder is None:
194
self._body_decoder = LengthPrefixedBodyDecoder()
195
self._body_decoder.accept_bytes(self.in_buffer)
196
self.in_buffer = self._body_decoder.unused_data
197
body_data = self._body_decoder.read_pending_data()
198
self.request.accept_body(body_data)
199
if self._body_decoder.finished_reading:
200
self.request.end_of_body()
201
if not self.request.finished_reading:
202
raise AssertionError("no more body, request not finished")
203
if self.request.response is not None:
204
self._send_response(self.request.response)
205
self.unused_data = self.in_buffer
208
if self.request.finished_reading:
209
raise AssertionError(
210
"no response and we have finished reading.")
212
def _send_response(self, response):
213
"""Send a smart server response down the output stream."""
215
raise AssertionError('response already sent')
218
self._finished = True
219
self._write_protocol_version()
220
self._write_success_or_failure_prefix(response)
221
self._write_func(_encode_tuple(args))
223
if not isinstance(body, str):
224
raise ValueError(body)
225
bytes = self._encode_bulk_data(body)
226
self._write_func(bytes)
228
def _write_protocol_version(self):
229
"""Write any prefixes this protocol requires.
231
Version one doesn't send protocol versions.
234
def _write_success_or_failure_prefix(self, response):
235
"""Write the protocol specific success/failure prefix.
237
For SmartServerRequestProtocolOne this is omitted but we
238
call is_successful to ensure that the response is valid.
240
response.is_successful()
242
def next_read_size(self):
245
if self._body_decoder is None:
248
return self._body_decoder.next_read_size()
251
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
252
r"""Version two of the server side of the smart protocol.
254
This prefixes responses with the value of RESPONSE_VERSION_TWO.
257
response_marker = RESPONSE_VERSION_TWO
258
request_marker = REQUEST_VERSION_TWO
260
def _write_success_or_failure_prefix(self, response):
261
"""Write the protocol specific success/failure prefix."""
262
if response.is_successful():
263
self._write_func('success\n')
265
self._write_func('failed\n')
267
def _write_protocol_version(self):
268
r"""Write any prefixes this protocol requires.
270
Version two sends the value of RESPONSE_VERSION_TWO.
272
self._write_func(self.response_marker)
274
def _send_response(self, response):
275
"""Send a smart server response down the output stream."""
277
raise AssertionError('response already sent')
278
self._finished = True
279
self._write_protocol_version()
280
self._write_success_or_failure_prefix(response)
281
self._write_func(_encode_tuple(response.args))
282
if response.body is not None:
283
if not isinstance(response.body, str):
284
raise AssertionError('body must be a str')
285
if not (response.body_stream is None):
286
raise AssertionError(
287
'body_stream and body cannot both be set')
288
bytes = self._encode_bulk_data(response.body)
289
self._write_func(bytes)
290
elif response.body_stream is not None:
291
_send_stream(response.body_stream, self._write_func)
294
def _send_stream(stream, write_func):
295
write_func('chunked\n')
296
_send_chunks(stream, write_func)
300
def _send_chunks(stream, write_func):
302
if isinstance(chunk, str):
303
bytes = "%x\n%s" % (len(chunk), chunk)
305
elif isinstance(chunk, request.FailedSmartServerResponse):
307
_send_chunks(chunk.args, write_func)
310
raise errors.BzrError(
311
'Chunks must be str or FailedSmartServerResponse, got %r'
315
class _NeedMoreBytes(Exception):
316
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
320
def __init__(self, count=None):
323
:param count: the total number of bytes needed by the current state.
324
May be None if the number of bytes needed is unknown.
329
class _StatefulDecoder(object):
330
"""Base class for writing state machines to decode byte streams.
332
Subclasses should provide a self.state_accept attribute that accepts bytes
333
and, if appropriate, updates self.state_accept to a different function.
334
accept_bytes will call state_accept as often as necessary to make sure the
335
state machine has progressed as far as possible before it returns.
337
See ProtocolThreeDecoder for an example subclass.
341
self.finished_reading = False
342
self._in_buffer_list = []
343
self._in_buffer_len = 0
344
self.unused_data = ''
345
self.bytes_left = None
346
self._number_needed_bytes = None
348
def _get_in_buffer(self):
349
if len(self._in_buffer_list) == 1:
350
return self._in_buffer_list[0]
351
in_buffer = ''.join(self._in_buffer_list)
352
if len(in_buffer) != self._in_buffer_len:
353
raise AssertionError(
354
"Length of buffer did not match expected value: %s != %s"
355
% self._in_buffer_len, len(in_buffer))
356
self._in_buffer_list = [in_buffer]
359
def _get_in_bytes(self, count):
360
"""Grab X bytes from the input_buffer.
362
Callers should have already checked that self._in_buffer_len is >
363
count. Note, this does not consume the bytes from the buffer. The
364
caller will still need to call _get_in_buffer() and then
365
_set_in_buffer() if they actually need to consume the bytes.
367
# check if we can yield the bytes from just the first entry in our list
368
if len(self._in_buffer_list) == 0:
369
raise AssertionError('Callers must be sure we have buffered bytes'
370
' before calling _get_in_bytes')
371
if len(self._in_buffer_list[0]) > count:
372
return self._in_buffer_list[0][:count]
373
# We can't yield it from the first buffer, so collapse all buffers, and
375
in_buf = self._get_in_buffer()
376
return in_buf[:count]
378
def _set_in_buffer(self, new_buf):
379
if new_buf is not None:
380
self._in_buffer_list = [new_buf]
381
self._in_buffer_len = len(new_buf)
383
self._in_buffer_list = []
384
self._in_buffer_len = 0
386
def accept_bytes(self, bytes):
387
"""Decode as much of bytes as possible.
389
If 'bytes' contains too much data it will be appended to
392
finished_reading will be set when no more data is required. Further
393
data will be appended to self.unused_data.
395
# accept_bytes is allowed to change the state
396
self._number_needed_bytes = None
397
# lsprof puts a very large amount of time on this specific call for
399
self._in_buffer_list.append(bytes)
400
self._in_buffer_len += len(bytes)
402
# Run the function for the current state.
403
current_state = self.state_accept
405
while current_state != self.state_accept:
406
# The current state has changed. Run the function for the new
407
# current state, so that it can:
408
# - decode any unconsumed bytes left in a buffer, and
409
# - signal how many more bytes are expected (via raising
411
current_state = self.state_accept
413
except _NeedMoreBytes, e:
414
self._number_needed_bytes = e.count
417
class ChunkedBodyDecoder(_StatefulDecoder):
418
"""Decoder for chunked body data.
420
This is very similar the HTTP's chunked encoding. See the description of
421
streamed body data in `doc/developers/network-protocol.txt` for details.
425
_StatefulDecoder.__init__(self)
426
self.state_accept = self._state_accept_expecting_header
427
self.chunk_in_progress = None
428
self.chunks = collections.deque()
430
self.error_in_progress = None
432
def next_read_size(self):
433
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
434
# end-of-body marker is 4 bytes: 'END\n'.
435
if self.state_accept == self._state_accept_reading_chunk:
436
# We're expecting more chunk content. So we're expecting at least
437
# the rest of this chunk plus an END chunk.
438
return self.bytes_left + 4
439
elif self.state_accept == self._state_accept_expecting_length:
440
if self._in_buffer_len == 0:
441
# We're expecting a chunk length. There's at least two bytes
442
# left: a digit plus '\n'.
445
# We're in the middle of reading a chunk length. So there's at
446
# least one byte left, the '\n' that terminates the length.
448
elif self.state_accept == self._state_accept_reading_unused:
450
elif self.state_accept == self._state_accept_expecting_header:
451
return max(0, len('chunked\n') - self._in_buffer_len)
453
raise AssertionError("Impossible state: %r" % (self.state_accept,))
455
def read_next_chunk(self):
457
return self.chunks.popleft()
461
def _extract_line(self):
462
in_buf = self._get_in_buffer()
463
pos = in_buf.find('\n')
465
# We haven't read a complete line yet, so request more bytes before
467
raise _NeedMoreBytes(1)
469
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
470
self._set_in_buffer(in_buf[pos+1:])
474
self.unused_data = self._get_in_buffer()
475
self._in_buffer_list = []
476
self._in_buffer_len = 0
477
self.state_accept = self._state_accept_reading_unused
479
error_args = tuple(self.error_in_progress)
480
self.chunks.append(request.FailedSmartServerResponse(error_args))
481
self.error_in_progress = None
482
self.finished_reading = True
484
def _state_accept_expecting_header(self):
485
prefix = self._extract_line()
486
if prefix == 'chunked':
487
self.state_accept = self._state_accept_expecting_length
489
raise errors.SmartProtocolError(
490
'Bad chunked body header: "%s"' % (prefix,))
492
def _state_accept_expecting_length(self):
493
prefix = self._extract_line()
496
self.error_in_progress = []
497
self._state_accept_expecting_length()
499
elif prefix == 'END':
500
# We've read the end-of-body marker.
501
# Any further bytes are unused data, including the bytes left in
506
self.bytes_left = int(prefix, 16)
507
self.chunk_in_progress = ''
508
self.state_accept = self._state_accept_reading_chunk
510
def _state_accept_reading_chunk(self):
511
in_buf = self._get_in_buffer()
512
in_buffer_len = len(in_buf)
513
self.chunk_in_progress += in_buf[:self.bytes_left]
514
self._set_in_buffer(in_buf[self.bytes_left:])
515
self.bytes_left -= in_buffer_len
516
if self.bytes_left <= 0:
517
# Finished with chunk
518
self.bytes_left = None
520
self.error_in_progress.append(self.chunk_in_progress)
522
self.chunks.append(self.chunk_in_progress)
523
self.chunk_in_progress = None
524
self.state_accept = self._state_accept_expecting_length
526
def _state_accept_reading_unused(self):
527
self.unused_data += self._get_in_buffer()
528
self._in_buffer_list = []
531
class LengthPrefixedBodyDecoder(_StatefulDecoder):
532
"""Decodes the length-prefixed bulk data."""
535
_StatefulDecoder.__init__(self)
536
self.state_accept = self._state_accept_expecting_length
537
self.state_read = self._state_read_no_data
539
self._trailer_buffer = ''
541
def next_read_size(self):
542
if self.bytes_left is not None:
543
# Ideally we want to read all the remainder of the body and the
545
return self.bytes_left + 5
546
elif self.state_accept == self._state_accept_reading_trailer:
547
# Just the trailer left
548
return 5 - len(self._trailer_buffer)
549
elif self.state_accept == self._state_accept_expecting_length:
550
# There's still at least 6 bytes left ('\n' to end the length, plus
554
# Reading excess data. Either way, 1 byte at a time is fine.
557
def read_pending_data(self):
558
"""Return any pending data that has been decoded."""
559
return self.state_read()
561
def _state_accept_expecting_length(self):
562
in_buf = self._get_in_buffer()
563
pos = in_buf.find('\n')
566
self.bytes_left = int(in_buf[:pos])
567
self._set_in_buffer(in_buf[pos+1:])
568
self.state_accept = self._state_accept_reading_body
569
self.state_read = self._state_read_body_buffer
571
def _state_accept_reading_body(self):
572
in_buf = self._get_in_buffer()
574
self.bytes_left -= len(in_buf)
575
self._set_in_buffer(None)
576
if self.bytes_left <= 0:
578
if self.bytes_left != 0:
579
self._trailer_buffer = self._body[self.bytes_left:]
580
self._body = self._body[:self.bytes_left]
581
self.bytes_left = None
582
self.state_accept = self._state_accept_reading_trailer
584
def _state_accept_reading_trailer(self):
585
self._trailer_buffer += self._get_in_buffer()
586
self._set_in_buffer(None)
587
# TODO: what if the trailer does not match "done\n"? Should this raise
588
# a ProtocolViolation exception?
589
if self._trailer_buffer.startswith('done\n'):
590
self.unused_data = self._trailer_buffer[len('done\n'):]
591
self.state_accept = self._state_accept_reading_unused
592
self.finished_reading = True
594
def _state_accept_reading_unused(self):
595
self.unused_data += self._get_in_buffer()
596
self._set_in_buffer(None)
598
def _state_read_no_data(self):
601
def _state_read_body_buffer(self):
607
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
608
message.ResponseHandler):
609
"""The client-side protocol for smart version 1."""
611
def __init__(self, request):
612
"""Construct a SmartClientRequestProtocolOne.
614
:param request: A SmartClientMediumRequest to serialise onto and
617
self._request = request
618
self._body_buffer = None
619
self._request_start_time = None
620
self._last_verb = None
623
def set_headers(self, headers):
624
self._headers = dict(headers)
626
def call(self, *args):
627
if 'hpss' in debug.debug_flags:
628
mutter('hpss call: %s', repr(args)[1:-1])
629
if getattr(self._request._medium, 'base', None) is not None:
630
mutter(' (to %s)', self._request._medium.base)
631
self._request_start_time = osutils.timer_func()
632
self._write_args(args)
633
self._request.finished_writing()
634
self._last_verb = args[0]
636
def call_with_body_bytes(self, args, body):
637
"""Make a remote call of args with body bytes 'body'.
639
After calling this, call read_response_tuple to find the result out.
641
if 'hpss' in debug.debug_flags:
642
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
643
if getattr(self._request._medium, '_path', None) is not None:
644
mutter(' (to %s)', self._request._medium._path)
645
mutter(' %d bytes', len(body))
646
self._request_start_time = osutils.timer_func()
647
if 'hpssdetail' in debug.debug_flags:
648
mutter('hpss body content: %s', body)
649
self._write_args(args)
650
bytes = self._encode_bulk_data(body)
651
self._request.accept_bytes(bytes)
652
self._request.finished_writing()
653
self._last_verb = args[0]
655
def call_with_body_readv_array(self, args, body):
656
"""Make a remote call with a readv array.
658
The body is encoded with one line per readv offset pair. The numbers in
659
each pair are separated by a comma, and no trailing \\n is emitted.
661
if 'hpss' in debug.debug_flags:
662
mutter('hpss call w/readv: %s', repr(args)[1:-1])
663
if getattr(self._request._medium, '_path', None) is not None:
664
mutter(' (to %s)', self._request._medium._path)
665
self._request_start_time = osutils.timer_func()
666
self._write_args(args)
667
readv_bytes = self._serialise_offsets(body)
668
bytes = self._encode_bulk_data(readv_bytes)
669
self._request.accept_bytes(bytes)
670
self._request.finished_writing()
671
if 'hpss' in debug.debug_flags:
672
mutter(' %d bytes in readv request', len(readv_bytes))
673
self._last_verb = args[0]
675
def call_with_body_stream(self, args, stream):
676
# Protocols v1 and v2 don't support body streams. So it's safe to
677
# assume that a v1/v2 server doesn't support whatever method we're
678
# trying to call with a body stream.
679
self._request.finished_writing()
680
self._request.finished_reading()
681
raise errors.UnknownSmartMethod(args[0])
683
def cancel_read_body(self):
684
"""After expecting a body, a response code may indicate one otherwise.
686
This method lets the domain client inform the protocol that no body
687
will be transmitted. This is a terminal method: after calling it the
688
protocol is not able to be used further.
690
self._request.finished_reading()
692
def _read_response_tuple(self):
693
result = self._recv_tuple()
694
if 'hpss' in debug.debug_flags:
695
if self._request_start_time is not None:
696
mutter(' result: %6.3fs %s',
697
osutils.timer_func() - self._request_start_time,
699
self._request_start_time = None
701
mutter(' result: %s', repr(result)[1:-1])
704
def read_response_tuple(self, expect_body=False):
705
"""Read a response tuple from the wire.
707
This should only be called once.
709
result = self._read_response_tuple()
710
self._response_is_unknown_method(result)
711
self._raise_args_if_error(result)
713
self._request.finished_reading()
716
def _raise_args_if_error(self, result_tuple):
717
# Later protocol versions have an explicit flag in the protocol to say
718
# if an error response is "failed" or not. In version 1 we don't have
719
# that luxury. So here is a complete list of errors that can be
720
# returned in response to existing version 1 smart requests. Responses
721
# starting with these codes are always "failed" responses.
728
'UnicodeEncodeError',
729
'UnicodeDecodeError',
735
'UnlockableTransport',
741
if result_tuple[0] in v1_error_codes:
742
self._request.finished_reading()
743
raise errors.ErrorFromSmartServer(result_tuple)
745
def _response_is_unknown_method(self, result_tuple):
746
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
747
method' response to the request.
749
:param response: The response from a smart client call_expecting_body
751
:param verb: The verb used in that call.
752
:raises: UnexpectedSmartServerResponse
754
if (result_tuple == ('error', "Generic bzr smart protocol error: "
755
"bad request '%s'" % self._last_verb) or
756
result_tuple == ('error', "Generic bzr smart protocol error: "
757
"bad request u'%s'" % self._last_verb)):
758
# The response will have no body, so we've finished reading.
759
self._request.finished_reading()
760
raise errors.UnknownSmartMethod(self._last_verb)
762
def read_body_bytes(self, count=-1):
763
"""Read bytes from the body, decoding into a byte stream.
765
We read all bytes at once to ensure we've checked the trailer for
766
errors, and then feed the buffer back as read_body_bytes is called.
768
if self._body_buffer is not None:
769
return self._body_buffer.read(count)
770
_body_decoder = LengthPrefixedBodyDecoder()
772
while not _body_decoder.finished_reading:
773
bytes = self._request.read_bytes(_body_decoder.next_read_size())
775
# end of file encountered reading from server
776
raise errors.ConnectionReset(
777
"Connection lost while reading response body.")
778
_body_decoder.accept_bytes(bytes)
779
self._request.finished_reading()
780
self._body_buffer = StringIO(_body_decoder.read_pending_data())
781
# XXX: TODO check the trailer result.
782
if 'hpss' in debug.debug_flags:
783
mutter(' %d body bytes read',
784
len(self._body_buffer.getvalue()))
785
return self._body_buffer.read(count)
787
def _recv_tuple(self):
788
"""Receive a tuple from the medium request."""
789
return _decode_tuple(self._request.read_line())
791
def query_version(self):
792
"""Return protocol version number of the server."""
794
resp = self.read_response_tuple()
795
if resp == ('ok', '1'):
797
elif resp == ('ok', '2'):
800
raise errors.SmartProtocolError("bad response %r" % (resp,))
802
def _write_args(self, args):
803
self._write_protocol_version()
804
bytes = _encode_tuple(args)
805
self._request.accept_bytes(bytes)
807
def _write_protocol_version(self):
808
"""Write any prefixes this protocol requires.
810
Version one doesn't send protocol versions.
814
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
815
"""Version two of the client side of the smart protocol.
817
This prefixes the request with the value of REQUEST_VERSION_TWO.
820
response_marker = RESPONSE_VERSION_TWO
821
request_marker = REQUEST_VERSION_TWO
823
def read_response_tuple(self, expect_body=False):
824
"""Read a response tuple from the wire.
826
This should only be called once.
828
version = self._request.read_line()
829
if version != self.response_marker:
830
self._request.finished_reading()
831
raise errors.UnexpectedProtocolVersionMarker(version)
832
response_status = self._request.read_line()
833
result = SmartClientRequestProtocolOne._read_response_tuple(self)
834
self._response_is_unknown_method(result)
835
if response_status == 'success\n':
836
self.response_status = True
838
self._request.finished_reading()
840
elif response_status == 'failed\n':
841
self.response_status = False
842
self._request.finished_reading()
843
raise errors.ErrorFromSmartServer(result)
845
raise errors.SmartProtocolError(
846
'bad protocol status %r' % response_status)
848
def _write_protocol_version(self):
849
"""Write any prefixes this protocol requires.
851
Version two sends the value of REQUEST_VERSION_TWO.
853
self._request.accept_bytes(self.request_marker)
855
def read_streamed_body(self):
856
"""Read bytes from the body, decoding into a byte stream.
858
# Read no more than 64k at a time so that we don't risk error 10055 (no
859
# buffer space available) on Windows.
860
_body_decoder = ChunkedBodyDecoder()
861
while not _body_decoder.finished_reading:
862
bytes = self._request.read_bytes(_body_decoder.next_read_size())
864
# end of file encountered reading from server
865
raise errors.ConnectionReset(
866
"Connection lost while reading streamed body.")
867
_body_decoder.accept_bytes(bytes)
868
for body_bytes in iter(_body_decoder.read_next_chunk, None):
869
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
870
mutter(' %d byte chunk read',
873
self._request.finished_reading()
876
def build_server_protocol_three(backing_transport, write_func,
877
root_client_path, jail_root=None):
878
request_handler = request.SmartServerRequestHandler(
879
backing_transport, commands=request.request_handlers,
880
root_client_path=root_client_path, jail_root=jail_root)
881
responder = ProtocolThreeResponder(write_func)
882
message_handler = message.ConventionalRequestHandler(request_handler, responder)
883
return ProtocolThreeDecoder(message_handler)
886
class ProtocolThreeDecoder(_StatefulDecoder):
888
response_marker = RESPONSE_VERSION_THREE
889
request_marker = REQUEST_VERSION_THREE
891
def __init__(self, message_handler, expect_version_marker=False):
892
_StatefulDecoder.__init__(self)
893
self._has_dispatched = False
895
if expect_version_marker:
896
self.state_accept = self._state_accept_expecting_protocol_version
897
# We're expecting at least the protocol version marker + some
899
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
901
self.state_accept = self._state_accept_expecting_headers
902
self._number_needed_bytes = 4
903
self.decoding_failed = False
904
self.request_handler = self.message_handler = message_handler
906
def accept_bytes(self, bytes):
907
self._number_needed_bytes = None
909
_StatefulDecoder.accept_bytes(self, bytes)
910
except KeyboardInterrupt:
912
except errors.SmartMessageHandlerError, exception:
913
# We do *not* set self.decoding_failed here. The message handler
914
# has raised an error, but the decoder is still able to parse bytes
915
# and determine when this message ends.
916
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
917
log_exception_quietly()
918
self.message_handler.protocol_error(exception.exc_value)
919
# The state machine is ready to continue decoding, but the
920
# exception has interrupted the loop that runs the state machine.
921
# So we call accept_bytes again to restart it.
922
self.accept_bytes('')
923
except Exception, exception:
924
# The decoder itself has raised an exception. We cannot continue
926
self.decoding_failed = True
927
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
928
# This happens during normal operation when the client tries a
929
# protocol version the server doesn't understand, so no need to
930
# log a traceback every time.
931
# Note that this can only happen when
932
# expect_version_marker=True, which is only the case on the
936
log_exception_quietly()
937
self.message_handler.protocol_error(exception)
939
def _extract_length_prefixed_bytes(self):
940
if self._in_buffer_len < 4:
941
# A length prefix by itself is 4 bytes, and we don't even have that
943
raise _NeedMoreBytes(4)
944
(length,) = struct.unpack('!L', self._get_in_bytes(4))
945
end_of_bytes = 4 + length
946
if self._in_buffer_len < end_of_bytes:
947
# We haven't yet read as many bytes as the length-prefix says there
949
raise _NeedMoreBytes(end_of_bytes)
950
# Extract the bytes from the buffer.
951
in_buf = self._get_in_buffer()
952
bytes = in_buf[4:end_of_bytes]
953
self._set_in_buffer(in_buf[end_of_bytes:])
956
def _extract_prefixed_bencoded_data(self):
957
prefixed_bytes = self._extract_length_prefixed_bytes()
959
decoded = bdecode_as_tuple(prefixed_bytes)
961
raise errors.SmartProtocolError(
962
'Bytes %r not bencoded' % (prefixed_bytes,))
965
def _extract_single_byte(self):
966
if self._in_buffer_len == 0:
967
# The buffer is empty
968
raise _NeedMoreBytes(1)
969
in_buf = self._get_in_buffer()
971
self._set_in_buffer(in_buf[1:])
974
def _state_accept_expecting_protocol_version(self):
975
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
976
in_buf = self._get_in_buffer()
978
# We don't have enough bytes to check if the protocol version
979
# marker is right. But we can check if it is already wrong by
980
# checking that the start of MESSAGE_VERSION_THREE matches what
982
# [In fact, if the remote end isn't bzr we might never receive
983
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
984
# are wrong then we should just raise immediately rather than
986
if not MESSAGE_VERSION_THREE.startswith(in_buf):
987
# We have enough bytes to know the protocol version is wrong
988
raise errors.UnexpectedProtocolVersionMarker(in_buf)
989
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
990
if not in_buf.startswith(MESSAGE_VERSION_THREE):
991
raise errors.UnexpectedProtocolVersionMarker(in_buf)
992
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
993
self.state_accept = self._state_accept_expecting_headers
995
def _state_accept_expecting_headers(self):
996
decoded = self._extract_prefixed_bencoded_data()
997
if type(decoded) is not dict:
998
raise errors.SmartProtocolError(
999
'Header object %r is not a dict' % (decoded,))
1000
self.state_accept = self._state_accept_expecting_message_part
1002
self.message_handler.headers_received(decoded)
1004
raise errors.SmartMessageHandlerError(sys.exc_info())
1006
def _state_accept_expecting_message_part(self):
1007
message_part_kind = self._extract_single_byte()
1008
if message_part_kind == 'o':
1009
self.state_accept = self._state_accept_expecting_one_byte
1010
elif message_part_kind == 's':
1011
self.state_accept = self._state_accept_expecting_structure
1012
elif message_part_kind == 'b':
1013
self.state_accept = self._state_accept_expecting_bytes
1014
elif message_part_kind == 'e':
1017
raise errors.SmartProtocolError(
1018
'Bad message kind byte: %r' % (message_part_kind,))
1020
def _state_accept_expecting_one_byte(self):
1021
byte = self._extract_single_byte()
1022
self.state_accept = self._state_accept_expecting_message_part
1024
self.message_handler.byte_part_received(byte)
1026
raise errors.SmartMessageHandlerError(sys.exc_info())
1028
def _state_accept_expecting_bytes(self):
1029
# XXX: this should not buffer whole message part, but instead deliver
1030
# the bytes as they arrive.
1031
prefixed_bytes = self._extract_length_prefixed_bytes()
1032
self.state_accept = self._state_accept_expecting_message_part
1034
self.message_handler.bytes_part_received(prefixed_bytes)
1036
raise errors.SmartMessageHandlerError(sys.exc_info())
1038
def _state_accept_expecting_structure(self):
1039
structure = self._extract_prefixed_bencoded_data()
1040
self.state_accept = self._state_accept_expecting_message_part
1042
self.message_handler.structure_part_received(structure)
1044
raise errors.SmartMessageHandlerError(sys.exc_info())
1047
self.unused_data = self._get_in_buffer()
1048
self._set_in_buffer(None)
1049
self.state_accept = self._state_accept_reading_unused
1051
self.message_handler.end_received()
1053
raise errors.SmartMessageHandlerError(sys.exc_info())
1055
def _state_accept_reading_unused(self):
1056
self.unused_data += self._get_in_buffer()
1057
self._set_in_buffer(None)
1059
def next_read_size(self):
1060
if self.state_accept == self._state_accept_reading_unused:
1062
elif self.decoding_failed:
1063
# An exception occured while processing this message, probably from
1064
# self.message_handler. We're not sure that this state machine is
1065
# in a consistent state, so just signal that we're done (i.e. give
1069
if self._number_needed_bytes is not None:
1070
return self._number_needed_bytes - self._in_buffer_len
1072
raise AssertionError("don't know how many bytes are expected!")
1075
class _ProtocolThreeEncoder(object):
1077
response_marker = request_marker = MESSAGE_VERSION_THREE
1078
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1080
def __init__(self, write_func):
1083
self._real_write_func = write_func
1085
def _write_func(self, bytes):
1086
# TODO: Another possibility would be to turn this into an async model.
1087
# Where we let another thread know that we have some bytes if
1088
# they want it, but we don't actually block for it
1089
# Note that osutils.send_all always sends 64kB chunks anyway, so
1090
# we might just push out smaller bits at a time?
1091
self._buf.append(bytes)
1092
self._buf_len += len(bytes)
1093
if self._buf_len > self.BUFFER_SIZE:
1098
self._real_write_func(''.join(self._buf))
1102
def _serialise_offsets(self, offsets):
1103
"""Serialise a readv offset list."""
1105
for start, length in offsets:
1106
txt.append('%d,%d' % (start, length))
1107
return '\n'.join(txt)
1109
def _write_protocol_version(self):
1110
self._write_func(MESSAGE_VERSION_THREE)
1112
def _write_prefixed_bencode(self, structure):
1113
bytes = bencode(structure)
1114
self._write_func(struct.pack('!L', len(bytes)))
1115
self._write_func(bytes)
1117
def _write_headers(self, headers):
1118
self._write_prefixed_bencode(headers)
1120
def _write_structure(self, args):
1121
self._write_func('s')
1124
if type(arg) is unicode:
1125
utf8_args.append(arg.encode('utf8'))
1127
utf8_args.append(arg)
1128
self._write_prefixed_bencode(utf8_args)
1130
def _write_end(self):
1131
self._write_func('e')
1134
def _write_prefixed_body(self, bytes):
1135
self._write_func('b')
1136
self._write_func(struct.pack('!L', len(bytes)))
1137
self._write_func(bytes)
1139
def _write_chunked_body_start(self):
1140
self._write_func('oC')
1142
def _write_error_status(self):
1143
self._write_func('oE')
1145
def _write_success_status(self):
1146
self._write_func('oS')
1149
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1151
def __init__(self, write_func):
1152
_ProtocolThreeEncoder.__init__(self, write_func)
1153
self.response_sent = False
1154
self._headers = {'Software version': bzrlib.__version__}
1155
if 'hpss' in debug.debug_flags:
1156
self._thread_id = thread.get_ident()
1157
self._response_start_time = None
1159
def _trace(self, action, message, extra_bytes=None, include_time=False):
1160
if self._response_start_time is None:
1161
self._response_start_time = osutils.timer_func()
1163
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1166
if extra_bytes is None:
1169
extra = ' ' + repr(extra_bytes[:40])
1171
extra = extra[:29] + extra[-1] + '...'
1172
mutter('%12s: [%s] %s%s%s'
1173
% (action, self._thread_id, t, message, extra))
1175
def send_error(self, exception):
1176
if self.response_sent:
1177
raise AssertionError(
1178
"send_error(%s) called, but response already sent."
1180
if isinstance(exception, errors.UnknownSmartMethod):
1181
failure = request.FailedSmartServerResponse(
1182
('UnknownMethod', exception.verb))
1183
self.send_response(failure)
1185
if 'hpss' in debug.debug_flags:
1186
self._trace('error', str(exception))
1187
self.response_sent = True
1188
self._write_protocol_version()
1189
self._write_headers(self._headers)
1190
self._write_error_status()
1191
self._write_structure(('error', str(exception)))
1194
def send_response(self, response):
1195
if self.response_sent:
1196
raise AssertionError(
1197
"send_response(%r) called, but response already sent."
1199
self.response_sent = True
1200
self._write_protocol_version()
1201
self._write_headers(self._headers)
1202
if response.is_successful():
1203
self._write_success_status()
1205
self._write_error_status()
1206
if 'hpss' in debug.debug_flags:
1207
self._trace('response', repr(response.args))
1208
self._write_structure(response.args)
1209
if response.body is not None:
1210
self._write_prefixed_body(response.body)
1211
if 'hpss' in debug.debug_flags:
1212
self._trace('body', '%d bytes' % (len(response.body),),
1213
response.body, include_time=True)
1214
elif response.body_stream is not None:
1215
count = num_bytes = 0
1217
for exc_info, chunk in _iter_with_errors(response.body_stream):
1219
if exc_info is not None:
1220
self._write_error_status()
1221
error_struct = request._translate_error(exc_info[1])
1222
self._write_structure(error_struct)
1225
if isinstance(chunk, request.FailedSmartServerResponse):
1226
self._write_error_status()
1227
self._write_structure(chunk.args)
1229
num_bytes += len(chunk)
1230
if first_chunk is None:
1232
self._write_prefixed_body(chunk)
1234
if 'hpssdetail' in debug.debug_flags:
1235
# Not worth timing separately, as _write_func is
1237
self._trace('body chunk',
1238
'%d bytes' % (len(chunk),),
1239
chunk, suppress_time=True)
1240
if 'hpss' in debug.debug_flags:
1241
self._trace('body stream',
1242
'%d bytes %d chunks' % (num_bytes, count),
1245
if 'hpss' in debug.debug_flags:
1246
self._trace('response end', '', include_time=True)
1249
def _iter_with_errors(iterable):
1250
"""Handle errors from iterable.next().
1254
for exc_info, value in _iter_with_errors(iterable):
1257
This is a safer alternative to::
1260
for value in iterable:
1265
Because the latter will catch errors from the for-loop body, not just
1268
If an error occurs, exc_info will be a exc_info tuple, and the generator
1269
will terminate. Otherwise exc_info will be None, and value will be the
1270
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1271
will not be itercepted.
1273
iterator = iter(iterable)
1276
yield None, iterator.next()
1277
except StopIteration:
1279
except (KeyboardInterrupt, SystemExit):
1282
mutter('_iter_with_errors caught error')
1283
log_exception_quietly()
1284
yield sys.exc_info(), None
1288
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1290
def __init__(self, medium_request):
1291
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1292
self._medium_request = medium_request
1294
self.body_stream_started = None
1296
def set_headers(self, headers):
1297
self._headers = headers.copy()
1299
def call(self, *args):
1300
if 'hpss' in debug.debug_flags:
1301
mutter('hpss call: %s', repr(args)[1:-1])
1302
base = getattr(self._medium_request._medium, 'base', None)
1303
if base is not None:
1304
mutter(' (to %s)', base)
1305
self._request_start_time = osutils.timer_func()
1306
self._write_protocol_version()
1307
self._write_headers(self._headers)
1308
self._write_structure(args)
1310
self._medium_request.finished_writing()
1312
def call_with_body_bytes(self, args, body):
1313
"""Make a remote call of args with body bytes 'body'.
1315
After calling this, call read_response_tuple to find the result out.
1317
if 'hpss' in debug.debug_flags:
1318
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1319
path = getattr(self._medium_request._medium, '_path', None)
1320
if path is not None:
1321
mutter(' (to %s)', path)
1322
mutter(' %d bytes', len(body))
1323
self._request_start_time = osutils.timer_func()
1324
self._write_protocol_version()
1325
self._write_headers(self._headers)
1326
self._write_structure(args)
1327
self._write_prefixed_body(body)
1329
self._medium_request.finished_writing()
1331
def call_with_body_readv_array(self, args, body):
1332
"""Make a remote call with a readv array.
1334
The body is encoded with one line per readv offset pair. The numbers in
1335
each pair are separated by a comma, and no trailing \\n is emitted.
1337
if 'hpss' in debug.debug_flags:
1338
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1339
path = getattr(self._medium_request._medium, '_path', None)
1340
if path is not None:
1341
mutter(' (to %s)', path)
1342
self._request_start_time = osutils.timer_func()
1343
self._write_protocol_version()
1344
self._write_headers(self._headers)
1345
self._write_structure(args)
1346
readv_bytes = self._serialise_offsets(body)
1347
if 'hpss' in debug.debug_flags:
1348
mutter(' %d bytes in readv request', len(readv_bytes))
1349
self._write_prefixed_body(readv_bytes)
1351
self._medium_request.finished_writing()
1353
def call_with_body_stream(self, args, stream):
1354
if 'hpss' in debug.debug_flags:
1355
mutter('hpss call w/body stream: %r', args)
1356
path = getattr(self._medium_request._medium, '_path', None)
1357
if path is not None:
1358
mutter(' (to %s)', path)
1359
self._request_start_time = osutils.timer_func()
1360
self.body_stream_started = False
1361
self._write_protocol_version()
1362
self._write_headers(self._headers)
1363
self._write_structure(args)
1364
# TODO: notice if the server has sent an early error reply before we
1365
# have finished sending the stream. We would notice at the end
1366
# anyway, but if the medium can deliver it early then it's good
1367
# to short-circuit the whole request...
1368
# Provoke any ConnectionReset failures before we start the body stream.
1370
self.body_stream_started = True
1371
for exc_info, part in _iter_with_errors(stream):
1372
if exc_info is not None:
1373
# Iterating the stream failed. Cleanly abort the request.
1374
self._write_error_status()
1375
# Currently the client unconditionally sends ('error',) as the
1377
self._write_structure(('error',))
1379
self._medium_request.finished_writing()
1380
raise exc_info[0], exc_info[1], exc_info[2]
1382
self._write_prefixed_body(part)
1385
self._medium_request.finished_writing()