1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
24
from collections.abc import deque
25
except ImportError: # python < 3.7
26
from collections import deque
33
import thread as _thread
42
from ...sixish import (
46
from . import message, request
47
from ...sixish import text_type
48
from ...trace import log_exception_quietly, mutter
49
from ...bencode import bdecode_as_tuple, bencode
52
# Protocol version strings. These are sent as prefixes of bzr requests and
53
# responses to identify the protocol version being used. (There are no version
54
# one strings because that version doesn't send any).
55
REQUEST_VERSION_TWO = b'bzr request 2\n'
56
RESPONSE_VERSION_TWO = b'bzr response 2\n'
58
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
59
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
62
def _recv_tuple(from_file):
63
req_line = from_file.readline()
64
return _decode_tuple(req_line)
67
def _decode_tuple(req_line):
68
if req_line is None or req_line == b'':
70
if not req_line.endswith(b'\n'):
71
raise errors.SmartProtocolError("request %r not terminated" % req_line)
72
return tuple(req_line[:-1].split(b'\x01'))
75
def _encode_tuple(args):
76
"""Encode the tuple args to a bytestream."""
78
if isinstance(arg, text_type):
80
return b'\x01'.join(args) + b'\n'
83
class Requester(object):
84
"""Abstract base class for an object that can issue requests on a smart
88
def call(self, *args):
89
"""Make a remote call.
91
:param args: the arguments of this call.
93
raise NotImplementedError(self.call)
95
def call_with_body_bytes(self, args, body):
96
"""Make a remote call with a body.
98
:param args: the arguments of this call.
100
:param body: the body to send with the request.
102
raise NotImplementedError(self.call_with_body_bytes)
104
def call_with_body_readv_array(self, args, body):
105
"""Make a remote call with a readv array.
107
:param args: the arguments of this call.
108
:type body: iterable of (start, length) tuples.
109
:param body: the readv ranges to send with this request.
111
raise NotImplementedError(self.call_with_body_readv_array)
113
def set_headers(self, headers):
114
raise NotImplementedError(self.set_headers)
117
class SmartProtocolBase(object):
118
"""Methods common to client and server"""
120
# TODO: this only actually accomodates a single block; possibly should
121
# support multiple chunks?
122
def _encode_bulk_data(self, body):
123
"""Encode body as a bulk data chunk."""
124
return b''.join((b'%d\n' % len(body), body, b'done\n'))
126
def _serialise_offsets(self, offsets):
127
"""Serialise a readv offset list."""
129
for start, length in offsets:
130
txt.append(b'%d,%d' % (start, length))
131
return b'\n'.join(txt)
134
class SmartServerRequestProtocolOne(SmartProtocolBase):
135
"""Server-side encoding and decoding logic for smart version 1."""
137
def __init__(self, backing_transport, write_func, root_client_path='/',
139
self._backing_transport = backing_transport
140
self._root_client_path = root_client_path
141
self._jail_root = jail_root
142
self.unused_data = b''
143
self._finished = False
145
self._has_dispatched = False
147
self._body_decoder = None
148
self._write_func = write_func
150
def accept_bytes(self, data):
151
"""Take bytes, and advance the internal state machine appropriately.
153
:param data: must be a byte string
155
if not isinstance(data, bytes):
156
raise ValueError(data)
157
self.in_buffer += data
158
if not self._has_dispatched:
159
if b'\n' not in self.in_buffer:
160
# no command line yet
162
self._has_dispatched = True
164
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
166
req_args = _decode_tuple(first_line)
167
self.request = request.SmartServerRequestHandler(
168
self._backing_transport, commands=request.request_handlers,
169
root_client_path=self._root_client_path,
170
jail_root=self._jail_root)
171
self.request.args_received(req_args)
172
if self.request.finished_reading:
174
self.unused_data = self.in_buffer
176
self._send_response(self.request.response)
177
except KeyboardInterrupt:
179
except errors.UnknownSmartMethod as err:
180
protocol_error = errors.SmartProtocolError(
181
"bad request '%s'" % (err.verb.decode('ascii'),))
182
failure = request.FailedSmartServerResponse(
183
(b'error', str(protocol_error).encode('utf-8')))
184
self._send_response(failure)
186
except Exception as exception:
187
# everything else: pass to client, flush, and quit
188
log_exception_quietly()
189
self._send_response(request.FailedSmartServerResponse(
190
(b'error', str(exception).encode('utf-8'))))
193
if self._has_dispatched:
195
# nothing to do.XXX: this routine should be a single state
197
self.unused_data += self.in_buffer
200
if self._body_decoder is None:
201
self._body_decoder = LengthPrefixedBodyDecoder()
202
self._body_decoder.accept_bytes(self.in_buffer)
203
self.in_buffer = self._body_decoder.unused_data
204
body_data = self._body_decoder.read_pending_data()
205
self.request.accept_body(body_data)
206
if self._body_decoder.finished_reading:
207
self.request.end_of_body()
208
if not self.request.finished_reading:
209
raise AssertionError("no more body, request not finished")
210
if self.request.response is not None:
211
self._send_response(self.request.response)
212
self.unused_data = self.in_buffer
215
if self.request.finished_reading:
216
raise AssertionError(
217
"no response and we have finished reading.")
219
def _send_response(self, response):
220
"""Send a smart server response down the output stream."""
222
raise AssertionError('response already sent')
225
self._finished = True
226
self._write_protocol_version()
227
self._write_success_or_failure_prefix(response)
228
self._write_func(_encode_tuple(args))
230
if not isinstance(body, bytes):
231
raise ValueError(body)
232
data = self._encode_bulk_data(body)
233
self._write_func(data)
235
def _write_protocol_version(self):
236
"""Write any prefixes this protocol requires.
238
Version one doesn't send protocol versions.
241
def _write_success_or_failure_prefix(self, response):
242
"""Write the protocol specific success/failure prefix.
244
For SmartServerRequestProtocolOne this is omitted but we
245
call is_successful to ensure that the response is valid.
247
response.is_successful()
249
def next_read_size(self):
252
if self._body_decoder is None:
255
return self._body_decoder.next_read_size()
258
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
259
r"""Version two of the server side of the smart protocol.
261
This prefixes responses with the value of RESPONSE_VERSION_TWO.
264
response_marker = RESPONSE_VERSION_TWO
265
request_marker = REQUEST_VERSION_TWO
267
def _write_success_or_failure_prefix(self, response):
268
"""Write the protocol specific success/failure prefix."""
269
if response.is_successful():
270
self._write_func(b'success\n')
272
self._write_func(b'failed\n')
274
def _write_protocol_version(self):
275
r"""Write any prefixes this protocol requires.
277
Version two sends the value of RESPONSE_VERSION_TWO.
279
self._write_func(self.response_marker)
281
def _send_response(self, response):
282
"""Send a smart server response down the output stream."""
284
raise AssertionError('response already sent')
285
self._finished = True
286
self._write_protocol_version()
287
self._write_success_or_failure_prefix(response)
288
self._write_func(_encode_tuple(response.args))
289
if response.body is not None:
290
if not isinstance(response.body, bytes):
291
raise AssertionError('body must be bytes')
292
if not (response.body_stream is None):
293
raise AssertionError(
294
'body_stream and body cannot both be set')
295
data = self._encode_bulk_data(response.body)
296
self._write_func(data)
297
elif response.body_stream is not None:
298
_send_stream(response.body_stream, self._write_func)
301
def _send_stream(stream, write_func):
302
write_func(b'chunked\n')
303
_send_chunks(stream, write_func)
307
def _send_chunks(stream, write_func):
309
if isinstance(chunk, bytes):
310
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
312
elif isinstance(chunk, request.FailedSmartServerResponse):
314
_send_chunks(chunk.args, write_func)
317
raise errors.BzrError(
318
'Chunks must be str or FailedSmartServerResponse, got %r'
322
class _NeedMoreBytes(Exception):
323
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
327
def __init__(self, count=None):
330
:param count: the total number of bytes needed by the current state.
331
May be None if the number of bytes needed is unknown.
336
class _StatefulDecoder(object):
337
"""Base class for writing state machines to decode byte streams.
339
Subclasses should provide a self.state_accept attribute that accepts bytes
340
and, if appropriate, updates self.state_accept to a different function.
341
accept_bytes will call state_accept as often as necessary to make sure the
342
state machine has progressed as far as possible before it returns.
344
See ProtocolThreeDecoder for an example subclass.
348
self.finished_reading = False
349
self._in_buffer_list = []
350
self._in_buffer_len = 0
351
self.unused_data = b''
352
self.bytes_left = None
353
self._number_needed_bytes = None
355
def _get_in_buffer(self):
356
if len(self._in_buffer_list) == 1:
357
return self._in_buffer_list[0]
358
in_buffer = b''.join(self._in_buffer_list)
359
if len(in_buffer) != self._in_buffer_len:
360
raise AssertionError(
361
"Length of buffer did not match expected value: %s != %s"
362
% self._in_buffer_len, len(in_buffer))
363
self._in_buffer_list = [in_buffer]
366
def _get_in_bytes(self, count):
367
"""Grab X bytes from the input_buffer.
369
Callers should have already checked that self._in_buffer_len is >
370
count. Note, this does not consume the bytes from the buffer. The
371
caller will still need to call _get_in_buffer() and then
372
_set_in_buffer() if they actually need to consume the bytes.
374
# check if we can yield the bytes from just the first entry in our list
375
if len(self._in_buffer_list) == 0:
376
raise AssertionError('Callers must be sure we have buffered bytes'
377
' before calling _get_in_bytes')
378
if len(self._in_buffer_list[0]) > count:
379
return self._in_buffer_list[0][:count]
380
# We can't yield it from the first buffer, so collapse all buffers, and
382
in_buf = self._get_in_buffer()
383
return in_buf[:count]
385
def _set_in_buffer(self, new_buf):
386
if new_buf is not None:
387
if not isinstance(new_buf, bytes):
388
raise TypeError(new_buf)
389
self._in_buffer_list = [new_buf]
390
self._in_buffer_len = len(new_buf)
392
self._in_buffer_list = []
393
self._in_buffer_len = 0
395
def accept_bytes(self, new_buf):
396
"""Decode as much of bytes as possible.
398
If 'new_buf' contains too much data it will be appended to
401
finished_reading will be set when no more data is required. Further
402
data will be appended to self.unused_data.
404
if not isinstance(new_buf, bytes):
405
raise TypeError(new_buf)
406
# accept_bytes is allowed to change the state
407
self._number_needed_bytes = None
408
# lsprof puts a very large amount of time on this specific call for
410
self._in_buffer_list.append(new_buf)
411
self._in_buffer_len += len(new_buf)
413
# Run the function for the current state.
414
current_state = self.state_accept
416
while current_state != self.state_accept:
417
# The current state has changed. Run the function for the new
418
# current state, so that it can:
419
# - decode any unconsumed bytes left in a buffer, and
420
# - signal how many more bytes are expected (via raising
422
current_state = self.state_accept
424
except _NeedMoreBytes as e:
425
self._number_needed_bytes = e.count
428
class ChunkedBodyDecoder(_StatefulDecoder):
429
"""Decoder for chunked body data.
431
This is very similar the HTTP's chunked encoding. See the description of
432
streamed body data in `doc/developers/network-protocol.txt` for details.
436
_StatefulDecoder.__init__(self)
437
self.state_accept = self._state_accept_expecting_header
438
self.chunk_in_progress = None
439
self.chunks = deque()
441
self.error_in_progress = None
443
def next_read_size(self):
444
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
445
# end-of-body marker is 4 bytes: 'END\n'.
446
if self.state_accept == self._state_accept_reading_chunk:
447
# We're expecting more chunk content. So we're expecting at least
448
# the rest of this chunk plus an END chunk.
449
return self.bytes_left + 4
450
elif self.state_accept == self._state_accept_expecting_length:
451
if self._in_buffer_len == 0:
452
# We're expecting a chunk length. There's at least two bytes
453
# left: a digit plus '\n'.
456
# We're in the middle of reading a chunk length. So there's at
457
# least one byte left, the '\n' that terminates the length.
459
elif self.state_accept == self._state_accept_reading_unused:
461
elif self.state_accept == self._state_accept_expecting_header:
462
return max(0, len('chunked\n') - self._in_buffer_len)
464
raise AssertionError("Impossible state: %r" % (self.state_accept,))
466
def read_next_chunk(self):
468
return self.chunks.popleft()
472
def _extract_line(self):
473
in_buf = self._get_in_buffer()
474
pos = in_buf.find(b'\n')
476
# We haven't read a complete line yet, so request more bytes before
478
raise _NeedMoreBytes(1)
480
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
481
self._set_in_buffer(in_buf[pos + 1:])
485
self.unused_data = self._get_in_buffer()
486
self._in_buffer_list = []
487
self._in_buffer_len = 0
488
self.state_accept = self._state_accept_reading_unused
490
error_args = tuple(self.error_in_progress)
491
self.chunks.append(request.FailedSmartServerResponse(error_args))
492
self.error_in_progress = None
493
self.finished_reading = True
495
def _state_accept_expecting_header(self):
496
prefix = self._extract_line()
497
if prefix == b'chunked':
498
self.state_accept = self._state_accept_expecting_length
500
raise errors.SmartProtocolError(
501
'Bad chunked body header: "%s"' % (prefix,))
503
def _state_accept_expecting_length(self):
504
prefix = self._extract_line()
507
self.error_in_progress = []
508
self._state_accept_expecting_length()
510
elif prefix == b'END':
511
# We've read the end-of-body marker.
512
# Any further bytes are unused data, including the bytes left in
517
self.bytes_left = int(prefix, 16)
518
self.chunk_in_progress = b''
519
self.state_accept = self._state_accept_reading_chunk
521
def _state_accept_reading_chunk(self):
522
in_buf = self._get_in_buffer()
523
in_buffer_len = len(in_buf)
524
self.chunk_in_progress += in_buf[:self.bytes_left]
525
self._set_in_buffer(in_buf[self.bytes_left:])
526
self.bytes_left -= in_buffer_len
527
if self.bytes_left <= 0:
528
# Finished with chunk
529
self.bytes_left = None
531
self.error_in_progress.append(self.chunk_in_progress)
533
self.chunks.append(self.chunk_in_progress)
534
self.chunk_in_progress = None
535
self.state_accept = self._state_accept_expecting_length
537
def _state_accept_reading_unused(self):
538
self.unused_data += self._get_in_buffer()
539
self._in_buffer_list = []
542
class LengthPrefixedBodyDecoder(_StatefulDecoder):
543
"""Decodes the length-prefixed bulk data."""
546
_StatefulDecoder.__init__(self)
547
self.state_accept = self._state_accept_expecting_length
548
self.state_read = self._state_read_no_data
550
self._trailer_buffer = b''
552
def next_read_size(self):
553
if self.bytes_left is not None:
554
# Ideally we want to read all the remainder of the body and the
556
return self.bytes_left + 5
557
elif self.state_accept == self._state_accept_reading_trailer:
558
# Just the trailer left
559
return 5 - len(self._trailer_buffer)
560
elif self.state_accept == self._state_accept_expecting_length:
561
# There's still at least 6 bytes left ('\n' to end the length, plus
565
# Reading excess data. Either way, 1 byte at a time is fine.
568
def read_pending_data(self):
569
"""Return any pending data that has been decoded."""
570
return self.state_read()
572
def _state_accept_expecting_length(self):
573
in_buf = self._get_in_buffer()
574
pos = in_buf.find(b'\n')
577
self.bytes_left = int(in_buf[:pos])
578
self._set_in_buffer(in_buf[pos + 1:])
579
self.state_accept = self._state_accept_reading_body
580
self.state_read = self._state_read_body_buffer
582
def _state_accept_reading_body(self):
583
in_buf = self._get_in_buffer()
585
self.bytes_left -= len(in_buf)
586
self._set_in_buffer(None)
587
if self.bytes_left <= 0:
589
if self.bytes_left != 0:
590
self._trailer_buffer = self._body[self.bytes_left:]
591
self._body = self._body[:self.bytes_left]
592
self.bytes_left = None
593
self.state_accept = self._state_accept_reading_trailer
595
def _state_accept_reading_trailer(self):
596
self._trailer_buffer += self._get_in_buffer()
597
self._set_in_buffer(None)
598
# TODO: what if the trailer does not match "done\n"? Should this raise
599
# a ProtocolViolation exception?
600
if self._trailer_buffer.startswith(b'done\n'):
601
self.unused_data = self._trailer_buffer[len(b'done\n'):]
602
self.state_accept = self._state_accept_reading_unused
603
self.finished_reading = True
605
def _state_accept_reading_unused(self):
606
self.unused_data += self._get_in_buffer()
607
self._set_in_buffer(None)
609
def _state_read_no_data(self):
612
def _state_read_body_buffer(self):
618
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
619
message.ResponseHandler):
620
"""The client-side protocol for smart version 1."""
622
def __init__(self, request):
623
"""Construct a SmartClientRequestProtocolOne.
625
:param request: A SmartClientMediumRequest to serialise onto and
628
self._request = request
629
self._body_buffer = None
630
self._request_start_time = None
631
self._last_verb = None
634
def set_headers(self, headers):
635
self._headers = dict(headers)
637
def call(self, *args):
638
if 'hpss' in debug.debug_flags:
639
mutter('hpss call: %s', repr(args)[1:-1])
640
if getattr(self._request._medium, 'base', None) is not None:
641
mutter(' (to %s)', self._request._medium.base)
642
self._request_start_time = osutils.perf_counter()
643
self._write_args(args)
644
self._request.finished_writing()
645
self._last_verb = args[0]
647
def call_with_body_bytes(self, args, body):
648
"""Make a remote call of args with body bytes 'body'.
650
After calling this, call read_response_tuple to find the result out.
652
if 'hpss' in debug.debug_flags:
653
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
654
if getattr(self._request._medium, '_path', None) is not None:
656
self._request._medium._path)
657
mutter(' %d bytes', len(body))
658
self._request_start_time = osutils.perf_counter()
659
if 'hpssdetail' in debug.debug_flags:
660
mutter('hpss body content: %s', body)
661
self._write_args(args)
662
bytes = self._encode_bulk_data(body)
663
self._request.accept_bytes(bytes)
664
self._request.finished_writing()
665
self._last_verb = args[0]
667
def call_with_body_readv_array(self, args, body):
668
"""Make a remote call with a readv array.
670
The body is encoded with one line per readv offset pair. The numbers in
671
each pair are separated by a comma, and no trailing \\n is emitted.
673
if 'hpss' in debug.debug_flags:
674
mutter('hpss call w/readv: %s', repr(args)[1:-1])
675
if getattr(self._request._medium, '_path', None) is not None:
677
self._request._medium._path)
678
self._request_start_time = osutils.perf_counter()
679
self._write_args(args)
680
readv_bytes = self._serialise_offsets(body)
681
bytes = self._encode_bulk_data(readv_bytes)
682
self._request.accept_bytes(bytes)
683
self._request.finished_writing()
684
if 'hpss' in debug.debug_flags:
685
mutter(' %d bytes in readv request', len(readv_bytes))
686
self._last_verb = args[0]
688
def call_with_body_stream(self, args, stream):
689
# Protocols v1 and v2 don't support body streams. So it's safe to
690
# assume that a v1/v2 server doesn't support whatever method we're
691
# trying to call with a body stream.
692
self._request.finished_writing()
693
self._request.finished_reading()
694
raise errors.UnknownSmartMethod(args[0])
696
def cancel_read_body(self):
697
"""After expecting a body, a response code may indicate one otherwise.
699
This method lets the domain client inform the protocol that no body
700
will be transmitted. This is a terminal method: after calling it the
701
protocol is not able to be used further.
703
self._request.finished_reading()
705
def _read_response_tuple(self):
706
result = self._recv_tuple()
707
if 'hpss' in debug.debug_flags:
708
if self._request_start_time is not None:
709
mutter(' result: %6.3fs %s',
710
osutils.perf_counter() - self._request_start_time,
712
self._request_start_time = None
714
mutter(' result: %s', repr(result)[1:-1])
717
def read_response_tuple(self, expect_body=False):
718
"""Read a response tuple from the wire.
720
This should only be called once.
722
result = self._read_response_tuple()
723
self._response_is_unknown_method(result)
724
self._raise_args_if_error(result)
726
self._request.finished_reading()
729
def _raise_args_if_error(self, result_tuple):
730
# Later protocol versions have an explicit flag in the protocol to say
731
# if an error response is "failed" or not. In version 1 we don't have
732
# that luxury. So here is a complete list of errors that can be
733
# returned in response to existing version 1 smart requests. Responses
734
# starting with these codes are always "failed" responses.
739
b'DirectoryNotEmpty',
741
b'UnicodeEncodeError',
742
b'UnicodeDecodeError',
748
b'UnlockableTransport',
754
if result_tuple[0] in v1_error_codes:
755
self._request.finished_reading()
756
raise errors.ErrorFromSmartServer(result_tuple)
758
def _response_is_unknown_method(self, result_tuple):
759
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
760
method' response to the request.
762
:param response: The response from a smart client call_expecting_body
764
:param verb: The verb used in that call.
765
:raises: UnexpectedSmartServerResponse
767
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
768
b"bad request '" + self._last_verb + b"'")
769
or result_tuple == (b'error', b"Generic bzr smart protocol error: "
770
b"bad request u'%s'" % self._last_verb)):
771
# The response will have no body, so we've finished reading.
772
self._request.finished_reading()
773
raise errors.UnknownSmartMethod(self._last_verb)
775
def read_body_bytes(self, count=-1):
776
"""Read bytes from the body, decoding into a byte stream.
778
We read all bytes at once to ensure we've checked the trailer for
779
errors, and then feed the buffer back as read_body_bytes is called.
781
if self._body_buffer is not None:
782
return self._body_buffer.read(count)
783
_body_decoder = LengthPrefixedBodyDecoder()
785
while not _body_decoder.finished_reading:
786
bytes = self._request.read_bytes(_body_decoder.next_read_size())
788
# end of file encountered reading from server
789
raise errors.ConnectionReset(
790
"Connection lost while reading response body.")
791
_body_decoder.accept_bytes(bytes)
792
self._request.finished_reading()
793
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
794
# XXX: TODO check the trailer result.
795
if 'hpss' in debug.debug_flags:
796
mutter(' %d body bytes read',
797
len(self._body_buffer.getvalue()))
798
return self._body_buffer.read(count)
800
def _recv_tuple(self):
801
"""Receive a tuple from the medium request."""
802
return _decode_tuple(self._request.read_line())
804
def query_version(self):
805
"""Return protocol version number of the server."""
807
resp = self.read_response_tuple()
808
if resp == (b'ok', b'1'):
810
elif resp == (b'ok', b'2'):
813
raise errors.SmartProtocolError("bad response %r" % (resp,))
815
def _write_args(self, args):
816
self._write_protocol_version()
817
bytes = _encode_tuple(args)
818
self._request.accept_bytes(bytes)
820
def _write_protocol_version(self):
821
"""Write any prefixes this protocol requires.
823
Version one doesn't send protocol versions.
827
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
828
"""Version two of the client side of the smart protocol.
830
This prefixes the request with the value of REQUEST_VERSION_TWO.
833
response_marker = RESPONSE_VERSION_TWO
834
request_marker = REQUEST_VERSION_TWO
836
def read_response_tuple(self, expect_body=False):
837
"""Read a response tuple from the wire.
839
This should only be called once.
841
version = self._request.read_line()
842
if version != self.response_marker:
843
self._request.finished_reading()
844
raise errors.UnexpectedProtocolVersionMarker(version)
845
response_status = self._request.read_line()
846
result = SmartClientRequestProtocolOne._read_response_tuple(self)
847
self._response_is_unknown_method(result)
848
if response_status == b'success\n':
849
self.response_status = True
851
self._request.finished_reading()
853
elif response_status == b'failed\n':
854
self.response_status = False
855
self._request.finished_reading()
856
raise errors.ErrorFromSmartServer(result)
858
raise errors.SmartProtocolError(
859
'bad protocol status %r' % response_status)
861
def _write_protocol_version(self):
862
"""Write any prefixes this protocol requires.
864
Version two sends the value of REQUEST_VERSION_TWO.
866
self._request.accept_bytes(self.request_marker)
868
def read_streamed_body(self):
869
"""Read bytes from the body, decoding into a byte stream.
871
# Read no more than 64k at a time so that we don't risk error 10055 (no
872
# buffer space available) on Windows.
873
_body_decoder = ChunkedBodyDecoder()
874
while not _body_decoder.finished_reading:
875
bytes = self._request.read_bytes(_body_decoder.next_read_size())
877
# end of file encountered reading from server
878
raise errors.ConnectionReset(
879
"Connection lost while reading streamed body.")
880
_body_decoder.accept_bytes(bytes)
881
for body_bytes in iter(_body_decoder.read_next_chunk, None):
882
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
883
mutter(' %d byte chunk read',
886
self._request.finished_reading()
889
def build_server_protocol_three(backing_transport, write_func,
890
root_client_path, jail_root=None):
891
request_handler = request.SmartServerRequestHandler(
892
backing_transport, commands=request.request_handlers,
893
root_client_path=root_client_path, jail_root=jail_root)
894
responder = ProtocolThreeResponder(write_func)
895
message_handler = message.ConventionalRequestHandler(
896
request_handler, responder)
897
return ProtocolThreeDecoder(message_handler)
900
class ProtocolThreeDecoder(_StatefulDecoder):
902
response_marker = RESPONSE_VERSION_THREE
903
request_marker = REQUEST_VERSION_THREE
905
def __init__(self, message_handler, expect_version_marker=False):
906
_StatefulDecoder.__init__(self)
907
self._has_dispatched = False
909
if expect_version_marker:
910
self.state_accept = self._state_accept_expecting_protocol_version
911
# We're expecting at least the protocol version marker + some
913
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
915
self.state_accept = self._state_accept_expecting_headers
916
self._number_needed_bytes = 4
917
self.decoding_failed = False
918
self.request_handler = self.message_handler = message_handler
920
def accept_bytes(self, bytes):
921
self._number_needed_bytes = None
923
_StatefulDecoder.accept_bytes(self, bytes)
924
except KeyboardInterrupt:
926
except errors.SmartMessageHandlerError as exception:
927
# We do *not* set self.decoding_failed here. The message handler
928
# has raised an error, but the decoder is still able to parse bytes
929
# and determine when this message ends.
930
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
931
log_exception_quietly()
932
self.message_handler.protocol_error(exception.exc_value)
933
# The state machine is ready to continue decoding, but the
934
# exception has interrupted the loop that runs the state machine.
935
# So we call accept_bytes again to restart it.
936
self.accept_bytes(b'')
937
except Exception as exception:
938
# The decoder itself has raised an exception. We cannot continue
940
self.decoding_failed = True
941
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
942
# This happens during normal operation when the client tries a
943
# protocol version the server doesn't understand, so no need to
944
# log a traceback every time.
945
# Note that this can only happen when
946
# expect_version_marker=True, which is only the case on the
950
log_exception_quietly()
951
self.message_handler.protocol_error(exception)
953
def _extract_length_prefixed_bytes(self):
954
if self._in_buffer_len < 4:
955
# A length prefix by itself is 4 bytes, and we don't even have that
957
raise _NeedMoreBytes(4)
958
(length,) = struct.unpack('!L', self._get_in_bytes(4))
959
end_of_bytes = 4 + length
960
if self._in_buffer_len < end_of_bytes:
961
# We haven't yet read as many bytes as the length-prefix says there
963
raise _NeedMoreBytes(end_of_bytes)
964
# Extract the bytes from the buffer.
965
in_buf = self._get_in_buffer()
966
bytes = in_buf[4:end_of_bytes]
967
self._set_in_buffer(in_buf[end_of_bytes:])
970
def _extract_prefixed_bencoded_data(self):
971
prefixed_bytes = self._extract_length_prefixed_bytes()
973
decoded = bdecode_as_tuple(prefixed_bytes)
975
raise errors.SmartProtocolError(
976
'Bytes %r not bencoded' % (prefixed_bytes,))
979
def _extract_single_byte(self):
980
if self._in_buffer_len == 0:
981
# The buffer is empty
982
raise _NeedMoreBytes(1)
983
in_buf = self._get_in_buffer()
984
one_byte = in_buf[0:1]
985
self._set_in_buffer(in_buf[1:])
988
def _state_accept_expecting_protocol_version(self):
989
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
990
in_buf = self._get_in_buffer()
992
# We don't have enough bytes to check if the protocol version
993
# marker is right. But we can check if it is already wrong by
994
# checking that the start of MESSAGE_VERSION_THREE matches what
996
# [In fact, if the remote end isn't bzr we might never receive
997
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
998
# are wrong then we should just raise immediately rather than
1000
if not MESSAGE_VERSION_THREE.startswith(in_buf):
1001
# We have enough bytes to know the protocol version is wrong
1002
raise errors.UnexpectedProtocolVersionMarker(in_buf)
1003
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
1004
if not in_buf.startswith(MESSAGE_VERSION_THREE):
1005
raise errors.UnexpectedProtocolVersionMarker(in_buf)
1006
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
1007
self.state_accept = self._state_accept_expecting_headers
1009
def _state_accept_expecting_headers(self):
1010
decoded = self._extract_prefixed_bencoded_data()
1011
if not isinstance(decoded, dict):
1012
raise errors.SmartProtocolError(
1013
'Header object %r is not a dict' % (decoded,))
1014
self.state_accept = self._state_accept_expecting_message_part
1016
self.message_handler.headers_received(decoded)
1018
raise errors.SmartMessageHandlerError(sys.exc_info())
1020
def _state_accept_expecting_message_part(self):
1021
message_part_kind = self._extract_single_byte()
1022
if message_part_kind == b'o':
1023
self.state_accept = self._state_accept_expecting_one_byte
1024
elif message_part_kind == b's':
1025
self.state_accept = self._state_accept_expecting_structure
1026
elif message_part_kind == b'b':
1027
self.state_accept = self._state_accept_expecting_bytes
1028
elif message_part_kind == b'e':
1031
raise errors.SmartProtocolError(
1032
'Bad message kind byte: %r' % (message_part_kind,))
1034
def _state_accept_expecting_one_byte(self):
1035
byte = self._extract_single_byte()
1036
self.state_accept = self._state_accept_expecting_message_part
1038
self.message_handler.byte_part_received(byte)
1040
raise errors.SmartMessageHandlerError(sys.exc_info())
1042
def _state_accept_expecting_bytes(self):
1043
# XXX: this should not buffer whole message part, but instead deliver
1044
# the bytes as they arrive.
1045
prefixed_bytes = self._extract_length_prefixed_bytes()
1046
self.state_accept = self._state_accept_expecting_message_part
1048
self.message_handler.bytes_part_received(prefixed_bytes)
1050
raise errors.SmartMessageHandlerError(sys.exc_info())
1052
def _state_accept_expecting_structure(self):
1053
structure = self._extract_prefixed_bencoded_data()
1054
self.state_accept = self._state_accept_expecting_message_part
1056
self.message_handler.structure_part_received(structure)
1058
raise errors.SmartMessageHandlerError(sys.exc_info())
1061
self.unused_data = self._get_in_buffer()
1062
self._set_in_buffer(None)
1063
self.state_accept = self._state_accept_reading_unused
1065
self.message_handler.end_received()
1067
raise errors.SmartMessageHandlerError(sys.exc_info())
1069
def _state_accept_reading_unused(self):
1070
self.unused_data += self._get_in_buffer()
1071
self._set_in_buffer(None)
1073
def next_read_size(self):
1074
if self.state_accept == self._state_accept_reading_unused:
1076
elif self.decoding_failed:
1077
# An exception occured while processing this message, probably from
1078
# self.message_handler. We're not sure that this state machine is
1079
# in a consistent state, so just signal that we're done (i.e. give
1083
if self._number_needed_bytes is not None:
1084
return self._number_needed_bytes - self._in_buffer_len
1086
raise AssertionError("don't know how many bytes are expected!")
1089
class _ProtocolThreeEncoder(object):
1091
response_marker = request_marker = MESSAGE_VERSION_THREE
1092
BUFFER_SIZE = 1024 * 1024 # 1 MiB buffer before flushing
1094
def __init__(self, write_func):
1097
self._real_write_func = write_func
1099
def _write_func(self, bytes):
1100
# TODO: Another possibility would be to turn this into an async model.
1101
# Where we let another thread know that we have some bytes if
1102
# they want it, but we don't actually block for it
1103
# Note that osutils.send_all always sends 64kB chunks anyway, so
1104
# we might just push out smaller bits at a time?
1105
self._buf.append(bytes)
1106
self._buf_len += len(bytes)
1107
if self._buf_len > self.BUFFER_SIZE:
1112
self._real_write_func(b''.join(self._buf))
1116
def _serialise_offsets(self, offsets):
1117
"""Serialise a readv offset list."""
1119
for start, length in offsets:
1120
txt.append(b'%d,%d' % (start, length))
1121
return b'\n'.join(txt)
1123
def _write_protocol_version(self):
1124
self._write_func(MESSAGE_VERSION_THREE)
1126
def _write_prefixed_bencode(self, structure):
1127
bytes = bencode(structure)
1128
self._write_func(struct.pack('!L', len(bytes)))
1129
self._write_func(bytes)
1131
def _write_headers(self, headers):
1132
self._write_prefixed_bencode(headers)
1134
def _write_structure(self, args):
1135
self._write_func(b's')
1138
if isinstance(arg, text_type):
1139
utf8_args.append(arg.encode('utf8'))
1141
utf8_args.append(arg)
1142
self._write_prefixed_bencode(utf8_args)
1144
def _write_end(self):
1145
self._write_func(b'e')
1148
def _write_prefixed_body(self, bytes):
1149
self._write_func(b'b')
1150
self._write_func(struct.pack('!L', len(bytes)))
1151
self._write_func(bytes)
1153
def _write_chunked_body_start(self):
1154
self._write_func(b'oC')
1156
def _write_error_status(self):
1157
self._write_func(b'oE')
1159
def _write_success_status(self):
1160
self._write_func(b'oS')
1163
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1165
def __init__(self, write_func):
1166
_ProtocolThreeEncoder.__init__(self, write_func)
1167
self.response_sent = False
1169
b'Software version': breezy.__version__.encode('utf-8')}
1170
if 'hpss' in debug.debug_flags:
1171
self._thread_id = _thread.get_ident()
1172
self._response_start_time = None
1174
def _trace(self, action, message, extra_bytes=None, include_time=False):
1175
if self._response_start_time is None:
1176
self._response_start_time = osutils.perf_counter()
1178
t = '%5.3fs ' % (osutils.perf_counter() - self._response_start_time)
1181
if extra_bytes is None:
1184
extra = ' ' + repr(extra_bytes[:40])
1186
extra = extra[:29] + extra[-1] + '...'
1187
mutter('%12s: [%s] %s%s%s'
1188
% (action, self._thread_id, t, message, extra))
1190
def send_error(self, exception):
1191
if self.response_sent:
1192
raise AssertionError(
1193
"send_error(%s) called, but response already sent."
1195
if isinstance(exception, errors.UnknownSmartMethod):
1196
failure = request.FailedSmartServerResponse(
1197
(b'UnknownMethod', exception.verb))
1198
self.send_response(failure)
1200
if 'hpss' in debug.debug_flags:
1201
self._trace('error', str(exception))
1202
self.response_sent = True
1203
self._write_protocol_version()
1204
self._write_headers(self._headers)
1205
self._write_error_status()
1206
self._write_structure(
1207
(b'error', str(exception).encode('utf-8', 'replace')))
1210
def send_response(self, response):
1211
if self.response_sent:
1212
raise AssertionError(
1213
"send_response(%r) called, but response already sent."
1215
self.response_sent = True
1216
self._write_protocol_version()
1217
self._write_headers(self._headers)
1218
if response.is_successful():
1219
self._write_success_status()
1221
self._write_error_status()
1222
if 'hpss' in debug.debug_flags:
1223
self._trace('response', repr(response.args))
1224
self._write_structure(response.args)
1225
if response.body is not None:
1226
self._write_prefixed_body(response.body)
1227
if 'hpss' in debug.debug_flags:
1228
self._trace('body', '%d bytes' % (len(response.body),),
1229
response.body, include_time=True)
1230
elif response.body_stream is not None:
1231
count = num_bytes = 0
1233
for exc_info, chunk in _iter_with_errors(response.body_stream):
1235
if exc_info is not None:
1236
self._write_error_status()
1237
error_struct = request._translate_error(exc_info[1])
1238
self._write_structure(error_struct)
1241
if isinstance(chunk, request.FailedSmartServerResponse):
1242
self._write_error_status()
1243
self._write_structure(chunk.args)
1245
num_bytes += len(chunk)
1246
if first_chunk is None:
1248
self._write_prefixed_body(chunk)
1250
if 'hpssdetail' in debug.debug_flags:
1251
# Not worth timing separately, as _write_func is
1253
self._trace('body chunk',
1254
'%d bytes' % (len(chunk),),
1255
chunk, suppress_time=True)
1256
if 'hpss' in debug.debug_flags:
1257
self._trace('body stream',
1258
'%d bytes %d chunks' % (num_bytes, count),
1261
if 'hpss' in debug.debug_flags:
1262
self._trace('response end', '', include_time=True)
1265
def _iter_with_errors(iterable):
1266
"""Handle errors from iterable.next().
1270
for exc_info, value in _iter_with_errors(iterable):
1273
This is a safer alternative to::
1276
for value in iterable:
1281
Because the latter will catch errors from the for-loop body, not just
1284
If an error occurs, exc_info will be a exc_info tuple, and the generator
1285
will terminate. Otherwise exc_info will be None, and value will be the
1286
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1287
will not be itercepted.
1289
iterator = iter(iterable)
1292
yield None, next(iterator)
1293
except StopIteration:
1295
except (KeyboardInterrupt, SystemExit):
1298
mutter('_iter_with_errors caught error')
1299
log_exception_quietly()
1300
yield sys.exc_info(), None
1304
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1306
def __init__(self, medium_request):
1307
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1308
self._medium_request = medium_request
1310
self.body_stream_started = None
1312
def set_headers(self, headers):
1313
self._headers = headers.copy()
1315
def call(self, *args):
1316
if 'hpss' in debug.debug_flags:
1317
mutter('hpss call: %s', repr(args)[1:-1])
1318
base = getattr(self._medium_request._medium, 'base', None)
1319
if base is not None:
1320
mutter(' (to %s)', base)
1321
self._request_start_time = osutils.perf_counter()
1322
self._write_protocol_version()
1323
self._write_headers(self._headers)
1324
self._write_structure(args)
1326
self._medium_request.finished_writing()
1328
def call_with_body_bytes(self, args, body):
1329
"""Make a remote call of args with body bytes 'body'.
1331
After calling this, call read_response_tuple to find the result out.
1333
if 'hpss' in debug.debug_flags:
1334
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1335
path = getattr(self._medium_request._medium, '_path', None)
1336
if path is not None:
1337
mutter(' (to %s)', path)
1338
mutter(' %d bytes', len(body))
1339
self._request_start_time = osutils.perf_counter()
1340
self._write_protocol_version()
1341
self._write_headers(self._headers)
1342
self._write_structure(args)
1343
self._write_prefixed_body(body)
1345
self._medium_request.finished_writing()
1347
def call_with_body_readv_array(self, args, body):
1348
"""Make a remote call with a readv array.
1350
The body is encoded with one line per readv offset pair. The numbers in
1351
each pair are separated by a comma, and no trailing \\n is emitted.
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.perf_counter()
1359
self._write_protocol_version()
1360
self._write_headers(self._headers)
1361
self._write_structure(args)
1362
readv_bytes = self._serialise_offsets(body)
1363
if 'hpss' in debug.debug_flags:
1364
mutter(' %d bytes in readv request', len(readv_bytes))
1365
self._write_prefixed_body(readv_bytes)
1367
self._medium_request.finished_writing()
1369
def call_with_body_stream(self, args, stream):
1370
if 'hpss' in debug.debug_flags:
1371
mutter('hpss call w/body stream: %r', args)
1372
path = getattr(self._medium_request._medium, '_path', None)
1373
if path is not None:
1374
mutter(' (to %s)', path)
1375
self._request_start_time = osutils.perf_counter()
1376
self.body_stream_started = False
1377
self._write_protocol_version()
1378
self._write_headers(self._headers)
1379
self._write_structure(args)
1380
# TODO: notice if the server has sent an early error reply before we
1381
# have finished sending the stream. We would notice at the end
1382
# anyway, but if the medium can deliver it early then it's good
1383
# to short-circuit the whole request...
1384
# Provoke any ConnectionReset failures before we start the body stream.
1386
self.body_stream_started = True
1387
for exc_info, part in _iter_with_errors(stream):
1388
if exc_info is not None:
1389
# Iterating the stream failed. Cleanly abort the request.
1390
self._write_error_status()
1391
# Currently the client unconditionally sends ('error',) as the
1393
self._write_structure((b'error',))
1395
self._medium_request.finished_writing()
1401
self._write_prefixed_body(part)
1404
self._medium_request.finished_writing()