1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
29
import thread as _thread
38
from ...sixish import (
42
from . import message, request
43
from ...sixish import text_type
44
from ...trace import log_exception_quietly, mutter
45
from ...bencode import bdecode_as_tuple, bencode
48
# Protocol version strings. These are sent as prefixes of bzr requests and
49
# responses to identify the protocol version being used. (There are no version
50
# one strings because that version doesn't send any).
51
REQUEST_VERSION_TWO = b'bzr request 2\n'
52
RESPONSE_VERSION_TWO = b'bzr response 2\n'
54
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
55
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
58
def _recv_tuple(from_file):
59
req_line = from_file.readline()
60
return _decode_tuple(req_line)
63
def _decode_tuple(req_line):
64
if req_line is None or req_line == b'':
66
if not req_line.endswith(b'\n'):
67
raise errors.SmartProtocolError("request %r not terminated" % req_line)
68
return tuple(req_line[:-1].split(b'\x01'))
71
def _encode_tuple(args):
72
"""Encode the tuple args to a bytestream."""
74
if isinstance(arg, text_type):
76
return b'\x01'.join(args) + b'\n'
79
class Requester(object):
80
"""Abstract base class for an object that can issue requests on a smart
84
def call(self, *args):
85
"""Make a remote call.
87
:param args: the arguments of this call.
89
raise NotImplementedError(self.call)
91
def call_with_body_bytes(self, args, body):
92
"""Make a remote call with a body.
94
:param args: the arguments of this call.
96
:param body: the body to send with the request.
98
raise NotImplementedError(self.call_with_body_bytes)
100
def call_with_body_readv_array(self, args, body):
101
"""Make a remote call with a readv array.
103
:param args: the arguments of this call.
104
:type body: iterable of (start, length) tuples.
105
:param body: the readv ranges to send with this request.
107
raise NotImplementedError(self.call_with_body_readv_array)
109
def set_headers(self, headers):
110
raise NotImplementedError(self.set_headers)
113
class SmartProtocolBase(object):
114
"""Methods common to client and server"""
116
# TODO: this only actually accomodates a single block; possibly should
117
# support multiple chunks?
118
def _encode_bulk_data(self, body):
119
"""Encode body as a bulk data chunk."""
120
return b''.join((b'%d\n' % len(body), body, b'done\n'))
122
def _serialise_offsets(self, offsets):
123
"""Serialise a readv offset list."""
125
for start, length in offsets:
126
txt.append(b'%d,%d' % (start, length))
127
return b'\n'.join(txt)
130
class SmartServerRequestProtocolOne(SmartProtocolBase):
131
"""Server-side encoding and decoding logic for smart version 1."""
133
def __init__(self, backing_transport, write_func, root_client_path='/',
135
self._backing_transport = backing_transport
136
self._root_client_path = root_client_path
137
self._jail_root = jail_root
138
self.unused_data = b''
139
self._finished = False
141
self._has_dispatched = False
143
self._body_decoder = None
144
self._write_func = write_func
146
def accept_bytes(self, data):
147
"""Take bytes, and advance the internal state machine appropriately.
149
:param data: must be a byte string
151
if not isinstance(data, bytes):
152
raise ValueError(data)
153
self.in_buffer += data
154
if not self._has_dispatched:
155
if b'\n' not in self.in_buffer:
156
# no command line yet
158
self._has_dispatched = True
160
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
162
req_args = _decode_tuple(first_line)
163
self.request = request.SmartServerRequestHandler(
164
self._backing_transport, commands=request.request_handlers,
165
root_client_path=self._root_client_path,
166
jail_root=self._jail_root)
167
self.request.args_received(req_args)
168
if self.request.finished_reading:
170
self.unused_data = self.in_buffer
172
self._send_response(self.request.response)
173
except KeyboardInterrupt:
175
except errors.UnknownSmartMethod as err:
176
protocol_error = errors.SmartProtocolError(
177
"bad request %r" % (err.verb,))
178
failure = request.FailedSmartServerResponse(
179
(b'error', str(protocol_error).encode('utf-8')))
180
self._send_response(failure)
182
except Exception as exception:
183
# everything else: pass to client, flush, and quit
184
log_exception_quietly()
185
self._send_response(request.FailedSmartServerResponse(
186
(b'error', str(exception).encode('utf-8'))))
189
if self._has_dispatched:
191
# nothing to do.XXX: this routine should be a single state
193
self.unused_data += self.in_buffer
196
if self._body_decoder is None:
197
self._body_decoder = LengthPrefixedBodyDecoder()
198
self._body_decoder.accept_bytes(self.in_buffer)
199
self.in_buffer = self._body_decoder.unused_data
200
body_data = self._body_decoder.read_pending_data()
201
self.request.accept_body(body_data)
202
if self._body_decoder.finished_reading:
203
self.request.end_of_body()
204
if not self.request.finished_reading:
205
raise AssertionError("no more body, request not finished")
206
if self.request.response is not None:
207
self._send_response(self.request.response)
208
self.unused_data = self.in_buffer
211
if self.request.finished_reading:
212
raise AssertionError(
213
"no response and we have finished reading.")
215
def _send_response(self, response):
216
"""Send a smart server response down the output stream."""
218
raise AssertionError('response already sent')
221
self._finished = True
222
self._write_protocol_version()
223
self._write_success_or_failure_prefix(response)
224
self._write_func(_encode_tuple(args))
226
if not isinstance(body, str):
227
raise ValueError(body)
228
bytes = self._encode_bulk_data(body)
229
self._write_func(bytes)
231
def _write_protocol_version(self):
232
"""Write any prefixes this protocol requires.
234
Version one doesn't send protocol versions.
237
def _write_success_or_failure_prefix(self, response):
238
"""Write the protocol specific success/failure prefix.
240
For SmartServerRequestProtocolOne this is omitted but we
241
call is_successful to ensure that the response is valid.
243
response.is_successful()
245
def next_read_size(self):
248
if self._body_decoder is None:
251
return self._body_decoder.next_read_size()
254
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
255
r"""Version two of the server side of the smart protocol.
257
This prefixes responses with the value of RESPONSE_VERSION_TWO.
260
response_marker = RESPONSE_VERSION_TWO
261
request_marker = REQUEST_VERSION_TWO
263
def _write_success_or_failure_prefix(self, response):
264
"""Write the protocol specific success/failure prefix."""
265
if response.is_successful():
266
self._write_func(b'success\n')
268
self._write_func(b'failed\n')
270
def _write_protocol_version(self):
271
r"""Write any prefixes this protocol requires.
273
Version two sends the value of RESPONSE_VERSION_TWO.
275
self._write_func(self.response_marker)
277
def _send_response(self, response):
278
"""Send a smart server response down the output stream."""
280
raise AssertionError('response already sent')
281
self._finished = True
282
self._write_protocol_version()
283
self._write_success_or_failure_prefix(response)
284
self._write_func(_encode_tuple(response.args))
285
if response.body is not None:
286
if not isinstance(response.body, bytes):
287
raise AssertionError('body must be bytes')
288
if not (response.body_stream is None):
289
raise AssertionError(
290
'body_stream and body cannot both be set')
291
data = self._encode_bulk_data(response.body)
292
self._write_func(data)
293
elif response.body_stream is not None:
294
_send_stream(response.body_stream, self._write_func)
297
def _send_stream(stream, write_func):
298
write_func(b'chunked\n')
299
_send_chunks(stream, write_func)
303
def _send_chunks(stream, write_func):
305
if isinstance(chunk, bytes):
306
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
308
elif isinstance(chunk, request.FailedSmartServerResponse):
310
_send_chunks(chunk.args, write_func)
313
raise errors.BzrError(
314
'Chunks must be str or FailedSmartServerResponse, got %r'
318
class _NeedMoreBytes(Exception):
319
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
323
def __init__(self, count=None):
326
:param count: the total number of bytes needed by the current state.
327
May be None if the number of bytes needed is unknown.
332
class _StatefulDecoder(object):
333
"""Base class for writing state machines to decode byte streams.
335
Subclasses should provide a self.state_accept attribute that accepts bytes
336
and, if appropriate, updates self.state_accept to a different function.
337
accept_bytes will call state_accept as often as necessary to make sure the
338
state machine has progressed as far as possible before it returns.
340
See ProtocolThreeDecoder for an example subclass.
344
self.finished_reading = False
345
self._in_buffer_list = []
346
self._in_buffer_len = 0
347
self.unused_data = b''
348
self.bytes_left = None
349
self._number_needed_bytes = None
351
def _get_in_buffer(self):
352
if len(self._in_buffer_list) == 1:
353
return self._in_buffer_list[0]
354
in_buffer = b''.join(self._in_buffer_list)
355
if len(in_buffer) != self._in_buffer_len:
356
raise AssertionError(
357
"Length of buffer did not match expected value: %s != %s"
358
% self._in_buffer_len, len(in_buffer))
359
self._in_buffer_list = [in_buffer]
362
def _get_in_bytes(self, count):
363
"""Grab X bytes from the input_buffer.
365
Callers should have already checked that self._in_buffer_len is >
366
count. Note, this does not consume the bytes from the buffer. The
367
caller will still need to call _get_in_buffer() and then
368
_set_in_buffer() if they actually need to consume the bytes.
370
# check if we can yield the bytes from just the first entry in our list
371
if len(self._in_buffer_list) == 0:
372
raise AssertionError('Callers must be sure we have buffered bytes'
373
' before calling _get_in_bytes')
374
if len(self._in_buffer_list[0]) > count:
375
return self._in_buffer_list[0][:count]
376
# We can't yield it from the first buffer, so collapse all buffers, and
378
in_buf = self._get_in_buffer()
379
return in_buf[:count]
381
def _set_in_buffer(self, new_buf):
382
if new_buf is not None:
383
self._in_buffer_list = [new_buf]
384
self._in_buffer_len = len(new_buf)
386
self._in_buffer_list = []
387
self._in_buffer_len = 0
389
def accept_bytes(self, bytes):
390
"""Decode as much of bytes as possible.
392
If 'bytes' contains too much data it will be appended to
395
finished_reading will be set when no more data is required. Further
396
data will be appended to self.unused_data.
398
# accept_bytes is allowed to change the state
399
self._number_needed_bytes = None
400
# lsprof puts a very large amount of time on this specific call for
402
self._in_buffer_list.append(bytes)
403
self._in_buffer_len += len(bytes)
405
# Run the function for the current state.
406
current_state = self.state_accept
408
while current_state != self.state_accept:
409
# The current state has changed. Run the function for the new
410
# current state, so that it can:
411
# - decode any unconsumed bytes left in a buffer, and
412
# - signal how many more bytes are expected (via raising
414
current_state = self.state_accept
416
except _NeedMoreBytes as e:
417
self._number_needed_bytes = e.count
420
class ChunkedBodyDecoder(_StatefulDecoder):
421
"""Decoder for chunked body data.
423
This is very similar the HTTP's chunked encoding. See the description of
424
streamed body data in `doc/developers/network-protocol.txt` for details.
428
_StatefulDecoder.__init__(self)
429
self.state_accept = self._state_accept_expecting_header
430
self.chunk_in_progress = None
431
self.chunks = collections.deque()
433
self.error_in_progress = None
435
def next_read_size(self):
436
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
437
# end-of-body marker is 4 bytes: 'END\n'.
438
if self.state_accept == self._state_accept_reading_chunk:
439
# We're expecting more chunk content. So we're expecting at least
440
# the rest of this chunk plus an END chunk.
441
return self.bytes_left + 4
442
elif self.state_accept == self._state_accept_expecting_length:
443
if self._in_buffer_len == 0:
444
# We're expecting a chunk length. There's at least two bytes
445
# left: a digit plus '\n'.
448
# We're in the middle of reading a chunk length. So there's at
449
# least one byte left, the '\n' that terminates the length.
451
elif self.state_accept == self._state_accept_reading_unused:
453
elif self.state_accept == self._state_accept_expecting_header:
454
return max(0, len('chunked\n') - self._in_buffer_len)
456
raise AssertionError("Impossible state: %r" % (self.state_accept,))
458
def read_next_chunk(self):
460
return self.chunks.popleft()
464
def _extract_line(self):
465
in_buf = self._get_in_buffer()
466
pos = in_buf.find(b'\n')
468
# We haven't read a complete line yet, so request more bytes before
470
raise _NeedMoreBytes(1)
472
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
473
self._set_in_buffer(in_buf[pos+1:])
477
self.unused_data = self._get_in_buffer()
478
self._in_buffer_list = []
479
self._in_buffer_len = 0
480
self.state_accept = self._state_accept_reading_unused
482
error_args = tuple(self.error_in_progress)
483
self.chunks.append(request.FailedSmartServerResponse(error_args))
484
self.error_in_progress = None
485
self.finished_reading = True
487
def _state_accept_expecting_header(self):
488
prefix = self._extract_line()
489
if prefix == b'chunked':
490
self.state_accept = self._state_accept_expecting_length
492
raise errors.SmartProtocolError(
493
'Bad chunked body header: "%s"' % (prefix,))
495
def _state_accept_expecting_length(self):
496
prefix = self._extract_line()
499
self.error_in_progress = []
500
self._state_accept_expecting_length()
502
elif prefix == b'END':
503
# We've read the end-of-body marker.
504
# Any further bytes are unused data, including the bytes left in
509
self.bytes_left = int(prefix, 16)
510
self.chunk_in_progress = b''
511
self.state_accept = self._state_accept_reading_chunk
513
def _state_accept_reading_chunk(self):
514
in_buf = self._get_in_buffer()
515
in_buffer_len = len(in_buf)
516
self.chunk_in_progress += in_buf[:self.bytes_left]
517
self._set_in_buffer(in_buf[self.bytes_left:])
518
self.bytes_left -= in_buffer_len
519
if self.bytes_left <= 0:
520
# Finished with chunk
521
self.bytes_left = None
523
self.error_in_progress.append(self.chunk_in_progress)
525
self.chunks.append(self.chunk_in_progress)
526
self.chunk_in_progress = None
527
self.state_accept = self._state_accept_expecting_length
529
def _state_accept_reading_unused(self):
530
self.unused_data += self._get_in_buffer()
531
self._in_buffer_list = []
534
class LengthPrefixedBodyDecoder(_StatefulDecoder):
535
"""Decodes the length-prefixed bulk data."""
538
_StatefulDecoder.__init__(self)
539
self.state_accept = self._state_accept_expecting_length
540
self.state_read = self._state_read_no_data
542
self._trailer_buffer = b''
544
def next_read_size(self):
545
if self.bytes_left is not None:
546
# Ideally we want to read all the remainder of the body and the
548
return self.bytes_left + 5
549
elif self.state_accept == self._state_accept_reading_trailer:
550
# Just the trailer left
551
return 5 - len(self._trailer_buffer)
552
elif self.state_accept == self._state_accept_expecting_length:
553
# There's still at least 6 bytes left ('\n' to end the length, plus
557
# Reading excess data. Either way, 1 byte at a time is fine.
560
def read_pending_data(self):
561
"""Return any pending data that has been decoded."""
562
return self.state_read()
564
def _state_accept_expecting_length(self):
565
in_buf = self._get_in_buffer()
566
pos = in_buf.find(b'\n')
569
self.bytes_left = int(in_buf[:pos])
570
self._set_in_buffer(in_buf[pos+1:])
571
self.state_accept = self._state_accept_reading_body
572
self.state_read = self._state_read_body_buffer
574
def _state_accept_reading_body(self):
575
in_buf = self._get_in_buffer()
577
self.bytes_left -= len(in_buf)
578
self._set_in_buffer(None)
579
if self.bytes_left <= 0:
581
if self.bytes_left != 0:
582
self._trailer_buffer = self._body[self.bytes_left:]
583
self._body = self._body[:self.bytes_left]
584
self.bytes_left = None
585
self.state_accept = self._state_accept_reading_trailer
587
def _state_accept_reading_trailer(self):
588
self._trailer_buffer += self._get_in_buffer()
589
self._set_in_buffer(None)
590
# TODO: what if the trailer does not match "done\n"? Should this raise
591
# a ProtocolViolation exception?
592
if self._trailer_buffer.startswith(b'done\n'):
593
self.unused_data = self._trailer_buffer[len(b'done\n'):]
594
self.state_accept = self._state_accept_reading_unused
595
self.finished_reading = True
597
def _state_accept_reading_unused(self):
598
self.unused_data += self._get_in_buffer()
599
self._set_in_buffer(None)
601
def _state_read_no_data(self):
604
def _state_read_body_buffer(self):
610
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
611
message.ResponseHandler):
612
"""The client-side protocol for smart version 1."""
614
def __init__(self, request):
615
"""Construct a SmartClientRequestProtocolOne.
617
:param request: A SmartClientMediumRequest to serialise onto and
620
self._request = request
621
self._body_buffer = None
622
self._request_start_time = None
623
self._last_verb = None
626
def set_headers(self, headers):
627
self._headers = dict(headers)
629
def call(self, *args):
630
if 'hpss' in debug.debug_flags:
631
mutter('hpss call: %s', repr(args)[1:-1])
632
if getattr(self._request._medium, 'base', None) is not None:
633
mutter(' (to %s)', self._request._medium.base)
634
self._request_start_time = osutils.timer_func()
635
self._write_args(args)
636
self._request.finished_writing()
637
self._last_verb = args[0]
639
def call_with_body_bytes(self, args, body):
640
"""Make a remote call of args with body bytes 'body'.
642
After calling this, call read_response_tuple to find the result out.
644
if 'hpss' in debug.debug_flags:
645
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
646
if getattr(self._request._medium, '_path', None) is not None:
647
mutter(' (to %s)', self._request._medium._path)
648
mutter(' %d bytes', len(body))
649
self._request_start_time = osutils.timer_func()
650
if 'hpssdetail' in debug.debug_flags:
651
mutter('hpss body content: %s', body)
652
self._write_args(args)
653
bytes = self._encode_bulk_data(body)
654
self._request.accept_bytes(bytes)
655
self._request.finished_writing()
656
self._last_verb = args[0]
658
def call_with_body_readv_array(self, args, body):
659
"""Make a remote call with a readv array.
661
The body is encoded with one line per readv offset pair. The numbers in
662
each pair are separated by a comma, and no trailing \\n is emitted.
664
if 'hpss' in debug.debug_flags:
665
mutter('hpss call w/readv: %s', repr(args)[1:-1])
666
if getattr(self._request._medium, '_path', None) is not None:
667
mutter(' (to %s)', self._request._medium._path)
668
self._request_start_time = osutils.timer_func()
669
self._write_args(args)
670
readv_bytes = self._serialise_offsets(body)
671
bytes = self._encode_bulk_data(readv_bytes)
672
self._request.accept_bytes(bytes)
673
self._request.finished_writing()
674
if 'hpss' in debug.debug_flags:
675
mutter(' %d bytes in readv request', len(readv_bytes))
676
self._last_verb = args[0]
678
def call_with_body_stream(self, args, stream):
679
# Protocols v1 and v2 don't support body streams. So it's safe to
680
# assume that a v1/v2 server doesn't support whatever method we're
681
# trying to call with a body stream.
682
self._request.finished_writing()
683
self._request.finished_reading()
684
raise errors.UnknownSmartMethod(args[0])
686
def cancel_read_body(self):
687
"""After expecting a body, a response code may indicate one otherwise.
689
This method lets the domain client inform the protocol that no body
690
will be transmitted. This is a terminal method: after calling it the
691
protocol is not able to be used further.
693
self._request.finished_reading()
695
def _read_response_tuple(self):
696
result = self._recv_tuple()
697
if 'hpss' in debug.debug_flags:
698
if self._request_start_time is not None:
699
mutter(' result: %6.3fs %s',
700
osutils.timer_func() - self._request_start_time,
702
self._request_start_time = None
704
mutter(' result: %s', repr(result)[1:-1])
707
def read_response_tuple(self, expect_body=False):
708
"""Read a response tuple from the wire.
710
This should only be called once.
712
result = self._read_response_tuple()
713
self._response_is_unknown_method(result)
714
self._raise_args_if_error(result)
716
self._request.finished_reading()
719
def _raise_args_if_error(self, result_tuple):
720
# Later protocol versions have an explicit flag in the protocol to say
721
# if an error response is "failed" or not. In version 1 we don't have
722
# that luxury. So here is a complete list of errors that can be
723
# returned in response to existing version 1 smart requests. Responses
724
# starting with these codes are always "failed" responses.
729
b'DirectoryNotEmpty',
731
b'UnicodeEncodeError',
732
b'UnicodeDecodeError',
738
b'UnlockableTransport',
744
if result_tuple[0] in v1_error_codes:
745
self._request.finished_reading()
746
raise errors.ErrorFromSmartServer(result_tuple)
748
def _response_is_unknown_method(self, result_tuple):
749
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
750
method' response to the request.
752
:param response: The response from a smart client call_expecting_body
754
:param verb: The verb used in that call.
755
:raises: UnexpectedSmartServerResponse
757
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
758
b"bad request '" + self._last_verb + b"'") or
759
result_tuple == (b'error', b"Generic bzr smart protocol error: "
760
b"bad request u'%s'" % self._last_verb)):
761
# The response will have no body, so we've finished reading.
762
self._request.finished_reading()
763
raise errors.UnknownSmartMethod(self._last_verb)
765
def read_body_bytes(self, count=-1):
766
"""Read bytes from the body, decoding into a byte stream.
768
We read all bytes at once to ensure we've checked the trailer for
769
errors, and then feed the buffer back as read_body_bytes is called.
771
if self._body_buffer is not None:
772
return self._body_buffer.read(count)
773
_body_decoder = LengthPrefixedBodyDecoder()
775
while not _body_decoder.finished_reading:
776
bytes = self._request.read_bytes(_body_decoder.next_read_size())
778
# end of file encountered reading from server
779
raise errors.ConnectionReset(
780
"Connection lost while reading response body.")
781
_body_decoder.accept_bytes(bytes)
782
self._request.finished_reading()
783
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
784
# XXX: TODO check the trailer result.
785
if 'hpss' in debug.debug_flags:
786
mutter(' %d body bytes read',
787
len(self._body_buffer.getvalue()))
788
return self._body_buffer.read(count)
790
def _recv_tuple(self):
791
"""Receive a tuple from the medium request."""
792
return _decode_tuple(self._request.read_line())
794
def query_version(self):
795
"""Return protocol version number of the server."""
797
resp = self.read_response_tuple()
798
if resp == (b'ok', '1'):
800
elif resp == (b'ok', '2'):
803
raise errors.SmartProtocolError("bad response %r" % (resp,))
805
def _write_args(self, args):
806
self._write_protocol_version()
807
bytes = _encode_tuple(args)
808
self._request.accept_bytes(bytes)
810
def _write_protocol_version(self):
811
"""Write any prefixes this protocol requires.
813
Version one doesn't send protocol versions.
817
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
818
"""Version two of the client side of the smart protocol.
820
This prefixes the request with the value of REQUEST_VERSION_TWO.
823
response_marker = RESPONSE_VERSION_TWO
824
request_marker = REQUEST_VERSION_TWO
826
def read_response_tuple(self, expect_body=False):
827
"""Read a response tuple from the wire.
829
This should only be called once.
831
version = self._request.read_line()
832
if version != self.response_marker:
833
self._request.finished_reading()
834
raise errors.UnexpectedProtocolVersionMarker(version)
835
response_status = self._request.read_line()
836
result = SmartClientRequestProtocolOne._read_response_tuple(self)
837
self._response_is_unknown_method(result)
838
if response_status == b'success\n':
839
self.response_status = True
841
self._request.finished_reading()
843
elif response_status == b'failed\n':
844
self.response_status = False
845
self._request.finished_reading()
846
raise errors.ErrorFromSmartServer(result)
848
raise errors.SmartProtocolError(
849
'bad protocol status %r' % response_status)
851
def _write_protocol_version(self):
852
"""Write any prefixes this protocol requires.
854
Version two sends the value of REQUEST_VERSION_TWO.
856
self._request.accept_bytes(self.request_marker)
858
def read_streamed_body(self):
859
"""Read bytes from the body, decoding into a byte stream.
861
# Read no more than 64k at a time so that we don't risk error 10055 (no
862
# buffer space available) on Windows.
863
_body_decoder = ChunkedBodyDecoder()
864
while not _body_decoder.finished_reading:
865
bytes = self._request.read_bytes(_body_decoder.next_read_size())
867
# end of file encountered reading from server
868
raise errors.ConnectionReset(
869
"Connection lost while reading streamed body.")
870
_body_decoder.accept_bytes(bytes)
871
for body_bytes in iter(_body_decoder.read_next_chunk, None):
872
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
873
mutter(' %d byte chunk read',
876
self._request.finished_reading()
879
def build_server_protocol_three(backing_transport, write_func,
880
root_client_path, jail_root=None):
881
request_handler = request.SmartServerRequestHandler(
882
backing_transport, commands=request.request_handlers,
883
root_client_path=root_client_path, jail_root=jail_root)
884
responder = ProtocolThreeResponder(write_func)
885
message_handler = message.ConventionalRequestHandler(request_handler, responder)
886
return ProtocolThreeDecoder(message_handler)
889
class ProtocolThreeDecoder(_StatefulDecoder):
891
response_marker = RESPONSE_VERSION_THREE
892
request_marker = REQUEST_VERSION_THREE
894
def __init__(self, message_handler, expect_version_marker=False):
895
_StatefulDecoder.__init__(self)
896
self._has_dispatched = False
898
if expect_version_marker:
899
self.state_accept = self._state_accept_expecting_protocol_version
900
# We're expecting at least the protocol version marker + some
902
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
904
self.state_accept = self._state_accept_expecting_headers
905
self._number_needed_bytes = 4
906
self.decoding_failed = False
907
self.request_handler = self.message_handler = message_handler
909
def accept_bytes(self, bytes):
910
self._number_needed_bytes = None
912
_StatefulDecoder.accept_bytes(self, bytes)
913
except KeyboardInterrupt:
915
except errors.SmartMessageHandlerError as exception:
916
# We do *not* set self.decoding_failed here. The message handler
917
# has raised an error, but the decoder is still able to parse bytes
918
# and determine when this message ends.
919
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
920
log_exception_quietly()
921
self.message_handler.protocol_error(exception.exc_value)
922
# The state machine is ready to continue decoding, but the
923
# exception has interrupted the loop that runs the state machine.
924
# So we call accept_bytes again to restart it.
925
self.accept_bytes('')
926
except Exception as exception:
927
# The decoder itself has raised an exception. We cannot continue
929
self.decoding_failed = True
930
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
931
# This happens during normal operation when the client tries a
932
# protocol version the server doesn't understand, so no need to
933
# log a traceback every time.
934
# Note that this can only happen when
935
# expect_version_marker=True, which is only the case on the
939
log_exception_quietly()
940
self.message_handler.protocol_error(exception)
942
def _extract_length_prefixed_bytes(self):
943
if self._in_buffer_len < 4:
944
# A length prefix by itself is 4 bytes, and we don't even have that
946
raise _NeedMoreBytes(4)
947
(length,) = struct.unpack('!L', self._get_in_bytes(4))
948
end_of_bytes = 4 + length
949
if self._in_buffer_len < end_of_bytes:
950
# We haven't yet read as many bytes as the length-prefix says there
952
raise _NeedMoreBytes(end_of_bytes)
953
# Extract the bytes from the buffer.
954
in_buf = self._get_in_buffer()
955
bytes = in_buf[4:end_of_bytes]
956
self._set_in_buffer(in_buf[end_of_bytes:])
959
def _extract_prefixed_bencoded_data(self):
960
prefixed_bytes = self._extract_length_prefixed_bytes()
962
decoded = bdecode_as_tuple(prefixed_bytes)
964
raise errors.SmartProtocolError(
965
'Bytes %r not bencoded' % (prefixed_bytes,))
968
def _extract_single_byte(self):
969
if self._in_buffer_len == 0:
970
# The buffer is empty
971
raise _NeedMoreBytes(1)
972
in_buf = self._get_in_buffer()
973
one_byte = in_buf[0:1]
974
self._set_in_buffer(in_buf[1:])
977
def _state_accept_expecting_protocol_version(self):
978
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
979
in_buf = self._get_in_buffer()
981
# We don't have enough bytes to check if the protocol version
982
# marker is right. But we can check if it is already wrong by
983
# checking that the start of MESSAGE_VERSION_THREE matches what
985
# [In fact, if the remote end isn't bzr we might never receive
986
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
987
# are wrong then we should just raise immediately rather than
989
if not MESSAGE_VERSION_THREE.startswith(in_buf):
990
# We have enough bytes to know the protocol version is wrong
991
raise errors.UnexpectedProtocolVersionMarker(in_buf)
992
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
993
if not in_buf.startswith(MESSAGE_VERSION_THREE):
994
raise errors.UnexpectedProtocolVersionMarker(in_buf)
995
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
996
self.state_accept = self._state_accept_expecting_headers
998
def _state_accept_expecting_headers(self):
999
decoded = self._extract_prefixed_bencoded_data()
1000
if not isinstance(decoded, dict):
1001
raise errors.SmartProtocolError(
1002
'Header object %r is not a dict' % (decoded,))
1003
self.state_accept = self._state_accept_expecting_message_part
1005
self.message_handler.headers_received(decoded)
1007
raise errors.SmartMessageHandlerError(sys.exc_info())
1009
def _state_accept_expecting_message_part(self):
1010
message_part_kind = self._extract_single_byte()
1011
if message_part_kind == b'o':
1012
self.state_accept = self._state_accept_expecting_one_byte
1013
elif message_part_kind == b's':
1014
self.state_accept = self._state_accept_expecting_structure
1015
elif message_part_kind == b'b':
1016
self.state_accept = self._state_accept_expecting_bytes
1017
elif message_part_kind == b'e':
1020
raise errors.SmartProtocolError(
1021
'Bad message kind byte: %r' % (message_part_kind,))
1023
def _state_accept_expecting_one_byte(self):
1024
byte = self._extract_single_byte()
1025
self.state_accept = self._state_accept_expecting_message_part
1027
self.message_handler.byte_part_received(byte)
1029
raise errors.SmartMessageHandlerError(sys.exc_info())
1031
def _state_accept_expecting_bytes(self):
1032
# XXX: this should not buffer whole message part, but instead deliver
1033
# the bytes as they arrive.
1034
prefixed_bytes = self._extract_length_prefixed_bytes()
1035
self.state_accept = self._state_accept_expecting_message_part
1037
self.message_handler.bytes_part_received(prefixed_bytes)
1039
raise errors.SmartMessageHandlerError(sys.exc_info())
1041
def _state_accept_expecting_structure(self):
1042
structure = self._extract_prefixed_bencoded_data()
1043
self.state_accept = self._state_accept_expecting_message_part
1045
self.message_handler.structure_part_received(structure)
1047
raise errors.SmartMessageHandlerError(sys.exc_info())
1050
self.unused_data = self._get_in_buffer()
1051
self._set_in_buffer(None)
1052
self.state_accept = self._state_accept_reading_unused
1054
self.message_handler.end_received()
1056
raise errors.SmartMessageHandlerError(sys.exc_info())
1058
def _state_accept_reading_unused(self):
1059
self.unused_data += self._get_in_buffer()
1060
self._set_in_buffer(None)
1062
def next_read_size(self):
1063
if self.state_accept == self._state_accept_reading_unused:
1065
elif self.decoding_failed:
1066
# An exception occured while processing this message, probably from
1067
# self.message_handler. We're not sure that this state machine is
1068
# in a consistent state, so just signal that we're done (i.e. give
1072
if self._number_needed_bytes is not None:
1073
return self._number_needed_bytes - self._in_buffer_len
1075
raise AssertionError("don't know how many bytes are expected!")
1078
class _ProtocolThreeEncoder(object):
1080
response_marker = request_marker = MESSAGE_VERSION_THREE
1081
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1083
def __init__(self, write_func):
1086
self._real_write_func = write_func
1088
def _write_func(self, bytes):
1089
# TODO: Another possibility would be to turn this into an async model.
1090
# Where we let another thread know that we have some bytes if
1091
# they want it, but we don't actually block for it
1092
# Note that osutils.send_all always sends 64kB chunks anyway, so
1093
# we might just push out smaller bits at a time?
1094
self._buf.append(bytes)
1095
self._buf_len += len(bytes)
1096
if self._buf_len > self.BUFFER_SIZE:
1101
self._real_write_func(b''.join(self._buf))
1105
def _serialise_offsets(self, offsets):
1106
"""Serialise a readv offset list."""
1108
for start, length in offsets:
1109
txt.append(b'%d,%d' % (start, length))
1110
return b'\n'.join(txt)
1112
def _write_protocol_version(self):
1113
self._write_func(MESSAGE_VERSION_THREE)
1115
def _write_prefixed_bencode(self, structure):
1116
bytes = bencode(structure)
1117
self._write_func(struct.pack('!L', len(bytes)))
1118
self._write_func(bytes)
1120
def _write_headers(self, headers):
1121
self._write_prefixed_bencode(headers)
1123
def _write_structure(self, args):
1124
self._write_func(b's')
1127
if isinstance(arg, text_type):
1128
utf8_args.append(arg.encode('utf8'))
1130
utf8_args.append(arg)
1131
self._write_prefixed_bencode(utf8_args)
1133
def _write_end(self):
1134
self._write_func(b'e')
1137
def _write_prefixed_body(self, bytes):
1138
self._write_func(b'b')
1139
self._write_func(struct.pack('!L', len(bytes)))
1140
self._write_func(bytes)
1142
def _write_chunked_body_start(self):
1143
self._write_func(b'oC')
1145
def _write_error_status(self):
1146
self._write_func(b'oE')
1148
def _write_success_status(self):
1149
self._write_func(b'oS')
1152
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1154
def __init__(self, write_func):
1155
_ProtocolThreeEncoder.__init__(self, write_func)
1156
self.response_sent = False
1158
b'Software version': breezy.__version__.encode('utf-8')}
1159
if 'hpss' in debug.debug_flags:
1160
self._thread_id = _thread.get_ident()
1161
self._response_start_time = None
1163
def _trace(self, action, message, extra_bytes=None, include_time=False):
1164
if self._response_start_time is None:
1165
self._response_start_time = osutils.timer_func()
1167
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1170
if extra_bytes is None:
1173
extra = ' ' + repr(extra_bytes[:40])
1175
extra = extra[:29] + extra[-1] + '...'
1176
mutter('%12s: [%s] %s%s%s'
1177
% (action, self._thread_id, t, message, extra))
1179
def send_error(self, exception):
1180
if self.response_sent:
1181
raise AssertionError(
1182
"send_error(%s) called, but response already sent."
1184
if isinstance(exception, errors.UnknownSmartMethod):
1185
failure = request.FailedSmartServerResponse(
1186
(b'UnknownMethod', exception.verb))
1187
self.send_response(failure)
1189
if 'hpss' in debug.debug_flags:
1190
self._trace('error', str(exception))
1191
self.response_sent = True
1192
self._write_protocol_version()
1193
self._write_headers(self._headers)
1194
self._write_error_status()
1195
self._write_structure((b'error', str(exception).encode('utf-8', 'replace')))
1198
def send_response(self, response):
1199
if self.response_sent:
1200
raise AssertionError(
1201
"send_response(%r) called, but response already sent."
1203
self.response_sent = True
1204
self._write_protocol_version()
1205
self._write_headers(self._headers)
1206
if response.is_successful():
1207
self._write_success_status()
1209
self._write_error_status()
1210
if 'hpss' in debug.debug_flags:
1211
self._trace('response', repr(response.args))
1212
self._write_structure(response.args)
1213
if response.body is not None:
1214
self._write_prefixed_body(response.body)
1215
if 'hpss' in debug.debug_flags:
1216
self._trace('body', '%d bytes' % (len(response.body),),
1217
response.body, include_time=True)
1218
elif response.body_stream is not None:
1219
count = num_bytes = 0
1221
for exc_info, chunk in _iter_with_errors(response.body_stream):
1223
if exc_info is not None:
1224
self._write_error_status()
1225
error_struct = request._translate_error(exc_info[1])
1226
self._write_structure(error_struct)
1229
if isinstance(chunk, request.FailedSmartServerResponse):
1230
self._write_error_status()
1231
self._write_structure(chunk.args)
1233
num_bytes += len(chunk)
1234
if first_chunk is None:
1236
self._write_prefixed_body(chunk)
1238
if 'hpssdetail' in debug.debug_flags:
1239
# Not worth timing separately, as _write_func is
1241
self._trace('body chunk',
1242
'%d bytes' % (len(chunk),),
1243
chunk, suppress_time=True)
1244
if 'hpss' in debug.debug_flags:
1245
self._trace('body stream',
1246
'%d bytes %d chunks' % (num_bytes, count),
1249
if 'hpss' in debug.debug_flags:
1250
self._trace('response end', '', include_time=True)
1253
def _iter_with_errors(iterable):
1254
"""Handle errors from iterable.next().
1258
for exc_info, value in _iter_with_errors(iterable):
1261
This is a safer alternative to::
1264
for value in iterable:
1269
Because the latter will catch errors from the for-loop body, not just
1272
If an error occurs, exc_info will be a exc_info tuple, and the generator
1273
will terminate. Otherwise exc_info will be None, and value will be the
1274
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1275
will not be itercepted.
1277
iterator = iter(iterable)
1280
yield None, next(iterator)
1281
except StopIteration:
1283
except (KeyboardInterrupt, SystemExit):
1286
mutter('_iter_with_errors caught error')
1287
log_exception_quietly()
1288
yield sys.exc_info(), None
1292
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1294
def __init__(self, medium_request):
1295
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1296
self._medium_request = medium_request
1298
self.body_stream_started = None
1300
def set_headers(self, headers):
1301
self._headers = headers.copy()
1303
def call(self, *args):
1304
if 'hpss' in debug.debug_flags:
1305
mutter('hpss call: %s', repr(args)[1:-1])
1306
base = getattr(self._medium_request._medium, 'base', None)
1307
if base is not None:
1308
mutter(' (to %s)', base)
1309
self._request_start_time = osutils.timer_func()
1310
self._write_protocol_version()
1311
self._write_headers(self._headers)
1312
self._write_structure(args)
1314
self._medium_request.finished_writing()
1316
def call_with_body_bytes(self, args, body):
1317
"""Make a remote call of args with body bytes 'body'.
1319
After calling this, call read_response_tuple to find the result out.
1321
if 'hpss' in debug.debug_flags:
1322
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1323
path = getattr(self._medium_request._medium, '_path', None)
1324
if path is not None:
1325
mutter(' (to %s)', path)
1326
mutter(' %d bytes', len(body))
1327
self._request_start_time = osutils.timer_func()
1328
self._write_protocol_version()
1329
self._write_headers(self._headers)
1330
self._write_structure(args)
1331
self._write_prefixed_body(body)
1333
self._medium_request.finished_writing()
1335
def call_with_body_readv_array(self, args, body):
1336
"""Make a remote call with a readv array.
1338
The body is encoded with one line per readv offset pair. The numbers in
1339
each pair are separated by a comma, and no trailing \\n is emitted.
1341
if 'hpss' in debug.debug_flags:
1342
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1343
path = getattr(self._medium_request._medium, '_path', None)
1344
if path is not None:
1345
mutter(' (to %s)', path)
1346
self._request_start_time = osutils.timer_func()
1347
self._write_protocol_version()
1348
self._write_headers(self._headers)
1349
self._write_structure(args)
1350
readv_bytes = self._serialise_offsets(body)
1351
if 'hpss' in debug.debug_flags:
1352
mutter(' %d bytes in readv request', len(readv_bytes))
1353
self._write_prefixed_body(readv_bytes)
1355
self._medium_request.finished_writing()
1357
def call_with_body_stream(self, args, stream):
1358
if 'hpss' in debug.debug_flags:
1359
mutter('hpss call w/body stream: %r', args)
1360
path = getattr(self._medium_request._medium, '_path', None)
1361
if path is not None:
1362
mutter(' (to %s)', path)
1363
self._request_start_time = osutils.timer_func()
1364
self.body_stream_started = False
1365
self._write_protocol_version()
1366
self._write_headers(self._headers)
1367
self._write_structure(args)
1368
# TODO: notice if the server has sent an early error reply before we
1369
# have finished sending the stream. We would notice at the end
1370
# anyway, but if the medium can deliver it early then it's good
1371
# to short-circuit the whole request...
1372
# Provoke any ConnectionReset failures before we start the body stream.
1374
self.body_stream_started = True
1375
for exc_info, part in _iter_with_errors(stream):
1376
if exc_info is not None:
1377
# Iterating the stream failed. Cleanly abort the request.
1378
self._write_error_status()
1379
# Currently the client unconditionally sends ('error',) as the
1381
self._write_structure((b'error',))
1383
self._medium_request.finished_writing()
1389
self._write_prefixed_body(part)
1392
self._medium_request.finished_writing()