1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
from cStringIO import StringIO
34
from bzrlib.smart import message, request
35
from bzrlib.trace import log_exception_quietly, mutter
36
from bzrlib.bencode import bdecode_as_tuple, bencode
39
# Protocol version strings. These are sent as prefixes of bzr requests and
40
# responses to identify the protocol version being used. (There are no version
41
# one strings because that version doesn't send any).
42
REQUEST_VERSION_TWO = 'bzr request 2\n'
43
RESPONSE_VERSION_TWO = 'bzr response 2\n'
45
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
46
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
49
def _recv_tuple(from_file):
50
req_line = from_file.readline()
51
return _decode_tuple(req_line)
54
def _decode_tuple(req_line):
55
if req_line is None or req_line == '':
57
if req_line[-1] != '\n':
58
raise errors.SmartProtocolError("request %r not terminated" % req_line)
59
return tuple(req_line[:-1].split('\x01'))
62
def _encode_tuple(args):
63
"""Encode the tuple args to a bytestream."""
64
return '\x01'.join(args) + '\n'
67
class Requester(object):
68
"""Abstract base class for an object that can issue requests on a smart
72
def call(self, *args):
73
"""Make a remote call.
75
:param args: the arguments of this call.
77
raise NotImplementedError(self.call)
79
def call_with_body_bytes(self, args, body):
80
"""Make a remote call with a body.
82
:param args: the arguments of this call.
84
:param body: the body to send with the request.
86
raise NotImplementedError(self.call_with_body_bytes)
88
def call_with_body_readv_array(self, args, body):
89
"""Make a remote call with a readv array.
91
:param args: the arguments of this call.
92
:type body: iterable of (start, length) tuples.
93
:param body: the readv ranges to send with this request.
95
raise NotImplementedError(self.call_with_body_readv_array)
97
def set_headers(self, headers):
98
raise NotImplementedError(self.set_headers)
101
class SmartProtocolBase(object):
102
"""Methods common to client and server"""
104
# TODO: this only actually accomodates a single block; possibly should
105
# support multiple chunks?
106
def _encode_bulk_data(self, body):
107
"""Encode body as a bulk data chunk."""
108
return ''.join(('%d\n' % len(body), body, 'done\n'))
110
def _serialise_offsets(self, offsets):
111
"""Serialise a readv offset list."""
113
for start, length in offsets:
114
txt.append('%d,%d' % (start, length))
115
return '\n'.join(txt)
118
class SmartServerRequestProtocolOne(SmartProtocolBase):
119
"""Server-side encoding and decoding logic for smart version 1."""
121
def __init__(self, backing_transport, write_func, root_client_path='/',
123
self._backing_transport = backing_transport
124
self._root_client_path = root_client_path
125
self._jail_root = jail_root
126
self.unused_data = ''
127
self._finished = False
129
self._has_dispatched = False
131
self._body_decoder = None
132
self._write_func = write_func
134
def accept_bytes(self, bytes):
135
"""Take bytes, and advance the internal state machine appropriately.
137
:param bytes: must be a byte string
139
if not isinstance(bytes, str):
140
raise ValueError(bytes)
141
self.in_buffer += bytes
142
if not self._has_dispatched:
143
if '\n' not in self.in_buffer:
144
# no command line yet
146
self._has_dispatched = True
148
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
150
req_args = _decode_tuple(first_line)
151
self.request = request.SmartServerRequestHandler(
152
self._backing_transport, commands=request.request_handlers,
153
root_client_path=self._root_client_path,
154
jail_root=self._jail_root)
155
self.request.args_received(req_args)
156
if self.request.finished_reading:
158
self.unused_data = self.in_buffer
160
self._send_response(self.request.response)
161
except KeyboardInterrupt:
163
except errors.UnknownSmartMethod, err:
164
protocol_error = errors.SmartProtocolError(
165
"bad request %r" % (err.verb,))
166
failure = request.FailedSmartServerResponse(
167
('error', str(protocol_error)))
168
self._send_response(failure)
170
except Exception, exception:
171
# everything else: pass to client, flush, and quit
172
log_exception_quietly()
173
self._send_response(request.FailedSmartServerResponse(
174
('error', str(exception))))
177
if self._has_dispatched:
179
# nothing to do.XXX: this routine should be a single state
181
self.unused_data += self.in_buffer
184
if self._body_decoder is None:
185
self._body_decoder = LengthPrefixedBodyDecoder()
186
self._body_decoder.accept_bytes(self.in_buffer)
187
self.in_buffer = self._body_decoder.unused_data
188
body_data = self._body_decoder.read_pending_data()
189
self.request.accept_body(body_data)
190
if self._body_decoder.finished_reading:
191
self.request.end_of_body()
192
if not self.request.finished_reading:
193
raise AssertionError("no more body, request not finished")
194
if self.request.response is not None:
195
self._send_response(self.request.response)
196
self.unused_data = self.in_buffer
199
if self.request.finished_reading:
200
raise AssertionError(
201
"no response and we have finished reading.")
203
def _send_response(self, response):
204
"""Send a smart server response down the output stream."""
206
raise AssertionError('response already sent')
209
self._finished = True
210
self._write_protocol_version()
211
self._write_success_or_failure_prefix(response)
212
self._write_func(_encode_tuple(args))
214
if not isinstance(body, str):
215
raise ValueError(body)
216
bytes = self._encode_bulk_data(body)
217
self._write_func(bytes)
219
def _write_protocol_version(self):
220
"""Write any prefixes this protocol requires.
222
Version one doesn't send protocol versions.
225
def _write_success_or_failure_prefix(self, response):
226
"""Write the protocol specific success/failure prefix.
228
For SmartServerRequestProtocolOne this is omitted but we
229
call is_successful to ensure that the response is valid.
231
response.is_successful()
233
def next_read_size(self):
236
if self._body_decoder is None:
239
return self._body_decoder.next_read_size()
242
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
243
r"""Version two of the server side of the smart protocol.
245
This prefixes responses with the value of RESPONSE_VERSION_TWO.
248
response_marker = RESPONSE_VERSION_TWO
249
request_marker = REQUEST_VERSION_TWO
251
def _write_success_or_failure_prefix(self, response):
252
"""Write the protocol specific success/failure prefix."""
253
if response.is_successful():
254
self._write_func('success\n')
256
self._write_func('failed\n')
258
def _write_protocol_version(self):
259
r"""Write any prefixes this protocol requires.
261
Version two sends the value of RESPONSE_VERSION_TWO.
263
self._write_func(self.response_marker)
265
def _send_response(self, response):
266
"""Send a smart server response down the output stream."""
268
raise AssertionError('response already sent')
269
self._finished = True
270
self._write_protocol_version()
271
self._write_success_or_failure_prefix(response)
272
self._write_func(_encode_tuple(response.args))
273
if response.body is not None:
274
if not isinstance(response.body, str):
275
raise AssertionError('body must be a str')
276
if not (response.body_stream is None):
277
raise AssertionError(
278
'body_stream and body cannot both be set')
279
bytes = self._encode_bulk_data(response.body)
280
self._write_func(bytes)
281
elif response.body_stream is not None:
282
_send_stream(response.body_stream, self._write_func)
285
def _send_stream(stream, write_func):
286
write_func('chunked\n')
287
_send_chunks(stream, write_func)
291
def _send_chunks(stream, write_func):
293
if isinstance(chunk, str):
294
bytes = "%x\n%s" % (len(chunk), chunk)
296
elif isinstance(chunk, request.FailedSmartServerResponse):
298
_send_chunks(chunk.args, write_func)
301
raise errors.BzrError(
302
'Chunks must be str or FailedSmartServerResponse, got %r'
306
class _NeedMoreBytes(Exception):
307
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
311
def __init__(self, count=None):
314
:param count: the total number of bytes needed by the current state.
315
May be None if the number of bytes needed is unknown.
320
class _StatefulDecoder(object):
321
"""Base class for writing state machines to decode byte streams.
323
Subclasses should provide a self.state_accept attribute that accepts bytes
324
and, if appropriate, updates self.state_accept to a different function.
325
accept_bytes will call state_accept as often as necessary to make sure the
326
state machine has progressed as far as possible before it returns.
328
See ProtocolThreeDecoder for an example subclass.
332
self.finished_reading = False
333
self._in_buffer_list = []
334
self._in_buffer_len = 0
335
self.unused_data = ''
336
self.bytes_left = None
337
self._number_needed_bytes = None
339
def _get_in_buffer(self):
340
if len(self._in_buffer_list) == 1:
341
return self._in_buffer_list[0]
342
in_buffer = ''.join(self._in_buffer_list)
343
if len(in_buffer) != self._in_buffer_len:
344
raise AssertionError(
345
"Length of buffer did not match expected value: %s != %s"
346
% self._in_buffer_len, len(in_buffer))
347
self._in_buffer_list = [in_buffer]
350
def _get_in_bytes(self, count):
351
"""Grab X bytes from the input_buffer.
353
Callers should have already checked that self._in_buffer_len is >
354
count. Note, this does not consume the bytes from the buffer. The
355
caller will still need to call _get_in_buffer() and then
356
_set_in_buffer() if they actually need to consume the bytes.
358
# check if we can yield the bytes from just the first entry in our list
359
if len(self._in_buffer_list) == 0:
360
raise AssertionError('Callers must be sure we have buffered bytes'
361
' before calling _get_in_bytes')
362
if len(self._in_buffer_list[0]) > count:
363
return self._in_buffer_list[0][:count]
364
# We can't yield it from the first buffer, so collapse all buffers, and
366
in_buf = self._get_in_buffer()
367
return in_buf[:count]
369
def _set_in_buffer(self, new_buf):
370
if new_buf is not None:
371
self._in_buffer_list = [new_buf]
372
self._in_buffer_len = len(new_buf)
374
self._in_buffer_list = []
375
self._in_buffer_len = 0
377
def accept_bytes(self, bytes):
378
"""Decode as much of bytes as possible.
380
If 'bytes' contains too much data it will be appended to
383
finished_reading will be set when no more data is required. Further
384
data will be appended to self.unused_data.
386
# accept_bytes is allowed to change the state
387
self._number_needed_bytes = None
388
# lsprof puts a very large amount of time on this specific call for
390
self._in_buffer_list.append(bytes)
391
self._in_buffer_len += len(bytes)
393
# Run the function for the current state.
394
current_state = self.state_accept
396
while current_state != self.state_accept:
397
# The current state has changed. Run the function for the new
398
# current state, so that it can:
399
# - decode any unconsumed bytes left in a buffer, and
400
# - signal how many more bytes are expected (via raising
402
current_state = self.state_accept
404
except _NeedMoreBytes, e:
405
self._number_needed_bytes = e.count
408
class ChunkedBodyDecoder(_StatefulDecoder):
409
"""Decoder for chunked body data.
411
This is very similar the HTTP's chunked encoding. See the description of
412
streamed body data in `doc/developers/network-protocol.txt` for details.
416
_StatefulDecoder.__init__(self)
417
self.state_accept = self._state_accept_expecting_header
418
self.chunk_in_progress = None
419
self.chunks = collections.deque()
421
self.error_in_progress = None
423
def next_read_size(self):
424
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
425
# end-of-body marker is 4 bytes: 'END\n'.
426
if self.state_accept == self._state_accept_reading_chunk:
427
# We're expecting more chunk content. So we're expecting at least
428
# the rest of this chunk plus an END chunk.
429
return self.bytes_left + 4
430
elif self.state_accept == self._state_accept_expecting_length:
431
if self._in_buffer_len == 0:
432
# We're expecting a chunk length. There's at least two bytes
433
# left: a digit plus '\n'.
436
# We're in the middle of reading a chunk length. So there's at
437
# least one byte left, the '\n' that terminates the length.
439
elif self.state_accept == self._state_accept_reading_unused:
441
elif self.state_accept == self._state_accept_expecting_header:
442
return max(0, len('chunked\n') - self._in_buffer_len)
444
raise AssertionError("Impossible state: %r" % (self.state_accept,))
446
def read_next_chunk(self):
448
return self.chunks.popleft()
452
def _extract_line(self):
453
in_buf = self._get_in_buffer()
454
pos = in_buf.find('\n')
456
# We haven't read a complete line yet, so request more bytes before
458
raise _NeedMoreBytes(1)
460
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
461
self._set_in_buffer(in_buf[pos+1:])
465
self.unused_data = self._get_in_buffer()
466
self._in_buffer_list = []
467
self._in_buffer_len = 0
468
self.state_accept = self._state_accept_reading_unused
470
error_args = tuple(self.error_in_progress)
471
self.chunks.append(request.FailedSmartServerResponse(error_args))
472
self.error_in_progress = None
473
self.finished_reading = True
475
def _state_accept_expecting_header(self):
476
prefix = self._extract_line()
477
if prefix == 'chunked':
478
self.state_accept = self._state_accept_expecting_length
480
raise errors.SmartProtocolError(
481
'Bad chunked body header: "%s"' % (prefix,))
483
def _state_accept_expecting_length(self):
484
prefix = self._extract_line()
487
self.error_in_progress = []
488
self._state_accept_expecting_length()
490
elif prefix == 'END':
491
# We've read the end-of-body marker.
492
# Any further bytes are unused data, including the bytes left in
497
self.bytes_left = int(prefix, 16)
498
self.chunk_in_progress = ''
499
self.state_accept = self._state_accept_reading_chunk
501
def _state_accept_reading_chunk(self):
502
in_buf = self._get_in_buffer()
503
in_buffer_len = len(in_buf)
504
self.chunk_in_progress += in_buf[:self.bytes_left]
505
self._set_in_buffer(in_buf[self.bytes_left:])
506
self.bytes_left -= in_buffer_len
507
if self.bytes_left <= 0:
508
# Finished with chunk
509
self.bytes_left = None
511
self.error_in_progress.append(self.chunk_in_progress)
513
self.chunks.append(self.chunk_in_progress)
514
self.chunk_in_progress = None
515
self.state_accept = self._state_accept_expecting_length
517
def _state_accept_reading_unused(self):
518
self.unused_data += self._get_in_buffer()
519
self._in_buffer_list = []
522
class LengthPrefixedBodyDecoder(_StatefulDecoder):
523
"""Decodes the length-prefixed bulk data."""
526
_StatefulDecoder.__init__(self)
527
self.state_accept = self._state_accept_expecting_length
528
self.state_read = self._state_read_no_data
530
self._trailer_buffer = ''
532
def next_read_size(self):
533
if self.bytes_left is not None:
534
# Ideally we want to read all the remainder of the body and the
536
return self.bytes_left + 5
537
elif self.state_accept == self._state_accept_reading_trailer:
538
# Just the trailer left
539
return 5 - len(self._trailer_buffer)
540
elif self.state_accept == self._state_accept_expecting_length:
541
# There's still at least 6 bytes left ('\n' to end the length, plus
545
# Reading excess data. Either way, 1 byte at a time is fine.
548
def read_pending_data(self):
549
"""Return any pending data that has been decoded."""
550
return self.state_read()
552
def _state_accept_expecting_length(self):
553
in_buf = self._get_in_buffer()
554
pos = in_buf.find('\n')
557
self.bytes_left = int(in_buf[:pos])
558
self._set_in_buffer(in_buf[pos+1:])
559
self.state_accept = self._state_accept_reading_body
560
self.state_read = self._state_read_body_buffer
562
def _state_accept_reading_body(self):
563
in_buf = self._get_in_buffer()
565
self.bytes_left -= len(in_buf)
566
self._set_in_buffer(None)
567
if self.bytes_left <= 0:
569
if self.bytes_left != 0:
570
self._trailer_buffer = self._body[self.bytes_left:]
571
self._body = self._body[:self.bytes_left]
572
self.bytes_left = None
573
self.state_accept = self._state_accept_reading_trailer
575
def _state_accept_reading_trailer(self):
576
self._trailer_buffer += self._get_in_buffer()
577
self._set_in_buffer(None)
578
# TODO: what if the trailer does not match "done\n"? Should this raise
579
# a ProtocolViolation exception?
580
if self._trailer_buffer.startswith('done\n'):
581
self.unused_data = self._trailer_buffer[len('done\n'):]
582
self.state_accept = self._state_accept_reading_unused
583
self.finished_reading = True
585
def _state_accept_reading_unused(self):
586
self.unused_data += self._get_in_buffer()
587
self._set_in_buffer(None)
589
def _state_read_no_data(self):
592
def _state_read_body_buffer(self):
598
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
599
message.ResponseHandler):
600
"""The client-side protocol for smart version 1."""
602
def __init__(self, request):
603
"""Construct a SmartClientRequestProtocolOne.
605
:param request: A SmartClientMediumRequest to serialise onto and
608
self._request = request
609
self._body_buffer = None
610
self._request_start_time = None
611
self._last_verb = None
614
def set_headers(self, headers):
615
self._headers = dict(headers)
617
def call(self, *args):
618
if 'hpss' in debug.debug_flags:
619
mutter('hpss call: %s', repr(args)[1:-1])
620
if getattr(self._request._medium, 'base', None) is not None:
621
mutter(' (to %s)', self._request._medium.base)
622
self._request_start_time = osutils.timer_func()
623
self._write_args(args)
624
self._request.finished_writing()
625
self._last_verb = args[0]
627
def call_with_body_bytes(self, args, body):
628
"""Make a remote call of args with body bytes 'body'.
630
After calling this, call read_response_tuple to find the result out.
632
if 'hpss' in debug.debug_flags:
633
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
634
if getattr(self._request._medium, '_path', None) is not None:
635
mutter(' (to %s)', self._request._medium._path)
636
mutter(' %d bytes', len(body))
637
self._request_start_time = osutils.timer_func()
638
if 'hpssdetail' in debug.debug_flags:
639
mutter('hpss body content: %s', body)
640
self._write_args(args)
641
bytes = self._encode_bulk_data(body)
642
self._request.accept_bytes(bytes)
643
self._request.finished_writing()
644
self._last_verb = args[0]
646
def call_with_body_readv_array(self, args, body):
647
"""Make a remote call with a readv array.
649
The body is encoded with one line per readv offset pair. The numbers in
650
each pair are separated by a comma, and no trailing \n is emitted.
652
if 'hpss' in debug.debug_flags:
653
mutter('hpss call w/readv: %s', repr(args)[1:-1])
654
if getattr(self._request._medium, '_path', None) is not None:
655
mutter(' (to %s)', self._request._medium._path)
656
self._request_start_time = osutils.timer_func()
657
self._write_args(args)
658
readv_bytes = self._serialise_offsets(body)
659
bytes = self._encode_bulk_data(readv_bytes)
660
self._request.accept_bytes(bytes)
661
self._request.finished_writing()
662
if 'hpss' in debug.debug_flags:
663
mutter(' %d bytes in readv request', len(readv_bytes))
664
self._last_verb = args[0]
666
def call_with_body_stream(self, args, stream):
667
# Protocols v1 and v2 don't support body streams. So it's safe to
668
# assume that a v1/v2 server doesn't support whatever method we're
669
# trying to call with a body stream.
670
self._request.finished_writing()
671
self._request.finished_reading()
672
raise errors.UnknownSmartMethod(args[0])
674
def cancel_read_body(self):
675
"""After expecting a body, a response code may indicate one otherwise.
677
This method lets the domain client inform the protocol that no body
678
will be transmitted. This is a terminal method: after calling it the
679
protocol is not able to be used further.
681
self._request.finished_reading()
683
def _read_response_tuple(self):
684
result = self._recv_tuple()
685
if 'hpss' in debug.debug_flags:
686
if self._request_start_time is not None:
687
mutter(' result: %6.3fs %s',
688
osutils.timer_func() - self._request_start_time,
690
self._request_start_time = None
692
mutter(' result: %s', repr(result)[1:-1])
695
def read_response_tuple(self, expect_body=False):
696
"""Read a response tuple from the wire.
698
This should only be called once.
700
result = self._read_response_tuple()
701
self._response_is_unknown_method(result)
702
self._raise_args_if_error(result)
704
self._request.finished_reading()
707
def _raise_args_if_error(self, result_tuple):
708
# Later protocol versions have an explicit flag in the protocol to say
709
# if an error response is "failed" or not. In version 1 we don't have
710
# that luxury. So here is a complete list of errors that can be
711
# returned in response to existing version 1 smart requests. Responses
712
# starting with these codes are always "failed" responses.
719
'UnicodeEncodeError',
720
'UnicodeDecodeError',
726
'UnlockableTransport',
732
if result_tuple[0] in v1_error_codes:
733
self._request.finished_reading()
734
raise errors.ErrorFromSmartServer(result_tuple)
736
def _response_is_unknown_method(self, result_tuple):
737
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
738
method' response to the request.
740
:param response: The response from a smart client call_expecting_body
742
:param verb: The verb used in that call.
743
:raises: UnexpectedSmartServerResponse
745
if (result_tuple == ('error', "Generic bzr smart protocol error: "
746
"bad request '%s'" % self._last_verb) or
747
result_tuple == ('error', "Generic bzr smart protocol error: "
748
"bad request u'%s'" % self._last_verb)):
749
# The response will have no body, so we've finished reading.
750
self._request.finished_reading()
751
raise errors.UnknownSmartMethod(self._last_verb)
753
def read_body_bytes(self, count=-1):
754
"""Read bytes from the body, decoding into a byte stream.
756
We read all bytes at once to ensure we've checked the trailer for
757
errors, and then feed the buffer back as read_body_bytes is called.
759
if self._body_buffer is not None:
760
return self._body_buffer.read(count)
761
_body_decoder = LengthPrefixedBodyDecoder()
763
while not _body_decoder.finished_reading:
764
bytes = self._request.read_bytes(_body_decoder.next_read_size())
766
# end of file encountered reading from server
767
raise errors.ConnectionReset(
768
"Connection lost while reading response body.")
769
_body_decoder.accept_bytes(bytes)
770
self._request.finished_reading()
771
self._body_buffer = StringIO(_body_decoder.read_pending_data())
772
# XXX: TODO check the trailer result.
773
if 'hpss' in debug.debug_flags:
774
mutter(' %d body bytes read',
775
len(self._body_buffer.getvalue()))
776
return self._body_buffer.read(count)
778
def _recv_tuple(self):
779
"""Receive a tuple from the medium request."""
780
return _decode_tuple(self._request.read_line())
782
def query_version(self):
783
"""Return protocol version number of the server."""
785
resp = self.read_response_tuple()
786
if resp == ('ok', '1'):
788
elif resp == ('ok', '2'):
791
raise errors.SmartProtocolError("bad response %r" % (resp,))
793
def _write_args(self, args):
794
self._write_protocol_version()
795
bytes = _encode_tuple(args)
796
self._request.accept_bytes(bytes)
798
def _write_protocol_version(self):
799
"""Write any prefixes this protocol requires.
801
Version one doesn't send protocol versions.
805
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
806
"""Version two of the client side of the smart protocol.
808
This prefixes the request with the value of REQUEST_VERSION_TWO.
811
response_marker = RESPONSE_VERSION_TWO
812
request_marker = REQUEST_VERSION_TWO
814
def read_response_tuple(self, expect_body=False):
815
"""Read a response tuple from the wire.
817
This should only be called once.
819
version = self._request.read_line()
820
if version != self.response_marker:
821
self._request.finished_reading()
822
raise errors.UnexpectedProtocolVersionMarker(version)
823
response_status = self._request.read_line()
824
result = SmartClientRequestProtocolOne._read_response_tuple(self)
825
self._response_is_unknown_method(result)
826
if response_status == 'success\n':
827
self.response_status = True
829
self._request.finished_reading()
831
elif response_status == 'failed\n':
832
self.response_status = False
833
self._request.finished_reading()
834
raise errors.ErrorFromSmartServer(result)
836
raise errors.SmartProtocolError(
837
'bad protocol status %r' % response_status)
839
def _write_protocol_version(self):
840
"""Write any prefixes this protocol requires.
842
Version two sends the value of REQUEST_VERSION_TWO.
844
self._request.accept_bytes(self.request_marker)
846
def read_streamed_body(self):
847
"""Read bytes from the body, decoding into a byte stream.
849
# Read no more than 64k at a time so that we don't risk error 10055 (no
850
# buffer space available) on Windows.
851
_body_decoder = ChunkedBodyDecoder()
852
while not _body_decoder.finished_reading:
853
bytes = self._request.read_bytes(_body_decoder.next_read_size())
855
# end of file encountered reading from server
856
raise errors.ConnectionReset(
857
"Connection lost while reading streamed body.")
858
_body_decoder.accept_bytes(bytes)
859
for body_bytes in iter(_body_decoder.read_next_chunk, None):
860
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
861
mutter(' %d byte chunk read',
864
self._request.finished_reading()
867
def build_server_protocol_three(backing_transport, write_func,
868
root_client_path, jail_root=None):
869
request_handler = request.SmartServerRequestHandler(
870
backing_transport, commands=request.request_handlers,
871
root_client_path=root_client_path, jail_root=jail_root)
872
responder = ProtocolThreeResponder(write_func)
873
message_handler = message.ConventionalRequestHandler(request_handler, responder)
874
return ProtocolThreeDecoder(message_handler)
877
class ProtocolThreeDecoder(_StatefulDecoder):
879
response_marker = RESPONSE_VERSION_THREE
880
request_marker = REQUEST_VERSION_THREE
882
def __init__(self, message_handler, expect_version_marker=False):
883
_StatefulDecoder.__init__(self)
884
self._has_dispatched = False
886
if expect_version_marker:
887
self.state_accept = self._state_accept_expecting_protocol_version
888
# We're expecting at least the protocol version marker + some
890
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
892
self.state_accept = self._state_accept_expecting_headers
893
self._number_needed_bytes = 4
894
self.decoding_failed = False
895
self.request_handler = self.message_handler = message_handler
897
def accept_bytes(self, bytes):
898
self._number_needed_bytes = None
900
_StatefulDecoder.accept_bytes(self, bytes)
901
except KeyboardInterrupt:
903
except errors.SmartMessageHandlerError, exception:
904
# We do *not* set self.decoding_failed here. The message handler
905
# has raised an error, but the decoder is still able to parse bytes
906
# and determine when this message ends.
907
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
908
log_exception_quietly()
909
self.message_handler.protocol_error(exception.exc_value)
910
# The state machine is ready to continue decoding, but the
911
# exception has interrupted the loop that runs the state machine.
912
# So we call accept_bytes again to restart it.
913
self.accept_bytes('')
914
except Exception, exception:
915
# The decoder itself has raised an exception. We cannot continue
917
self.decoding_failed = True
918
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
919
# This happens during normal operation when the client tries a
920
# protocol version the server doesn't understand, so no need to
921
# log a traceback every time.
922
# Note that this can only happen when
923
# expect_version_marker=True, which is only the case on the
927
log_exception_quietly()
928
self.message_handler.protocol_error(exception)
930
def _extract_length_prefixed_bytes(self):
931
if self._in_buffer_len < 4:
932
# A length prefix by itself is 4 bytes, and we don't even have that
934
raise _NeedMoreBytes(4)
935
(length,) = struct.unpack('!L', self._get_in_bytes(4))
936
end_of_bytes = 4 + length
937
if self._in_buffer_len < end_of_bytes:
938
# We haven't yet read as many bytes as the length-prefix says there
940
raise _NeedMoreBytes(end_of_bytes)
941
# Extract the bytes from the buffer.
942
in_buf = self._get_in_buffer()
943
bytes = in_buf[4:end_of_bytes]
944
self._set_in_buffer(in_buf[end_of_bytes:])
947
def _extract_prefixed_bencoded_data(self):
948
prefixed_bytes = self._extract_length_prefixed_bytes()
950
decoded = bdecode_as_tuple(prefixed_bytes)
952
raise errors.SmartProtocolError(
953
'Bytes %r not bencoded' % (prefixed_bytes,))
956
def _extract_single_byte(self):
957
if self._in_buffer_len == 0:
958
# The buffer is empty
959
raise _NeedMoreBytes(1)
960
in_buf = self._get_in_buffer()
962
self._set_in_buffer(in_buf[1:])
965
def _state_accept_expecting_protocol_version(self):
966
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
967
in_buf = self._get_in_buffer()
969
# We don't have enough bytes to check if the protocol version
970
# marker is right. But we can check if it is already wrong by
971
# checking that the start of MESSAGE_VERSION_THREE matches what
973
# [In fact, if the remote end isn't bzr we might never receive
974
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
975
# are wrong then we should just raise immediately rather than
977
if not MESSAGE_VERSION_THREE.startswith(in_buf):
978
# We have enough bytes to know the protocol version is wrong
979
raise errors.UnexpectedProtocolVersionMarker(in_buf)
980
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
981
if not in_buf.startswith(MESSAGE_VERSION_THREE):
982
raise errors.UnexpectedProtocolVersionMarker(in_buf)
983
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
984
self.state_accept = self._state_accept_expecting_headers
986
def _state_accept_expecting_headers(self):
987
decoded = self._extract_prefixed_bencoded_data()
988
if type(decoded) is not dict:
989
raise errors.SmartProtocolError(
990
'Header object %r is not a dict' % (decoded,))
991
self.state_accept = self._state_accept_expecting_message_part
993
self.message_handler.headers_received(decoded)
995
raise errors.SmartMessageHandlerError(sys.exc_info())
997
def _state_accept_expecting_message_part(self):
998
message_part_kind = self._extract_single_byte()
999
if message_part_kind == 'o':
1000
self.state_accept = self._state_accept_expecting_one_byte
1001
elif message_part_kind == 's':
1002
self.state_accept = self._state_accept_expecting_structure
1003
elif message_part_kind == 'b':
1004
self.state_accept = self._state_accept_expecting_bytes
1005
elif message_part_kind == 'e':
1008
raise errors.SmartProtocolError(
1009
'Bad message kind byte: %r' % (message_part_kind,))
1011
def _state_accept_expecting_one_byte(self):
1012
byte = self._extract_single_byte()
1013
self.state_accept = self._state_accept_expecting_message_part
1015
self.message_handler.byte_part_received(byte)
1017
raise errors.SmartMessageHandlerError(sys.exc_info())
1019
def _state_accept_expecting_bytes(self):
1020
# XXX: this should not buffer whole message part, but instead deliver
1021
# the bytes as they arrive.
1022
prefixed_bytes = self._extract_length_prefixed_bytes()
1023
self.state_accept = self._state_accept_expecting_message_part
1025
self.message_handler.bytes_part_received(prefixed_bytes)
1027
raise errors.SmartMessageHandlerError(sys.exc_info())
1029
def _state_accept_expecting_structure(self):
1030
structure = self._extract_prefixed_bencoded_data()
1031
self.state_accept = self._state_accept_expecting_message_part
1033
self.message_handler.structure_part_received(structure)
1035
raise errors.SmartMessageHandlerError(sys.exc_info())
1038
self.unused_data = self._get_in_buffer()
1039
self._set_in_buffer(None)
1040
self.state_accept = self._state_accept_reading_unused
1042
self.message_handler.end_received()
1044
raise errors.SmartMessageHandlerError(sys.exc_info())
1046
def _state_accept_reading_unused(self):
1047
self.unused_data += self._get_in_buffer()
1048
self._set_in_buffer(None)
1050
def next_read_size(self):
1051
if self.state_accept == self._state_accept_reading_unused:
1053
elif self.decoding_failed:
1054
# An exception occured while processing this message, probably from
1055
# self.message_handler. We're not sure that this state machine is
1056
# in a consistent state, so just signal that we're done (i.e. give
1060
if self._number_needed_bytes is not None:
1061
return self._number_needed_bytes - self._in_buffer_len
1063
raise AssertionError("don't know how many bytes are expected!")
1066
class _ProtocolThreeEncoder(object):
1068
response_marker = request_marker = MESSAGE_VERSION_THREE
1069
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1071
def __init__(self, write_func):
1074
self._real_write_func = write_func
1076
def _write_func(self, bytes):
1077
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1078
# for total number of bytes to write, rather than buffer based on
1079
# the number of write() calls
1080
# TODO: Another possibility would be to turn this into an async model.
1081
# Where we let another thread know that we have some bytes if
1082
# they want it, but we don't actually block for it
1083
# Note that osutils.send_all always sends 64kB chunks anyway, so
1084
# we might just push out smaller bits at a time?
1085
self._buf.append(bytes)
1086
self._buf_len += len(bytes)
1087
if self._buf_len > self.BUFFER_SIZE:
1092
self._real_write_func(''.join(self._buf))
1096
def _serialise_offsets(self, offsets):
1097
"""Serialise a readv offset list."""
1099
for start, length in offsets:
1100
txt.append('%d,%d' % (start, length))
1101
return '\n'.join(txt)
1103
def _write_protocol_version(self):
1104
self._write_func(MESSAGE_VERSION_THREE)
1106
def _write_prefixed_bencode(self, structure):
1107
bytes = bencode(structure)
1108
self._write_func(struct.pack('!L', len(bytes)))
1109
self._write_func(bytes)
1111
def _write_headers(self, headers):
1112
self._write_prefixed_bencode(headers)
1114
def _write_structure(self, args):
1115
self._write_func('s')
1118
if type(arg) is unicode:
1119
utf8_args.append(arg.encode('utf8'))
1121
utf8_args.append(arg)
1122
self._write_prefixed_bencode(utf8_args)
1124
def _write_end(self):
1125
self._write_func('e')
1128
def _write_prefixed_body(self, bytes):
1129
self._write_func('b')
1130
self._write_func(struct.pack('!L', len(bytes)))
1131
self._write_func(bytes)
1133
def _write_chunked_body_start(self):
1134
self._write_func('oC')
1136
def _write_error_status(self):
1137
self._write_func('oE')
1139
def _write_success_status(self):
1140
self._write_func('oS')
1143
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1145
def __init__(self, write_func):
1146
_ProtocolThreeEncoder.__init__(self, write_func)
1147
self.response_sent = False
1148
self._headers = {'Software version': bzrlib.__version__}
1149
if 'hpss' in debug.debug_flags:
1150
self._thread_id = threading.currentThread().get_ident()
1151
self._response_start_time = None
1153
def _trace(self, action, message, extra_bytes=None, include_time=False):
1154
if self._response_start_time is None:
1155
self._response_start_time = osutils.timer_func()
1157
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1160
if extra_bytes is None:
1163
extra = ' ' + repr(extra_bytes[:40])
1165
extra = extra[:29] + extra[-1] + '...'
1166
mutter('%12s: [%s] %s%s%s'
1167
% (action, self._thread_id, t, message, extra))
1169
def send_error(self, exception):
1170
if self.response_sent:
1171
raise AssertionError(
1172
"send_error(%s) called, but response already sent."
1174
if isinstance(exception, errors.UnknownSmartMethod):
1175
failure = request.FailedSmartServerResponse(
1176
('UnknownMethod', exception.verb))
1177
self.send_response(failure)
1179
if 'hpss' in debug.debug_flags:
1180
self._trace('error', str(exception))
1181
self.response_sent = True
1182
self._write_protocol_version()
1183
self._write_headers(self._headers)
1184
self._write_error_status()
1185
self._write_structure(('error', str(exception)))
1188
def send_response(self, response):
1189
if self.response_sent:
1190
raise AssertionError(
1191
"send_response(%r) called, but response already sent."
1193
self.response_sent = True
1194
self._write_protocol_version()
1195
self._write_headers(self._headers)
1196
if response.is_successful():
1197
self._write_success_status()
1199
self._write_error_status()
1200
if 'hpss' in debug.debug_flags:
1201
self._trace('response', repr(response.args))
1202
self._write_structure(response.args)
1203
if response.body is not None:
1204
self._write_prefixed_body(response.body)
1205
if 'hpss' in debug.debug_flags:
1206
self._trace('body', '%d bytes' % (len(response.body),),
1207
response.body, include_time=True)
1208
elif response.body_stream is not None:
1209
count = num_bytes = 0
1211
for exc_info, chunk in _iter_with_errors(response.body_stream):
1213
if exc_info is not None:
1214
self._write_error_status()
1215
error_struct = request._translate_error(exc_info[1])
1216
self._write_structure(error_struct)
1219
if isinstance(chunk, request.FailedSmartServerResponse):
1220
self._write_error_status()
1221
self._write_structure(chunk.args)
1223
num_bytes += len(chunk)
1224
if first_chunk is None:
1226
self._write_prefixed_body(chunk)
1227
if 'hpssdetail' in debug.debug_flags:
1228
# Not worth timing separately, as _write_func is
1230
self._trace('body chunk',
1231
'%d bytes' % (len(chunk),),
1232
chunk, suppress_time=True)
1233
if 'hpss' in debug.debug_flags:
1234
self._trace('body stream',
1235
'%d bytes %d chunks' % (num_bytes, count),
1238
if 'hpss' in debug.debug_flags:
1239
self._trace('response end', '', include_time=True)
1242
def _iter_with_errors(iterable):
1243
"""Handle errors from iterable.next().
1247
for exc_info, value in _iter_with_errors(iterable):
1250
This is a safer alternative to::
1253
for value in iterable:
1258
Because the latter will catch errors from the for-loop body, not just
1261
If an error occurs, exc_info will be a exc_info tuple, and the generator
1262
will terminate. Otherwise exc_info will be None, and value will be the
1263
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1264
will not be itercepted.
1266
iterator = iter(iterable)
1269
yield None, iterator.next()
1270
except StopIteration:
1272
except (KeyboardInterrupt, SystemExit):
1275
mutter('_iter_with_errors caught error')
1276
log_exception_quietly()
1277
yield sys.exc_info(), None
1281
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1283
def __init__(self, medium_request):
1284
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1285
self._medium_request = medium_request
1288
def set_headers(self, headers):
1289
self._headers = headers.copy()
1291
def call(self, *args):
1292
if 'hpss' in debug.debug_flags:
1293
mutter('hpss call: %s', repr(args)[1:-1])
1294
base = getattr(self._medium_request._medium, 'base', None)
1295
if base is not None:
1296
mutter(' (to %s)', base)
1297
self._request_start_time = osutils.timer_func()
1298
self._write_protocol_version()
1299
self._write_headers(self._headers)
1300
self._write_structure(args)
1302
self._medium_request.finished_writing()
1304
def call_with_body_bytes(self, args, body):
1305
"""Make a remote call of args with body bytes 'body'.
1307
After calling this, call read_response_tuple to find the result out.
1309
if 'hpss' in debug.debug_flags:
1310
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1311
path = getattr(self._medium_request._medium, '_path', None)
1312
if path is not None:
1313
mutter(' (to %s)', path)
1314
mutter(' %d bytes', len(body))
1315
self._request_start_time = osutils.timer_func()
1316
self._write_protocol_version()
1317
self._write_headers(self._headers)
1318
self._write_structure(args)
1319
self._write_prefixed_body(body)
1321
self._medium_request.finished_writing()
1323
def call_with_body_readv_array(self, args, body):
1324
"""Make a remote call with a readv array.
1326
The body is encoded with one line per readv offset pair. The numbers in
1327
each pair are separated by a comma, and no trailing \n is emitted.
1329
if 'hpss' in debug.debug_flags:
1330
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1331
path = getattr(self._medium_request._medium, '_path', None)
1332
if path is not None:
1333
mutter(' (to %s)', path)
1334
self._request_start_time = osutils.timer_func()
1335
self._write_protocol_version()
1336
self._write_headers(self._headers)
1337
self._write_structure(args)
1338
readv_bytes = self._serialise_offsets(body)
1339
if 'hpss' in debug.debug_flags:
1340
mutter(' %d bytes in readv request', len(readv_bytes))
1341
self._write_prefixed_body(readv_bytes)
1343
self._medium_request.finished_writing()
1345
def call_with_body_stream(self, args, stream):
1346
if 'hpss' in debug.debug_flags:
1347
mutter('hpss call w/body stream: %r', args)
1348
path = getattr(self._medium_request._medium, '_path', None)
1349
if path is not None:
1350
mutter(' (to %s)', path)
1351
self._request_start_time = osutils.timer_func()
1352
self._write_protocol_version()
1353
self._write_headers(self._headers)
1354
self._write_structure(args)
1355
# TODO: notice if the server has sent an early error reply before we
1356
# have finished sending the stream. We would notice at the end
1357
# anyway, but if the medium can deliver it early then it's good
1358
# to short-circuit the whole request...
1359
for exc_info, part in _iter_with_errors(stream):
1360
if exc_info is not None:
1361
# Iterating the stream failed. Cleanly abort the request.
1362
self._write_error_status()
1363
# Currently the client unconditionally sends ('error',) as the
1365
self._write_structure(('error',))
1367
self._medium_request.finished_writing()
1368
raise exc_info[0], exc_info[1], exc_info[2]
1370
self._write_prefixed_body(part)
1373
self._medium_request.finished_writing()