1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
29
import thread as _thread
38
from ..sixish import (
42
from . import message, request
43
from ..trace import log_exception_quietly, mutter
44
from ..bencode import bdecode_as_tuple, bencode
47
# Protocol version strings. These are sent as prefixes of bzr requests and
48
# responses to identify the protocol version being used. (There are no version
49
# one strings because that version doesn't send any).
50
REQUEST_VERSION_TWO = 'bzr request 2\n'
51
RESPONSE_VERSION_TWO = 'bzr response 2\n'
53
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
54
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
57
def _recv_tuple(from_file):
58
req_line = from_file.readline()
59
return _decode_tuple(req_line)
62
def _decode_tuple(req_line):
63
if req_line is None or req_line == '':
65
if req_line[-1] != '\n':
66
raise errors.SmartProtocolError("request %r not terminated" % req_line)
67
return tuple(req_line[:-1].split('\x01'))
70
def _encode_tuple(args):
71
"""Encode the tuple args to a bytestream."""
72
joined = '\x01'.join(args) + '\n'
73
if isinstance(joined, unicode):
74
# XXX: We should fix things so this never happens! -AJB, 20100304
75
mutter('response args contain unicode, should be only bytes: %r',
77
joined = joined.encode('ascii')
81
class Requester(object):
82
"""Abstract base class for an object that can issue requests on a smart
86
def call(self, *args):
87
"""Make a remote call.
89
:param args: the arguments of this call.
91
raise NotImplementedError(self.call)
93
def call_with_body_bytes(self, args, body):
94
"""Make a remote call with a body.
96
:param args: the arguments of this call.
98
:param body: the body to send with the request.
100
raise NotImplementedError(self.call_with_body_bytes)
102
def call_with_body_readv_array(self, args, body):
103
"""Make a remote call with a readv array.
105
:param args: the arguments of this call.
106
:type body: iterable of (start, length) tuples.
107
:param body: the readv ranges to send with this request.
109
raise NotImplementedError(self.call_with_body_readv_array)
111
def set_headers(self, headers):
112
raise NotImplementedError(self.set_headers)
115
class SmartProtocolBase(object):
116
"""Methods common to client and server"""
118
# TODO: this only actually accomodates a single block; possibly should
119
# support multiple chunks?
120
def _encode_bulk_data(self, body):
121
"""Encode body as a bulk data chunk."""
122
return ''.join(('%d\n' % len(body), body, 'done\n'))
124
def _serialise_offsets(self, offsets):
125
"""Serialise a readv offset list."""
127
for start, length in offsets:
128
txt.append('%d,%d' % (start, length))
129
return '\n'.join(txt)
132
class SmartServerRequestProtocolOne(SmartProtocolBase):
133
"""Server-side encoding and decoding logic for smart version 1."""
135
def __init__(self, backing_transport, write_func, root_client_path='/',
137
self._backing_transport = backing_transport
138
self._root_client_path = root_client_path
139
self._jail_root = jail_root
140
self.unused_data = ''
141
self._finished = False
143
self._has_dispatched = False
145
self._body_decoder = None
146
self._write_func = write_func
148
def accept_bytes(self, bytes):
149
"""Take bytes, and advance the internal state machine appropriately.
151
:param bytes: must be a byte string
153
if not isinstance(bytes, str):
154
raise ValueError(bytes)
155
self.in_buffer += bytes
156
if not self._has_dispatched:
157
if '\n' not in self.in_buffer:
158
# no command line yet
160
self._has_dispatched = True
162
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
164
req_args = _decode_tuple(first_line)
165
self.request = request.SmartServerRequestHandler(
166
self._backing_transport, commands=request.request_handlers,
167
root_client_path=self._root_client_path,
168
jail_root=self._jail_root)
169
self.request.args_received(req_args)
170
if self.request.finished_reading:
172
self.unused_data = self.in_buffer
174
self._send_response(self.request.response)
175
except KeyboardInterrupt:
177
except errors.UnknownSmartMethod as err:
178
protocol_error = errors.SmartProtocolError(
179
"bad request %r" % (err.verb,))
180
failure = request.FailedSmartServerResponse(
181
('error', str(protocol_error)))
182
self._send_response(failure)
184
except Exception as exception:
185
# everything else: pass to client, flush, and quit
186
log_exception_quietly()
187
self._send_response(request.FailedSmartServerResponse(
188
('error', str(exception))))
191
if self._has_dispatched:
193
# nothing to do.XXX: this routine should be a single state
195
self.unused_data += self.in_buffer
198
if self._body_decoder is None:
199
self._body_decoder = LengthPrefixedBodyDecoder()
200
self._body_decoder.accept_bytes(self.in_buffer)
201
self.in_buffer = self._body_decoder.unused_data
202
body_data = self._body_decoder.read_pending_data()
203
self.request.accept_body(body_data)
204
if self._body_decoder.finished_reading:
205
self.request.end_of_body()
206
if not self.request.finished_reading:
207
raise AssertionError("no more body, request not finished")
208
if self.request.response is not None:
209
self._send_response(self.request.response)
210
self.unused_data = self.in_buffer
213
if self.request.finished_reading:
214
raise AssertionError(
215
"no response and we have finished reading.")
217
def _send_response(self, response):
218
"""Send a smart server response down the output stream."""
220
raise AssertionError('response already sent')
223
self._finished = True
224
self._write_protocol_version()
225
self._write_success_or_failure_prefix(response)
226
self._write_func(_encode_tuple(args))
228
if not isinstance(body, str):
229
raise ValueError(body)
230
bytes = self._encode_bulk_data(body)
231
self._write_func(bytes)
233
def _write_protocol_version(self):
234
"""Write any prefixes this protocol requires.
236
Version one doesn't send protocol versions.
239
def _write_success_or_failure_prefix(self, response):
240
"""Write the protocol specific success/failure prefix.
242
For SmartServerRequestProtocolOne this is omitted but we
243
call is_successful to ensure that the response is valid.
245
response.is_successful()
247
def next_read_size(self):
250
if self._body_decoder is None:
253
return self._body_decoder.next_read_size()
256
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
257
r"""Version two of the server side of the smart protocol.
259
This prefixes responses with the value of RESPONSE_VERSION_TWO.
262
response_marker = RESPONSE_VERSION_TWO
263
request_marker = REQUEST_VERSION_TWO
265
def _write_success_or_failure_prefix(self, response):
266
"""Write the protocol specific success/failure prefix."""
267
if response.is_successful():
268
self._write_func('success\n')
270
self._write_func('failed\n')
272
def _write_protocol_version(self):
273
r"""Write any prefixes this protocol requires.
275
Version two sends the value of RESPONSE_VERSION_TWO.
277
self._write_func(self.response_marker)
279
def _send_response(self, response):
280
"""Send a smart server response down the output stream."""
282
raise AssertionError('response already sent')
283
self._finished = True
284
self._write_protocol_version()
285
self._write_success_or_failure_prefix(response)
286
self._write_func(_encode_tuple(response.args))
287
if response.body is not None:
288
if not isinstance(response.body, str):
289
raise AssertionError('body must be a str')
290
if not (response.body_stream is None):
291
raise AssertionError(
292
'body_stream and body cannot both be set')
293
bytes = self._encode_bulk_data(response.body)
294
self._write_func(bytes)
295
elif response.body_stream is not None:
296
_send_stream(response.body_stream, self._write_func)
299
def _send_stream(stream, write_func):
300
write_func('chunked\n')
301
_send_chunks(stream, write_func)
305
def _send_chunks(stream, write_func):
307
if isinstance(chunk, str):
308
bytes = "%x\n%s" % (len(chunk), chunk)
310
elif isinstance(chunk, request.FailedSmartServerResponse):
312
_send_chunks(chunk.args, write_func)
315
raise errors.BzrError(
316
'Chunks must be str or FailedSmartServerResponse, got %r'
320
class _NeedMoreBytes(Exception):
321
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
325
def __init__(self, count=None):
328
:param count: the total number of bytes needed by the current state.
329
May be None if the number of bytes needed is unknown.
334
class _StatefulDecoder(object):
335
"""Base class for writing state machines to decode byte streams.
337
Subclasses should provide a self.state_accept attribute that accepts bytes
338
and, if appropriate, updates self.state_accept to a different function.
339
accept_bytes will call state_accept as often as necessary to make sure the
340
state machine has progressed as far as possible before it returns.
342
See ProtocolThreeDecoder for an example subclass.
346
self.finished_reading = False
347
self._in_buffer_list = []
348
self._in_buffer_len = 0
349
self.unused_data = ''
350
self.bytes_left = None
351
self._number_needed_bytes = None
353
def _get_in_buffer(self):
354
if len(self._in_buffer_list) == 1:
355
return self._in_buffer_list[0]
356
in_buffer = ''.join(self._in_buffer_list)
357
if len(in_buffer) != self._in_buffer_len:
358
raise AssertionError(
359
"Length of buffer did not match expected value: %s != %s"
360
% self._in_buffer_len, len(in_buffer))
361
self._in_buffer_list = [in_buffer]
364
def _get_in_bytes(self, count):
365
"""Grab X bytes from the input_buffer.
367
Callers should have already checked that self._in_buffer_len is >
368
count. Note, this does not consume the bytes from the buffer. The
369
caller will still need to call _get_in_buffer() and then
370
_set_in_buffer() if they actually need to consume the bytes.
372
# check if we can yield the bytes from just the first entry in our list
373
if len(self._in_buffer_list) == 0:
374
raise AssertionError('Callers must be sure we have buffered bytes'
375
' before calling _get_in_bytes')
376
if len(self._in_buffer_list[0]) > count:
377
return self._in_buffer_list[0][:count]
378
# We can't yield it from the first buffer, so collapse all buffers, and
380
in_buf = self._get_in_buffer()
381
return in_buf[:count]
383
def _set_in_buffer(self, new_buf):
384
if new_buf is not None:
385
self._in_buffer_list = [new_buf]
386
self._in_buffer_len = len(new_buf)
388
self._in_buffer_list = []
389
self._in_buffer_len = 0
391
def accept_bytes(self, bytes):
392
"""Decode as much of bytes as possible.
394
If 'bytes' contains too much data it will be appended to
397
finished_reading will be set when no more data is required. Further
398
data will be appended to self.unused_data.
400
# accept_bytes is allowed to change the state
401
self._number_needed_bytes = None
402
# lsprof puts a very large amount of time on this specific call for
404
self._in_buffer_list.append(bytes)
405
self._in_buffer_len += len(bytes)
407
# Run the function for the current state.
408
current_state = self.state_accept
410
while current_state != self.state_accept:
411
# The current state has changed. Run the function for the new
412
# current state, so that it can:
413
# - decode any unconsumed bytes left in a buffer, and
414
# - signal how many more bytes are expected (via raising
416
current_state = self.state_accept
418
except _NeedMoreBytes as e:
419
self._number_needed_bytes = e.count
422
class ChunkedBodyDecoder(_StatefulDecoder):
423
"""Decoder for chunked body data.
425
This is very similar the HTTP's chunked encoding. See the description of
426
streamed body data in `doc/developers/network-protocol.txt` for details.
430
_StatefulDecoder.__init__(self)
431
self.state_accept = self._state_accept_expecting_header
432
self.chunk_in_progress = None
433
self.chunks = collections.deque()
435
self.error_in_progress = None
437
def next_read_size(self):
438
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
439
# end-of-body marker is 4 bytes: 'END\n'.
440
if self.state_accept == self._state_accept_reading_chunk:
441
# We're expecting more chunk content. So we're expecting at least
442
# the rest of this chunk plus an END chunk.
443
return self.bytes_left + 4
444
elif self.state_accept == self._state_accept_expecting_length:
445
if self._in_buffer_len == 0:
446
# We're expecting a chunk length. There's at least two bytes
447
# left: a digit plus '\n'.
450
# We're in the middle of reading a chunk length. So there's at
451
# least one byte left, the '\n' that terminates the length.
453
elif self.state_accept == self._state_accept_reading_unused:
455
elif self.state_accept == self._state_accept_expecting_header:
456
return max(0, len('chunked\n') - self._in_buffer_len)
458
raise AssertionError("Impossible state: %r" % (self.state_accept,))
460
def read_next_chunk(self):
462
return self.chunks.popleft()
466
def _extract_line(self):
467
in_buf = self._get_in_buffer()
468
pos = in_buf.find('\n')
470
# We haven't read a complete line yet, so request more bytes before
472
raise _NeedMoreBytes(1)
474
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
475
self._set_in_buffer(in_buf[pos+1:])
479
self.unused_data = self._get_in_buffer()
480
self._in_buffer_list = []
481
self._in_buffer_len = 0
482
self.state_accept = self._state_accept_reading_unused
484
error_args = tuple(self.error_in_progress)
485
self.chunks.append(request.FailedSmartServerResponse(error_args))
486
self.error_in_progress = None
487
self.finished_reading = True
489
def _state_accept_expecting_header(self):
490
prefix = self._extract_line()
491
if prefix == 'chunked':
492
self.state_accept = self._state_accept_expecting_length
494
raise errors.SmartProtocolError(
495
'Bad chunked body header: "%s"' % (prefix,))
497
def _state_accept_expecting_length(self):
498
prefix = self._extract_line()
501
self.error_in_progress = []
502
self._state_accept_expecting_length()
504
elif prefix == 'END':
505
# We've read the end-of-body marker.
506
# Any further bytes are unused data, including the bytes left in
511
self.bytes_left = int(prefix, 16)
512
self.chunk_in_progress = ''
513
self.state_accept = self._state_accept_reading_chunk
515
def _state_accept_reading_chunk(self):
516
in_buf = self._get_in_buffer()
517
in_buffer_len = len(in_buf)
518
self.chunk_in_progress += in_buf[:self.bytes_left]
519
self._set_in_buffer(in_buf[self.bytes_left:])
520
self.bytes_left -= in_buffer_len
521
if self.bytes_left <= 0:
522
# Finished with chunk
523
self.bytes_left = None
525
self.error_in_progress.append(self.chunk_in_progress)
527
self.chunks.append(self.chunk_in_progress)
528
self.chunk_in_progress = None
529
self.state_accept = self._state_accept_expecting_length
531
def _state_accept_reading_unused(self):
532
self.unused_data += self._get_in_buffer()
533
self._in_buffer_list = []
536
class LengthPrefixedBodyDecoder(_StatefulDecoder):
537
"""Decodes the length-prefixed bulk data."""
540
_StatefulDecoder.__init__(self)
541
self.state_accept = self._state_accept_expecting_length
542
self.state_read = self._state_read_no_data
544
self._trailer_buffer = ''
546
def next_read_size(self):
547
if self.bytes_left is not None:
548
# Ideally we want to read all the remainder of the body and the
550
return self.bytes_left + 5
551
elif self.state_accept == self._state_accept_reading_trailer:
552
# Just the trailer left
553
return 5 - len(self._trailer_buffer)
554
elif self.state_accept == self._state_accept_expecting_length:
555
# There's still at least 6 bytes left ('\n' to end the length, plus
559
# Reading excess data. Either way, 1 byte at a time is fine.
562
def read_pending_data(self):
563
"""Return any pending data that has been decoded."""
564
return self.state_read()
566
def _state_accept_expecting_length(self):
567
in_buf = self._get_in_buffer()
568
pos = in_buf.find('\n')
571
self.bytes_left = int(in_buf[:pos])
572
self._set_in_buffer(in_buf[pos+1:])
573
self.state_accept = self._state_accept_reading_body
574
self.state_read = self._state_read_body_buffer
576
def _state_accept_reading_body(self):
577
in_buf = self._get_in_buffer()
579
self.bytes_left -= len(in_buf)
580
self._set_in_buffer(None)
581
if self.bytes_left <= 0:
583
if self.bytes_left != 0:
584
self._trailer_buffer = self._body[self.bytes_left:]
585
self._body = self._body[:self.bytes_left]
586
self.bytes_left = None
587
self.state_accept = self._state_accept_reading_trailer
589
def _state_accept_reading_trailer(self):
590
self._trailer_buffer += self._get_in_buffer()
591
self._set_in_buffer(None)
592
# TODO: what if the trailer does not match "done\n"? Should this raise
593
# a ProtocolViolation exception?
594
if self._trailer_buffer.startswith('done\n'):
595
self.unused_data = self._trailer_buffer[len('done\n'):]
596
self.state_accept = self._state_accept_reading_unused
597
self.finished_reading = True
599
def _state_accept_reading_unused(self):
600
self.unused_data += self._get_in_buffer()
601
self._set_in_buffer(None)
603
def _state_read_no_data(self):
606
def _state_read_body_buffer(self):
612
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
613
message.ResponseHandler):
614
"""The client-side protocol for smart version 1."""
616
def __init__(self, request):
617
"""Construct a SmartClientRequestProtocolOne.
619
:param request: A SmartClientMediumRequest to serialise onto and
622
self._request = request
623
self._body_buffer = None
624
self._request_start_time = None
625
self._last_verb = None
628
def set_headers(self, headers):
629
self._headers = dict(headers)
631
def call(self, *args):
632
if 'hpss' in debug.debug_flags:
633
mutter('hpss call: %s', repr(args)[1:-1])
634
if getattr(self._request._medium, 'base', None) is not None:
635
mutter(' (to %s)', self._request._medium.base)
636
self._request_start_time = osutils.timer_func()
637
self._write_args(args)
638
self._request.finished_writing()
639
self._last_verb = args[0]
641
def call_with_body_bytes(self, args, body):
642
"""Make a remote call of args with body bytes 'body'.
644
After calling this, call read_response_tuple to find the result out.
646
if 'hpss' in debug.debug_flags:
647
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
648
if getattr(self._request._medium, '_path', None) is not None:
649
mutter(' (to %s)', self._request._medium._path)
650
mutter(' %d bytes', len(body))
651
self._request_start_time = osutils.timer_func()
652
if 'hpssdetail' in debug.debug_flags:
653
mutter('hpss body content: %s', body)
654
self._write_args(args)
655
bytes = self._encode_bulk_data(body)
656
self._request.accept_bytes(bytes)
657
self._request.finished_writing()
658
self._last_verb = args[0]
660
def call_with_body_readv_array(self, args, body):
661
"""Make a remote call with a readv array.
663
The body is encoded with one line per readv offset pair. The numbers in
664
each pair are separated by a comma, and no trailing \\n is emitted.
666
if 'hpss' in debug.debug_flags:
667
mutter('hpss call w/readv: %s', repr(args)[1:-1])
668
if getattr(self._request._medium, '_path', None) is not None:
669
mutter(' (to %s)', self._request._medium._path)
670
self._request_start_time = osutils.timer_func()
671
self._write_args(args)
672
readv_bytes = self._serialise_offsets(body)
673
bytes = self._encode_bulk_data(readv_bytes)
674
self._request.accept_bytes(bytes)
675
self._request.finished_writing()
676
if 'hpss' in debug.debug_flags:
677
mutter(' %d bytes in readv request', len(readv_bytes))
678
self._last_verb = args[0]
680
def call_with_body_stream(self, args, stream):
681
# Protocols v1 and v2 don't support body streams. So it's safe to
682
# assume that a v1/v2 server doesn't support whatever method we're
683
# trying to call with a body stream.
684
self._request.finished_writing()
685
self._request.finished_reading()
686
raise errors.UnknownSmartMethod(args[0])
688
def cancel_read_body(self):
689
"""After expecting a body, a response code may indicate one otherwise.
691
This method lets the domain client inform the protocol that no body
692
will be transmitted. This is a terminal method: after calling it the
693
protocol is not able to be used further.
695
self._request.finished_reading()
697
def _read_response_tuple(self):
698
result = self._recv_tuple()
699
if 'hpss' in debug.debug_flags:
700
if self._request_start_time is not None:
701
mutter(' result: %6.3fs %s',
702
osutils.timer_func() - self._request_start_time,
704
self._request_start_time = None
706
mutter(' result: %s', repr(result)[1:-1])
709
def read_response_tuple(self, expect_body=False):
710
"""Read a response tuple from the wire.
712
This should only be called once.
714
result = self._read_response_tuple()
715
self._response_is_unknown_method(result)
716
self._raise_args_if_error(result)
718
self._request.finished_reading()
721
def _raise_args_if_error(self, result_tuple):
722
# Later protocol versions have an explicit flag in the protocol to say
723
# if an error response is "failed" or not. In version 1 we don't have
724
# that luxury. So here is a complete list of errors that can be
725
# returned in response to existing version 1 smart requests. Responses
726
# starting with these codes are always "failed" responses.
733
'UnicodeEncodeError',
734
'UnicodeDecodeError',
740
'UnlockableTransport',
746
if result_tuple[0] in v1_error_codes:
747
self._request.finished_reading()
748
raise errors.ErrorFromSmartServer(result_tuple)
750
def _response_is_unknown_method(self, result_tuple):
751
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
752
method' response to the request.
754
:param response: The response from a smart client call_expecting_body
756
:param verb: The verb used in that call.
757
:raises: UnexpectedSmartServerResponse
759
if (result_tuple == ('error', "Generic bzr smart protocol error: "
760
"bad request '%s'" % self._last_verb) or
761
result_tuple == ('error', "Generic bzr smart protocol error: "
762
"bad request u'%s'" % self._last_verb)):
763
# The response will have no body, so we've finished reading.
764
self._request.finished_reading()
765
raise errors.UnknownSmartMethod(self._last_verb)
767
def read_body_bytes(self, count=-1):
768
"""Read bytes from the body, decoding into a byte stream.
770
We read all bytes at once to ensure we've checked the trailer for
771
errors, and then feed the buffer back as read_body_bytes is called.
773
if self._body_buffer is not None:
774
return self._body_buffer.read(count)
775
_body_decoder = LengthPrefixedBodyDecoder()
777
while not _body_decoder.finished_reading:
778
bytes = self._request.read_bytes(_body_decoder.next_read_size())
780
# end of file encountered reading from server
781
raise errors.ConnectionReset(
782
"Connection lost while reading response body.")
783
_body_decoder.accept_bytes(bytes)
784
self._request.finished_reading()
785
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
786
# XXX: TODO check the trailer result.
787
if 'hpss' in debug.debug_flags:
788
mutter(' %d body bytes read',
789
len(self._body_buffer.getvalue()))
790
return self._body_buffer.read(count)
792
def _recv_tuple(self):
793
"""Receive a tuple from the medium request."""
794
return _decode_tuple(self._request.read_line())
796
def query_version(self):
797
"""Return protocol version number of the server."""
799
resp = self.read_response_tuple()
800
if resp == ('ok', '1'):
802
elif resp == ('ok', '2'):
805
raise errors.SmartProtocolError("bad response %r" % (resp,))
807
def _write_args(self, args):
808
self._write_protocol_version()
809
bytes = _encode_tuple(args)
810
self._request.accept_bytes(bytes)
812
def _write_protocol_version(self):
813
"""Write any prefixes this protocol requires.
815
Version one doesn't send protocol versions.
819
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
820
"""Version two of the client side of the smart protocol.
822
This prefixes the request with the value of REQUEST_VERSION_TWO.
825
response_marker = RESPONSE_VERSION_TWO
826
request_marker = REQUEST_VERSION_TWO
828
def read_response_tuple(self, expect_body=False):
829
"""Read a response tuple from the wire.
831
This should only be called once.
833
version = self._request.read_line()
834
if version != self.response_marker:
835
self._request.finished_reading()
836
raise errors.UnexpectedProtocolVersionMarker(version)
837
response_status = self._request.read_line()
838
result = SmartClientRequestProtocolOne._read_response_tuple(self)
839
self._response_is_unknown_method(result)
840
if response_status == 'success\n':
841
self.response_status = True
843
self._request.finished_reading()
845
elif response_status == 'failed\n':
846
self.response_status = False
847
self._request.finished_reading()
848
raise errors.ErrorFromSmartServer(result)
850
raise errors.SmartProtocolError(
851
'bad protocol status %r' % response_status)
853
def _write_protocol_version(self):
854
"""Write any prefixes this protocol requires.
856
Version two sends the value of REQUEST_VERSION_TWO.
858
self._request.accept_bytes(self.request_marker)
860
def read_streamed_body(self):
861
"""Read bytes from the body, decoding into a byte stream.
863
# Read no more than 64k at a time so that we don't risk error 10055 (no
864
# buffer space available) on Windows.
865
_body_decoder = ChunkedBodyDecoder()
866
while not _body_decoder.finished_reading:
867
bytes = self._request.read_bytes(_body_decoder.next_read_size())
869
# end of file encountered reading from server
870
raise errors.ConnectionReset(
871
"Connection lost while reading streamed body.")
872
_body_decoder.accept_bytes(bytes)
873
for body_bytes in iter(_body_decoder.read_next_chunk, None):
874
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
875
mutter(' %d byte chunk read',
878
self._request.finished_reading()
881
def build_server_protocol_three(backing_transport, write_func,
882
root_client_path, jail_root=None):
883
request_handler = request.SmartServerRequestHandler(
884
backing_transport, commands=request.request_handlers,
885
root_client_path=root_client_path, jail_root=jail_root)
886
responder = ProtocolThreeResponder(write_func)
887
message_handler = message.ConventionalRequestHandler(request_handler, responder)
888
return ProtocolThreeDecoder(message_handler)
891
class ProtocolThreeDecoder(_StatefulDecoder):
893
response_marker = RESPONSE_VERSION_THREE
894
request_marker = REQUEST_VERSION_THREE
896
def __init__(self, message_handler, expect_version_marker=False):
897
_StatefulDecoder.__init__(self)
898
self._has_dispatched = False
900
if expect_version_marker:
901
self.state_accept = self._state_accept_expecting_protocol_version
902
# We're expecting at least the protocol version marker + some
904
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
906
self.state_accept = self._state_accept_expecting_headers
907
self._number_needed_bytes = 4
908
self.decoding_failed = False
909
self.request_handler = self.message_handler = message_handler
911
def accept_bytes(self, bytes):
912
self._number_needed_bytes = None
914
_StatefulDecoder.accept_bytes(self, bytes)
915
except KeyboardInterrupt:
917
except errors.SmartMessageHandlerError as exception:
918
# We do *not* set self.decoding_failed here. The message handler
919
# has raised an error, but the decoder is still able to parse bytes
920
# and determine when this message ends.
921
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
922
log_exception_quietly()
923
self.message_handler.protocol_error(exception.exc_value)
924
# The state machine is ready to continue decoding, but the
925
# exception has interrupted the loop that runs the state machine.
926
# So we call accept_bytes again to restart it.
927
self.accept_bytes('')
928
except Exception as exception:
929
# The decoder itself has raised an exception. We cannot continue
931
self.decoding_failed = True
932
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
933
# This happens during normal operation when the client tries a
934
# protocol version the server doesn't understand, so no need to
935
# log a traceback every time.
936
# Note that this can only happen when
937
# expect_version_marker=True, which is only the case on the
941
log_exception_quietly()
942
self.message_handler.protocol_error(exception)
944
def _extract_length_prefixed_bytes(self):
945
if self._in_buffer_len < 4:
946
# A length prefix by itself is 4 bytes, and we don't even have that
948
raise _NeedMoreBytes(4)
949
(length,) = struct.unpack('!L', self._get_in_bytes(4))
950
end_of_bytes = 4 + length
951
if self._in_buffer_len < end_of_bytes:
952
# We haven't yet read as many bytes as the length-prefix says there
954
raise _NeedMoreBytes(end_of_bytes)
955
# Extract the bytes from the buffer.
956
in_buf = self._get_in_buffer()
957
bytes = in_buf[4:end_of_bytes]
958
self._set_in_buffer(in_buf[end_of_bytes:])
961
def _extract_prefixed_bencoded_data(self):
962
prefixed_bytes = self._extract_length_prefixed_bytes()
964
decoded = bdecode_as_tuple(prefixed_bytes)
966
raise errors.SmartProtocolError(
967
'Bytes %r not bencoded' % (prefixed_bytes,))
970
def _extract_single_byte(self):
971
if self._in_buffer_len == 0:
972
# The buffer is empty
973
raise _NeedMoreBytes(1)
974
in_buf = self._get_in_buffer()
976
self._set_in_buffer(in_buf[1:])
979
def _state_accept_expecting_protocol_version(self):
980
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
981
in_buf = self._get_in_buffer()
983
# We don't have enough bytes to check if the protocol version
984
# marker is right. But we can check if it is already wrong by
985
# checking that the start of MESSAGE_VERSION_THREE matches what
987
# [In fact, if the remote end isn't bzr we might never receive
988
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
989
# are wrong then we should just raise immediately rather than
991
if not MESSAGE_VERSION_THREE.startswith(in_buf):
992
# We have enough bytes to know the protocol version is wrong
993
raise errors.UnexpectedProtocolVersionMarker(in_buf)
994
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
995
if not in_buf.startswith(MESSAGE_VERSION_THREE):
996
raise errors.UnexpectedProtocolVersionMarker(in_buf)
997
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
998
self.state_accept = self._state_accept_expecting_headers
1000
def _state_accept_expecting_headers(self):
1001
decoded = self._extract_prefixed_bencoded_data()
1002
if not isinstance(decoded, dict):
1003
raise errors.SmartProtocolError(
1004
'Header object %r is not a dict' % (decoded,))
1005
self.state_accept = self._state_accept_expecting_message_part
1007
self.message_handler.headers_received(decoded)
1009
raise errors.SmartMessageHandlerError(sys.exc_info())
1011
def _state_accept_expecting_message_part(self):
1012
message_part_kind = self._extract_single_byte()
1013
if message_part_kind == 'o':
1014
self.state_accept = self._state_accept_expecting_one_byte
1015
elif message_part_kind == 's':
1016
self.state_accept = self._state_accept_expecting_structure
1017
elif message_part_kind == 'b':
1018
self.state_accept = self._state_accept_expecting_bytes
1019
elif message_part_kind == 'e':
1022
raise errors.SmartProtocolError(
1023
'Bad message kind byte: %r' % (message_part_kind,))
1025
def _state_accept_expecting_one_byte(self):
1026
byte = self._extract_single_byte()
1027
self.state_accept = self._state_accept_expecting_message_part
1029
self.message_handler.byte_part_received(byte)
1031
raise errors.SmartMessageHandlerError(sys.exc_info())
1033
def _state_accept_expecting_bytes(self):
1034
# XXX: this should not buffer whole message part, but instead deliver
1035
# the bytes as they arrive.
1036
prefixed_bytes = self._extract_length_prefixed_bytes()
1037
self.state_accept = self._state_accept_expecting_message_part
1039
self.message_handler.bytes_part_received(prefixed_bytes)
1041
raise errors.SmartMessageHandlerError(sys.exc_info())
1043
def _state_accept_expecting_structure(self):
1044
structure = self._extract_prefixed_bencoded_data()
1045
self.state_accept = self._state_accept_expecting_message_part
1047
self.message_handler.structure_part_received(structure)
1049
raise errors.SmartMessageHandlerError(sys.exc_info())
1052
self.unused_data = self._get_in_buffer()
1053
self._set_in_buffer(None)
1054
self.state_accept = self._state_accept_reading_unused
1056
self.message_handler.end_received()
1058
raise errors.SmartMessageHandlerError(sys.exc_info())
1060
def _state_accept_reading_unused(self):
1061
self.unused_data += self._get_in_buffer()
1062
self._set_in_buffer(None)
1064
def next_read_size(self):
1065
if self.state_accept == self._state_accept_reading_unused:
1067
elif self.decoding_failed:
1068
# An exception occured while processing this message, probably from
1069
# self.message_handler. We're not sure that this state machine is
1070
# in a consistent state, so just signal that we're done (i.e. give
1074
if self._number_needed_bytes is not None:
1075
return self._number_needed_bytes - self._in_buffer_len
1077
raise AssertionError("don't know how many bytes are expected!")
1080
class _ProtocolThreeEncoder(object):
1082
response_marker = request_marker = MESSAGE_VERSION_THREE
1083
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1085
def __init__(self, write_func):
1088
self._real_write_func = write_func
1090
def _write_func(self, bytes):
1091
# TODO: Another possibility would be to turn this into an async model.
1092
# Where we let another thread know that we have some bytes if
1093
# they want it, but we don't actually block for it
1094
# Note that osutils.send_all always sends 64kB chunks anyway, so
1095
# we might just push out smaller bits at a time?
1096
self._buf.append(bytes)
1097
self._buf_len += len(bytes)
1098
if self._buf_len > self.BUFFER_SIZE:
1103
self._real_write_func(''.join(self._buf))
1107
def _serialise_offsets(self, offsets):
1108
"""Serialise a readv offset list."""
1110
for start, length in offsets:
1111
txt.append('%d,%d' % (start, length))
1112
return '\n'.join(txt)
1114
def _write_protocol_version(self):
1115
self._write_func(MESSAGE_VERSION_THREE)
1117
def _write_prefixed_bencode(self, structure):
1118
bytes = bencode(structure)
1119
self._write_func(struct.pack('!L', len(bytes)))
1120
self._write_func(bytes)
1122
def _write_headers(self, headers):
1123
self._write_prefixed_bencode(headers)
1125
def _write_structure(self, args):
1126
self._write_func('s')
1129
if isinstance(arg, unicode):
1130
utf8_args.append(arg.encode('utf8'))
1132
utf8_args.append(arg)
1133
self._write_prefixed_bencode(utf8_args)
1135
def _write_end(self):
1136
self._write_func('e')
1139
def _write_prefixed_body(self, bytes):
1140
self._write_func('b')
1141
self._write_func(struct.pack('!L', len(bytes)))
1142
self._write_func(bytes)
1144
def _write_chunked_body_start(self):
1145
self._write_func('oC')
1147
def _write_error_status(self):
1148
self._write_func('oE')
1150
def _write_success_status(self):
1151
self._write_func('oS')
1154
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1156
def __init__(self, write_func):
1157
_ProtocolThreeEncoder.__init__(self, write_func)
1158
self.response_sent = False
1159
self._headers = {'Software version': breezy.__version__}
1160
if 'hpss' in debug.debug_flags:
1161
self._thread_id = _thread.get_ident()
1162
self._response_start_time = None
1164
def _trace(self, action, message, extra_bytes=None, include_time=False):
1165
if self._response_start_time is None:
1166
self._response_start_time = osutils.timer_func()
1168
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1171
if extra_bytes is None:
1174
extra = ' ' + repr(extra_bytes[:40])
1176
extra = extra[:29] + extra[-1] + '...'
1177
mutter('%12s: [%s] %s%s%s'
1178
% (action, self._thread_id, t, message, extra))
1180
def send_error(self, exception):
1181
if self.response_sent:
1182
raise AssertionError(
1183
"send_error(%s) called, but response already sent."
1185
if isinstance(exception, errors.UnknownSmartMethod):
1186
failure = request.FailedSmartServerResponse(
1187
('UnknownMethod', exception.verb))
1188
self.send_response(failure)
1190
if 'hpss' in debug.debug_flags:
1191
self._trace('error', str(exception))
1192
self.response_sent = True
1193
self._write_protocol_version()
1194
self._write_headers(self._headers)
1195
self._write_error_status()
1196
self._write_structure(('error', str(exception)))
1199
def send_response(self, response):
1200
if self.response_sent:
1201
raise AssertionError(
1202
"send_response(%r) called, but response already sent."
1204
self.response_sent = True
1205
self._write_protocol_version()
1206
self._write_headers(self._headers)
1207
if response.is_successful():
1208
self._write_success_status()
1210
self._write_error_status()
1211
if 'hpss' in debug.debug_flags:
1212
self._trace('response', repr(response.args))
1213
self._write_structure(response.args)
1214
if response.body is not None:
1215
self._write_prefixed_body(response.body)
1216
if 'hpss' in debug.debug_flags:
1217
self._trace('body', '%d bytes' % (len(response.body),),
1218
response.body, include_time=True)
1219
elif response.body_stream is not None:
1220
count = num_bytes = 0
1222
for exc_info, chunk in _iter_with_errors(response.body_stream):
1224
if exc_info is not None:
1225
self._write_error_status()
1226
error_struct = request._translate_error(exc_info[1])
1227
self._write_structure(error_struct)
1230
if isinstance(chunk, request.FailedSmartServerResponse):
1231
self._write_error_status()
1232
self._write_structure(chunk.args)
1234
num_bytes += len(chunk)
1235
if first_chunk is None:
1237
self._write_prefixed_body(chunk)
1239
if 'hpssdetail' in debug.debug_flags:
1240
# Not worth timing separately, as _write_func is
1242
self._trace('body chunk',
1243
'%d bytes' % (len(chunk),),
1244
chunk, suppress_time=True)
1245
if 'hpss' in debug.debug_flags:
1246
self._trace('body stream',
1247
'%d bytes %d chunks' % (num_bytes, count),
1250
if 'hpss' in debug.debug_flags:
1251
self._trace('response end', '', include_time=True)
1254
def _iter_with_errors(iterable):
1255
"""Handle errors from iterable.next().
1259
for exc_info, value in _iter_with_errors(iterable):
1262
This is a safer alternative to::
1265
for value in iterable:
1270
Because the latter will catch errors from the for-loop body, not just
1273
If an error occurs, exc_info will be a exc_info tuple, and the generator
1274
will terminate. Otherwise exc_info will be None, and value will be the
1275
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1276
will not be itercepted.
1278
iterator = iter(iterable)
1281
yield None, next(iterator)
1282
except StopIteration:
1284
except (KeyboardInterrupt, SystemExit):
1287
mutter('_iter_with_errors caught error')
1288
log_exception_quietly()
1289
yield sys.exc_info(), None
1293
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1295
def __init__(self, medium_request):
1296
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1297
self._medium_request = medium_request
1299
self.body_stream_started = None
1301
def set_headers(self, headers):
1302
self._headers = headers.copy()
1304
def call(self, *args):
1305
if 'hpss' in debug.debug_flags:
1306
mutter('hpss call: %s', repr(args)[1:-1])
1307
base = getattr(self._medium_request._medium, 'base', None)
1308
if base is not None:
1309
mutter(' (to %s)', base)
1310
self._request_start_time = osutils.timer_func()
1311
self._write_protocol_version()
1312
self._write_headers(self._headers)
1313
self._write_structure(args)
1315
self._medium_request.finished_writing()
1317
def call_with_body_bytes(self, args, body):
1318
"""Make a remote call of args with body bytes 'body'.
1320
After calling this, call read_response_tuple to find the result out.
1322
if 'hpss' in debug.debug_flags:
1323
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1324
path = getattr(self._medium_request._medium, '_path', None)
1325
if path is not None:
1326
mutter(' (to %s)', path)
1327
mutter(' %d bytes', len(body))
1328
self._request_start_time = osutils.timer_func()
1329
self._write_protocol_version()
1330
self._write_headers(self._headers)
1331
self._write_structure(args)
1332
self._write_prefixed_body(body)
1334
self._medium_request.finished_writing()
1336
def call_with_body_readv_array(self, args, body):
1337
"""Make a remote call with a readv array.
1339
The body is encoded with one line per readv offset pair. The numbers in
1340
each pair are separated by a comma, and no trailing \\n is emitted.
1342
if 'hpss' in debug.debug_flags:
1343
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1344
path = getattr(self._medium_request._medium, '_path', None)
1345
if path is not None:
1346
mutter(' (to %s)', path)
1347
self._request_start_time = osutils.timer_func()
1348
self._write_protocol_version()
1349
self._write_headers(self._headers)
1350
self._write_structure(args)
1351
readv_bytes = self._serialise_offsets(body)
1352
if 'hpss' in debug.debug_flags:
1353
mutter(' %d bytes in readv request', len(readv_bytes))
1354
self._write_prefixed_body(readv_bytes)
1356
self._medium_request.finished_writing()
1358
def call_with_body_stream(self, args, stream):
1359
if 'hpss' in debug.debug_flags:
1360
mutter('hpss call w/body stream: %r', args)
1361
path = getattr(self._medium_request._medium, '_path', None)
1362
if path is not None:
1363
mutter(' (to %s)', path)
1364
self._request_start_time = osutils.timer_func()
1365
self.body_stream_started = False
1366
self._write_protocol_version()
1367
self._write_headers(self._headers)
1368
self._write_structure(args)
1369
# TODO: notice if the server has sent an early error reply before we
1370
# have finished sending the stream. We would notice at the end
1371
# anyway, but if the medium can deliver it early then it's good
1372
# to short-circuit the whole request...
1373
# Provoke any ConnectionReset failures before we start the body stream.
1375
self.body_stream_started = True
1376
for exc_info, part in _iter_with_errors(stream):
1377
if exc_info is not None:
1378
# Iterating the stream failed. Cleanly abort the request.
1379
self._write_error_status()
1380
# Currently the client unconditionally sends ('error',) as the
1382
self._write_structure(('error',))
1384
self._medium_request.finished_writing()
1390
self._write_prefixed_body(part)
1393
self._medium_request.finished_writing()