1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
29
import thread as _thread
38
from ...sixish import (
42
from . import message, request
43
from ...sixish import text_type
44
from ...trace import log_exception_quietly, mutter
45
from ...bencode import bdecode_as_tuple, bencode
48
# Protocol version strings. These are sent as prefixes of bzr requests and
49
# responses to identify the protocol version being used. (There are no version
50
# one strings because that version doesn't send any).
51
REQUEST_VERSION_TWO = b'bzr request 2\n'
52
RESPONSE_VERSION_TWO = b'bzr response 2\n'
54
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
55
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
58
def _recv_tuple(from_file):
59
req_line = from_file.readline()
60
return _decode_tuple(req_line)
63
def _decode_tuple(req_line):
64
if req_line is None or req_line == b'':
66
if not req_line.endswith(b'\n'):
67
raise errors.SmartProtocolError("request %r not terminated" % req_line)
68
return tuple(req_line[:-1].split(b'\x01'))
71
def _encode_tuple(args):
72
"""Encode the tuple args to a bytestream."""
74
if isinstance(arg, text_type):
76
return b'\x01'.join(args) + b'\n'
79
class Requester(object):
80
"""Abstract base class for an object that can issue requests on a smart
84
def call(self, *args):
85
"""Make a remote call.
87
:param args: the arguments of this call.
89
raise NotImplementedError(self.call)
91
def call_with_body_bytes(self, args, body):
92
"""Make a remote call with a body.
94
:param args: the arguments of this call.
96
:param body: the body to send with the request.
98
raise NotImplementedError(self.call_with_body_bytes)
100
def call_with_body_readv_array(self, args, body):
101
"""Make a remote call with a readv array.
103
:param args: the arguments of this call.
104
:type body: iterable of (start, length) tuples.
105
:param body: the readv ranges to send with this request.
107
raise NotImplementedError(self.call_with_body_readv_array)
109
def set_headers(self, headers):
110
raise NotImplementedError(self.set_headers)
113
class SmartProtocolBase(object):
114
"""Methods common to client and server"""
116
# TODO: this only actually accomodates a single block; possibly should
117
# support multiple chunks?
118
def _encode_bulk_data(self, body):
119
"""Encode body as a bulk data chunk."""
120
return b''.join((b'%d\n' % len(body), body, b'done\n'))
122
def _serialise_offsets(self, offsets):
123
"""Serialise a readv offset list."""
125
for start, length in offsets:
126
txt.append(b'%d,%d' % (start, length))
127
return b'\n'.join(txt)
130
class SmartServerRequestProtocolOne(SmartProtocolBase):
131
"""Server-side encoding and decoding logic for smart version 1."""
133
def __init__(self, backing_transport, write_func, root_client_path='/',
135
self._backing_transport = backing_transport
136
self._root_client_path = root_client_path
137
self._jail_root = jail_root
138
self.unused_data = b''
139
self._finished = False
141
self._has_dispatched = False
143
self._body_decoder = None
144
self._write_func = write_func
146
def accept_bytes(self, data):
147
"""Take bytes, and advance the internal state machine appropriately.
149
:param data: must be a byte string
151
if not isinstance(data, bytes):
152
raise ValueError(data)
153
self.in_buffer += data
154
if not self._has_dispatched:
155
if b'\n' not in self.in_buffer:
156
# no command line yet
158
self._has_dispatched = True
160
first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
162
req_args = _decode_tuple(first_line)
163
self.request = request.SmartServerRequestHandler(
164
self._backing_transport, commands=request.request_handlers,
165
root_client_path=self._root_client_path,
166
jail_root=self._jail_root)
167
self.request.args_received(req_args)
168
if self.request.finished_reading:
170
self.unused_data = self.in_buffer
172
self._send_response(self.request.response)
173
except KeyboardInterrupt:
175
except errors.UnknownSmartMethod as err:
176
protocol_error = errors.SmartProtocolError(
177
"bad request '%s'" % (err.verb.decode('ascii'),))
178
failure = request.FailedSmartServerResponse(
179
(b'error', str(protocol_error).encode('utf-8')))
180
self._send_response(failure)
182
except Exception as exception:
183
# everything else: pass to client, flush, and quit
184
log_exception_quietly()
185
self._send_response(request.FailedSmartServerResponse(
186
(b'error', str(exception).encode('utf-8'))))
189
if self._has_dispatched:
191
# nothing to do.XXX: this routine should be a single state
193
self.unused_data += self.in_buffer
196
if self._body_decoder is None:
197
self._body_decoder = LengthPrefixedBodyDecoder()
198
self._body_decoder.accept_bytes(self.in_buffer)
199
self.in_buffer = self._body_decoder.unused_data
200
body_data = self._body_decoder.read_pending_data()
201
self.request.accept_body(body_data)
202
if self._body_decoder.finished_reading:
203
self.request.end_of_body()
204
if not self.request.finished_reading:
205
raise AssertionError("no more body, request not finished")
206
if self.request.response is not None:
207
self._send_response(self.request.response)
208
self.unused_data = self.in_buffer
211
if self.request.finished_reading:
212
raise AssertionError(
213
"no response and we have finished reading.")
215
def _send_response(self, response):
216
"""Send a smart server response down the output stream."""
218
raise AssertionError('response already sent')
221
self._finished = True
222
self._write_protocol_version()
223
self._write_success_or_failure_prefix(response)
224
self._write_func(_encode_tuple(args))
226
if not isinstance(body, bytes):
227
raise ValueError(body)
228
data = self._encode_bulk_data(body)
229
self._write_func(data)
231
def _write_protocol_version(self):
232
"""Write any prefixes this protocol requires.
234
Version one doesn't send protocol versions.
237
def _write_success_or_failure_prefix(self, response):
238
"""Write the protocol specific success/failure prefix.
240
For SmartServerRequestProtocolOne this is omitted but we
241
call is_successful to ensure that the response is valid.
243
response.is_successful()
245
def next_read_size(self):
248
if self._body_decoder is None:
251
return self._body_decoder.next_read_size()
254
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
255
r"""Version two of the server side of the smart protocol.
257
This prefixes responses with the value of RESPONSE_VERSION_TWO.
260
response_marker = RESPONSE_VERSION_TWO
261
request_marker = REQUEST_VERSION_TWO
263
def _write_success_or_failure_prefix(self, response):
264
"""Write the protocol specific success/failure prefix."""
265
if response.is_successful():
266
self._write_func(b'success\n')
268
self._write_func(b'failed\n')
270
def _write_protocol_version(self):
271
r"""Write any prefixes this protocol requires.
273
Version two sends the value of RESPONSE_VERSION_TWO.
275
self._write_func(self.response_marker)
277
def _send_response(self, response):
278
"""Send a smart server response down the output stream."""
280
raise AssertionError('response already sent')
281
self._finished = True
282
self._write_protocol_version()
283
self._write_success_or_failure_prefix(response)
284
self._write_func(_encode_tuple(response.args))
285
if response.body is not None:
286
if not isinstance(response.body, bytes):
287
raise AssertionError('body must be bytes')
288
if not (response.body_stream is None):
289
raise AssertionError(
290
'body_stream and body cannot both be set')
291
data = self._encode_bulk_data(response.body)
292
self._write_func(data)
293
elif response.body_stream is not None:
294
_send_stream(response.body_stream, self._write_func)
297
def _send_stream(stream, write_func):
298
write_func(b'chunked\n')
299
_send_chunks(stream, write_func)
303
def _send_chunks(stream, write_func):
305
if isinstance(chunk, bytes):
306
data = ("%x\n" % len(chunk)).encode('ascii') + chunk
308
elif isinstance(chunk, request.FailedSmartServerResponse):
310
_send_chunks(chunk.args, write_func)
313
raise errors.BzrError(
314
'Chunks must be str or FailedSmartServerResponse, got %r'
318
class _NeedMoreBytes(Exception):
319
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
323
def __init__(self, count=None):
326
:param count: the total number of bytes needed by the current state.
327
May be None if the number of bytes needed is unknown.
332
class _StatefulDecoder(object):
333
"""Base class for writing state machines to decode byte streams.
335
Subclasses should provide a self.state_accept attribute that accepts bytes
336
and, if appropriate, updates self.state_accept to a different function.
337
accept_bytes will call state_accept as often as necessary to make sure the
338
state machine has progressed as far as possible before it returns.
340
See ProtocolThreeDecoder for an example subclass.
344
self.finished_reading = False
345
self._in_buffer_list = []
346
self._in_buffer_len = 0
347
self.unused_data = b''
348
self.bytes_left = None
349
self._number_needed_bytes = None
351
def _get_in_buffer(self):
352
if len(self._in_buffer_list) == 1:
353
return self._in_buffer_list[0]
354
in_buffer = b''.join(self._in_buffer_list)
355
if len(in_buffer) != self._in_buffer_len:
356
raise AssertionError(
357
"Length of buffer did not match expected value: %s != %s"
358
% self._in_buffer_len, len(in_buffer))
359
self._in_buffer_list = [in_buffer]
362
def _get_in_bytes(self, count):
363
"""Grab X bytes from the input_buffer.
365
Callers should have already checked that self._in_buffer_len is >
366
count. Note, this does not consume the bytes from the buffer. The
367
caller will still need to call _get_in_buffer() and then
368
_set_in_buffer() if they actually need to consume the bytes.
370
# check if we can yield the bytes from just the first entry in our list
371
if len(self._in_buffer_list) == 0:
372
raise AssertionError('Callers must be sure we have buffered bytes'
373
' before calling _get_in_bytes')
374
if len(self._in_buffer_list[0]) > count:
375
return self._in_buffer_list[0][:count]
376
# We can't yield it from the first buffer, so collapse all buffers, and
378
in_buf = self._get_in_buffer()
379
return in_buf[:count]
381
def _set_in_buffer(self, new_buf):
382
if new_buf is not None:
383
if not isinstance(new_buf, bytes):
384
raise TypeError(new_buf)
385
self._in_buffer_list = [new_buf]
386
self._in_buffer_len = len(new_buf)
388
self._in_buffer_list = []
389
self._in_buffer_len = 0
391
def accept_bytes(self, new_buf):
392
"""Decode as much of bytes as possible.
394
If 'new_buf' contains too much data it will be appended to
397
finished_reading will be set when no more data is required. Further
398
data will be appended to self.unused_data.
400
if not isinstance(new_buf, bytes):
401
raise TypeError(new_buf)
402
# accept_bytes is allowed to change the state
403
self._number_needed_bytes = None
404
# lsprof puts a very large amount of time on this specific call for
406
self._in_buffer_list.append(new_buf)
407
self._in_buffer_len += len(new_buf)
409
# Run the function for the current state.
410
current_state = self.state_accept
412
while current_state != self.state_accept:
413
# The current state has changed. Run the function for the new
414
# current state, so that it can:
415
# - decode any unconsumed bytes left in a buffer, and
416
# - signal how many more bytes are expected (via raising
418
current_state = self.state_accept
420
except _NeedMoreBytes as e:
421
self._number_needed_bytes = e.count
424
class ChunkedBodyDecoder(_StatefulDecoder):
425
"""Decoder for chunked body data.
427
This is very similar the HTTP's chunked encoding. See the description of
428
streamed body data in `doc/developers/network-protocol.txt` for details.
432
_StatefulDecoder.__init__(self)
433
self.state_accept = self._state_accept_expecting_header
434
self.chunk_in_progress = None
435
self.chunks = collections.deque()
437
self.error_in_progress = None
439
def next_read_size(self):
440
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
441
# end-of-body marker is 4 bytes: 'END\n'.
442
if self.state_accept == self._state_accept_reading_chunk:
443
# We're expecting more chunk content. So we're expecting at least
444
# the rest of this chunk plus an END chunk.
445
return self.bytes_left + 4
446
elif self.state_accept == self._state_accept_expecting_length:
447
if self._in_buffer_len == 0:
448
# We're expecting a chunk length. There's at least two bytes
449
# left: a digit plus '\n'.
452
# We're in the middle of reading a chunk length. So there's at
453
# least one byte left, the '\n' that terminates the length.
455
elif self.state_accept == self._state_accept_reading_unused:
457
elif self.state_accept == self._state_accept_expecting_header:
458
return max(0, len('chunked\n') - self._in_buffer_len)
460
raise AssertionError("Impossible state: %r" % (self.state_accept,))
462
def read_next_chunk(self):
464
return self.chunks.popleft()
468
def _extract_line(self):
469
in_buf = self._get_in_buffer()
470
pos = in_buf.find(b'\n')
472
# We haven't read a complete line yet, so request more bytes before
474
raise _NeedMoreBytes(1)
476
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
477
self._set_in_buffer(in_buf[pos + 1:])
481
self.unused_data = self._get_in_buffer()
482
self._in_buffer_list = []
483
self._in_buffer_len = 0
484
self.state_accept = self._state_accept_reading_unused
486
error_args = tuple(self.error_in_progress)
487
self.chunks.append(request.FailedSmartServerResponse(error_args))
488
self.error_in_progress = None
489
self.finished_reading = True
491
def _state_accept_expecting_header(self):
492
prefix = self._extract_line()
493
if prefix == b'chunked':
494
self.state_accept = self._state_accept_expecting_length
496
raise errors.SmartProtocolError(
497
'Bad chunked body header: "%s"' % (prefix,))
499
def _state_accept_expecting_length(self):
500
prefix = self._extract_line()
503
self.error_in_progress = []
504
self._state_accept_expecting_length()
506
elif prefix == b'END':
507
# We've read the end-of-body marker.
508
# Any further bytes are unused data, including the bytes left in
513
self.bytes_left = int(prefix, 16)
514
self.chunk_in_progress = b''
515
self.state_accept = self._state_accept_reading_chunk
517
def _state_accept_reading_chunk(self):
518
in_buf = self._get_in_buffer()
519
in_buffer_len = len(in_buf)
520
self.chunk_in_progress += in_buf[:self.bytes_left]
521
self._set_in_buffer(in_buf[self.bytes_left:])
522
self.bytes_left -= in_buffer_len
523
if self.bytes_left <= 0:
524
# Finished with chunk
525
self.bytes_left = None
527
self.error_in_progress.append(self.chunk_in_progress)
529
self.chunks.append(self.chunk_in_progress)
530
self.chunk_in_progress = None
531
self.state_accept = self._state_accept_expecting_length
533
def _state_accept_reading_unused(self):
534
self.unused_data += self._get_in_buffer()
535
self._in_buffer_list = []
538
class LengthPrefixedBodyDecoder(_StatefulDecoder):
539
"""Decodes the length-prefixed bulk data."""
542
_StatefulDecoder.__init__(self)
543
self.state_accept = self._state_accept_expecting_length
544
self.state_read = self._state_read_no_data
546
self._trailer_buffer = b''
548
def next_read_size(self):
549
if self.bytes_left is not None:
550
# Ideally we want to read all the remainder of the body and the
552
return self.bytes_left + 5
553
elif self.state_accept == self._state_accept_reading_trailer:
554
# Just the trailer left
555
return 5 - len(self._trailer_buffer)
556
elif self.state_accept == self._state_accept_expecting_length:
557
# There's still at least 6 bytes left ('\n' to end the length, plus
561
# Reading excess data. Either way, 1 byte at a time is fine.
564
def read_pending_data(self):
565
"""Return any pending data that has been decoded."""
566
return self.state_read()
568
def _state_accept_expecting_length(self):
569
in_buf = self._get_in_buffer()
570
pos = in_buf.find(b'\n')
573
self.bytes_left = int(in_buf[:pos])
574
self._set_in_buffer(in_buf[pos + 1:])
575
self.state_accept = self._state_accept_reading_body
576
self.state_read = self._state_read_body_buffer
578
def _state_accept_reading_body(self):
579
in_buf = self._get_in_buffer()
581
self.bytes_left -= len(in_buf)
582
self._set_in_buffer(None)
583
if self.bytes_left <= 0:
585
if self.bytes_left != 0:
586
self._trailer_buffer = self._body[self.bytes_left:]
587
self._body = self._body[:self.bytes_left]
588
self.bytes_left = None
589
self.state_accept = self._state_accept_reading_trailer
591
def _state_accept_reading_trailer(self):
592
self._trailer_buffer += self._get_in_buffer()
593
self._set_in_buffer(None)
594
# TODO: what if the trailer does not match "done\n"? Should this raise
595
# a ProtocolViolation exception?
596
if self._trailer_buffer.startswith(b'done\n'):
597
self.unused_data = self._trailer_buffer[len(b'done\n'):]
598
self.state_accept = self._state_accept_reading_unused
599
self.finished_reading = True
601
def _state_accept_reading_unused(self):
602
self.unused_data += self._get_in_buffer()
603
self._set_in_buffer(None)
605
def _state_read_no_data(self):
608
def _state_read_body_buffer(self):
614
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
615
message.ResponseHandler):
616
"""The client-side protocol for smart version 1."""
618
def __init__(self, request):
619
"""Construct a SmartClientRequestProtocolOne.
621
:param request: A SmartClientMediumRequest to serialise onto and
624
self._request = request
625
self._body_buffer = None
626
self._request_start_time = None
627
self._last_verb = None
630
def set_headers(self, headers):
631
self._headers = dict(headers)
633
def call(self, *args):
634
if 'hpss' in debug.debug_flags:
635
mutter('hpss call: %s', repr(args)[1:-1])
636
if getattr(self._request._medium, 'base', None) is not None:
637
mutter(' (to %s)', self._request._medium.base)
638
self._request_start_time = osutils.timer_func()
639
self._write_args(args)
640
self._request.finished_writing()
641
self._last_verb = args[0]
643
def call_with_body_bytes(self, args, body):
644
"""Make a remote call of args with body bytes 'body'.
646
After calling this, call read_response_tuple to find the result out.
648
if 'hpss' in debug.debug_flags:
649
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
650
if getattr(self._request._medium, '_path', None) is not None:
652
self._request._medium._path)
653
mutter(' %d bytes', len(body))
654
self._request_start_time = osutils.timer_func()
655
if 'hpssdetail' in debug.debug_flags:
656
mutter('hpss body content: %s', body)
657
self._write_args(args)
658
bytes = self._encode_bulk_data(body)
659
self._request.accept_bytes(bytes)
660
self._request.finished_writing()
661
self._last_verb = args[0]
663
def call_with_body_readv_array(self, args, body):
664
"""Make a remote call with a readv array.
666
The body is encoded with one line per readv offset pair. The numbers in
667
each pair are separated by a comma, and no trailing \\n is emitted.
669
if 'hpss' in debug.debug_flags:
670
mutter('hpss call w/readv: %s', repr(args)[1:-1])
671
if getattr(self._request._medium, '_path', None) is not None:
673
self._request._medium._path)
674
self._request_start_time = osutils.timer_func()
675
self._write_args(args)
676
readv_bytes = self._serialise_offsets(body)
677
bytes = self._encode_bulk_data(readv_bytes)
678
self._request.accept_bytes(bytes)
679
self._request.finished_writing()
680
if 'hpss' in debug.debug_flags:
681
mutter(' %d bytes in readv request', len(readv_bytes))
682
self._last_verb = args[0]
684
def call_with_body_stream(self, args, stream):
685
# Protocols v1 and v2 don't support body streams. So it's safe to
686
# assume that a v1/v2 server doesn't support whatever method we're
687
# trying to call with a body stream.
688
self._request.finished_writing()
689
self._request.finished_reading()
690
raise errors.UnknownSmartMethod(args[0])
692
def cancel_read_body(self):
693
"""After expecting a body, a response code may indicate one otherwise.
695
This method lets the domain client inform the protocol that no body
696
will be transmitted. This is a terminal method: after calling it the
697
protocol is not able to be used further.
699
self._request.finished_reading()
701
def _read_response_tuple(self):
702
result = self._recv_tuple()
703
if 'hpss' in debug.debug_flags:
704
if self._request_start_time is not None:
705
mutter(' result: %6.3fs %s',
706
osutils.timer_func() - self._request_start_time,
708
self._request_start_time = None
710
mutter(' result: %s', repr(result)[1:-1])
713
def read_response_tuple(self, expect_body=False):
714
"""Read a response tuple from the wire.
716
This should only be called once.
718
result = self._read_response_tuple()
719
self._response_is_unknown_method(result)
720
self._raise_args_if_error(result)
722
self._request.finished_reading()
725
def _raise_args_if_error(self, result_tuple):
726
# Later protocol versions have an explicit flag in the protocol to say
727
# if an error response is "failed" or not. In version 1 we don't have
728
# that luxury. So here is a complete list of errors that can be
729
# returned in response to existing version 1 smart requests. Responses
730
# starting with these codes are always "failed" responses.
735
b'DirectoryNotEmpty',
737
b'UnicodeEncodeError',
738
b'UnicodeDecodeError',
744
b'UnlockableTransport',
750
if result_tuple[0] in v1_error_codes:
751
self._request.finished_reading()
752
raise errors.ErrorFromSmartServer(result_tuple)
754
def _response_is_unknown_method(self, result_tuple):
755
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
756
method' response to the request.
758
:param response: The response from a smart client call_expecting_body
760
:param verb: The verb used in that call.
761
:raises: UnexpectedSmartServerResponse
763
if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
764
b"bad request '" + self._last_verb + b"'")
765
or result_tuple == (b'error', b"Generic bzr smart protocol error: "
766
b"bad request u'%s'" % self._last_verb)):
767
# The response will have no body, so we've finished reading.
768
self._request.finished_reading()
769
raise errors.UnknownSmartMethod(self._last_verb)
771
def read_body_bytes(self, count=-1):
772
"""Read bytes from the body, decoding into a byte stream.
774
We read all bytes at once to ensure we've checked the trailer for
775
errors, and then feed the buffer back as read_body_bytes is called.
777
if self._body_buffer is not None:
778
return self._body_buffer.read(count)
779
_body_decoder = LengthPrefixedBodyDecoder()
781
while not _body_decoder.finished_reading:
782
bytes = self._request.read_bytes(_body_decoder.next_read_size())
784
# end of file encountered reading from server
785
raise errors.ConnectionReset(
786
"Connection lost while reading response body.")
787
_body_decoder.accept_bytes(bytes)
788
self._request.finished_reading()
789
self._body_buffer = BytesIO(_body_decoder.read_pending_data())
790
# XXX: TODO check the trailer result.
791
if 'hpss' in debug.debug_flags:
792
mutter(' %d body bytes read',
793
len(self._body_buffer.getvalue()))
794
return self._body_buffer.read(count)
796
def _recv_tuple(self):
797
"""Receive a tuple from the medium request."""
798
return _decode_tuple(self._request.read_line())
800
def query_version(self):
801
"""Return protocol version number of the server."""
803
resp = self.read_response_tuple()
804
if resp == (b'ok', b'1'):
806
elif resp == (b'ok', b'2'):
809
raise errors.SmartProtocolError("bad response %r" % (resp,))
811
def _write_args(self, args):
812
self._write_protocol_version()
813
bytes = _encode_tuple(args)
814
self._request.accept_bytes(bytes)
816
def _write_protocol_version(self):
817
"""Write any prefixes this protocol requires.
819
Version one doesn't send protocol versions.
823
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
824
"""Version two of the client side of the smart protocol.
826
This prefixes the request with the value of REQUEST_VERSION_TWO.
829
response_marker = RESPONSE_VERSION_TWO
830
request_marker = REQUEST_VERSION_TWO
832
def read_response_tuple(self, expect_body=False):
833
"""Read a response tuple from the wire.
835
This should only be called once.
837
version = self._request.read_line()
838
if version != self.response_marker:
839
self._request.finished_reading()
840
raise errors.UnexpectedProtocolVersionMarker(version)
841
response_status = self._request.read_line()
842
result = SmartClientRequestProtocolOne._read_response_tuple(self)
843
self._response_is_unknown_method(result)
844
if response_status == b'success\n':
845
self.response_status = True
847
self._request.finished_reading()
849
elif response_status == b'failed\n':
850
self.response_status = False
851
self._request.finished_reading()
852
raise errors.ErrorFromSmartServer(result)
854
raise errors.SmartProtocolError(
855
'bad protocol status %r' % response_status)
857
def _write_protocol_version(self):
858
"""Write any prefixes this protocol requires.
860
Version two sends the value of REQUEST_VERSION_TWO.
862
self._request.accept_bytes(self.request_marker)
864
def read_streamed_body(self):
865
"""Read bytes from the body, decoding into a byte stream.
867
# Read no more than 64k at a time so that we don't risk error 10055 (no
868
# buffer space available) on Windows.
869
_body_decoder = ChunkedBodyDecoder()
870
while not _body_decoder.finished_reading:
871
bytes = self._request.read_bytes(_body_decoder.next_read_size())
873
# end of file encountered reading from server
874
raise errors.ConnectionReset(
875
"Connection lost while reading streamed body.")
876
_body_decoder.accept_bytes(bytes)
877
for body_bytes in iter(_body_decoder.read_next_chunk, None):
878
if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
879
mutter(' %d byte chunk read',
882
self._request.finished_reading()
885
def build_server_protocol_three(backing_transport, write_func,
886
root_client_path, jail_root=None):
887
request_handler = request.SmartServerRequestHandler(
888
backing_transport, commands=request.request_handlers,
889
root_client_path=root_client_path, jail_root=jail_root)
890
responder = ProtocolThreeResponder(write_func)
891
message_handler = message.ConventionalRequestHandler(
892
request_handler, responder)
893
return ProtocolThreeDecoder(message_handler)
896
class ProtocolThreeDecoder(_StatefulDecoder):
898
response_marker = RESPONSE_VERSION_THREE
899
request_marker = REQUEST_VERSION_THREE
901
def __init__(self, message_handler, expect_version_marker=False):
902
_StatefulDecoder.__init__(self)
903
self._has_dispatched = False
905
if expect_version_marker:
906
self.state_accept = self._state_accept_expecting_protocol_version
907
# We're expecting at least the protocol version marker + some
909
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
911
self.state_accept = self._state_accept_expecting_headers
912
self._number_needed_bytes = 4
913
self.decoding_failed = False
914
self.request_handler = self.message_handler = message_handler
916
def accept_bytes(self, bytes):
917
self._number_needed_bytes = None
919
_StatefulDecoder.accept_bytes(self, bytes)
920
except KeyboardInterrupt:
922
except errors.SmartMessageHandlerError as exception:
923
# We do *not* set self.decoding_failed here. The message handler
924
# has raised an error, but the decoder is still able to parse bytes
925
# and determine when this message ends.
926
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
927
log_exception_quietly()
928
self.message_handler.protocol_error(exception.exc_value)
929
# The state machine is ready to continue decoding, but the
930
# exception has interrupted the loop that runs the state machine.
931
# So we call accept_bytes again to restart it.
932
self.accept_bytes(b'')
933
except Exception as exception:
934
# The decoder itself has raised an exception. We cannot continue
936
self.decoding_failed = True
937
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
938
# This happens during normal operation when the client tries a
939
# protocol version the server doesn't understand, so no need to
940
# log a traceback every time.
941
# Note that this can only happen when
942
# expect_version_marker=True, which is only the case on the
946
log_exception_quietly()
947
self.message_handler.protocol_error(exception)
949
def _extract_length_prefixed_bytes(self):
950
if self._in_buffer_len < 4:
951
# A length prefix by itself is 4 bytes, and we don't even have that
953
raise _NeedMoreBytes(4)
954
(length,) = struct.unpack('!L', self._get_in_bytes(4))
955
end_of_bytes = 4 + length
956
if self._in_buffer_len < end_of_bytes:
957
# We haven't yet read as many bytes as the length-prefix says there
959
raise _NeedMoreBytes(end_of_bytes)
960
# Extract the bytes from the buffer.
961
in_buf = self._get_in_buffer()
962
bytes = in_buf[4:end_of_bytes]
963
self._set_in_buffer(in_buf[end_of_bytes:])
966
def _extract_prefixed_bencoded_data(self):
967
prefixed_bytes = self._extract_length_prefixed_bytes()
969
decoded = bdecode_as_tuple(prefixed_bytes)
971
raise errors.SmartProtocolError(
972
'Bytes %r not bencoded' % (prefixed_bytes,))
975
def _extract_single_byte(self):
976
if self._in_buffer_len == 0:
977
# The buffer is empty
978
raise _NeedMoreBytes(1)
979
in_buf = self._get_in_buffer()
980
one_byte = in_buf[0:1]
981
self._set_in_buffer(in_buf[1:])
984
def _state_accept_expecting_protocol_version(self):
985
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
986
in_buf = self._get_in_buffer()
988
# We don't have enough bytes to check if the protocol version
989
# marker is right. But we can check if it is already wrong by
990
# checking that the start of MESSAGE_VERSION_THREE matches what
992
# [In fact, if the remote end isn't bzr we might never receive
993
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
994
# are wrong then we should just raise immediately rather than
996
if not MESSAGE_VERSION_THREE.startswith(in_buf):
997
# We have enough bytes to know the protocol version is wrong
998
raise errors.UnexpectedProtocolVersionMarker(in_buf)
999
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
1000
if not in_buf.startswith(MESSAGE_VERSION_THREE):
1001
raise errors.UnexpectedProtocolVersionMarker(in_buf)
1002
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
1003
self.state_accept = self._state_accept_expecting_headers
1005
def _state_accept_expecting_headers(self):
1006
decoded = self._extract_prefixed_bencoded_data()
1007
if not isinstance(decoded, dict):
1008
raise errors.SmartProtocolError(
1009
'Header object %r is not a dict' % (decoded,))
1010
self.state_accept = self._state_accept_expecting_message_part
1012
self.message_handler.headers_received(decoded)
1014
raise errors.SmartMessageHandlerError(sys.exc_info())
1016
def _state_accept_expecting_message_part(self):
1017
message_part_kind = self._extract_single_byte()
1018
if message_part_kind == b'o':
1019
self.state_accept = self._state_accept_expecting_one_byte
1020
elif message_part_kind == b's':
1021
self.state_accept = self._state_accept_expecting_structure
1022
elif message_part_kind == b'b':
1023
self.state_accept = self._state_accept_expecting_bytes
1024
elif message_part_kind == b'e':
1027
raise errors.SmartProtocolError(
1028
'Bad message kind byte: %r' % (message_part_kind,))
1030
def _state_accept_expecting_one_byte(self):
1031
byte = self._extract_single_byte()
1032
self.state_accept = self._state_accept_expecting_message_part
1034
self.message_handler.byte_part_received(byte)
1036
raise errors.SmartMessageHandlerError(sys.exc_info())
1038
def _state_accept_expecting_bytes(self):
1039
# XXX: this should not buffer whole message part, but instead deliver
1040
# the bytes as they arrive.
1041
prefixed_bytes = self._extract_length_prefixed_bytes()
1042
self.state_accept = self._state_accept_expecting_message_part
1044
self.message_handler.bytes_part_received(prefixed_bytes)
1046
raise errors.SmartMessageHandlerError(sys.exc_info())
1048
def _state_accept_expecting_structure(self):
1049
structure = self._extract_prefixed_bencoded_data()
1050
self.state_accept = self._state_accept_expecting_message_part
1052
self.message_handler.structure_part_received(structure)
1054
raise errors.SmartMessageHandlerError(sys.exc_info())
1057
self.unused_data = self._get_in_buffer()
1058
self._set_in_buffer(None)
1059
self.state_accept = self._state_accept_reading_unused
1061
self.message_handler.end_received()
1063
raise errors.SmartMessageHandlerError(sys.exc_info())
1065
def _state_accept_reading_unused(self):
1066
self.unused_data += self._get_in_buffer()
1067
self._set_in_buffer(None)
1069
def next_read_size(self):
1070
if self.state_accept == self._state_accept_reading_unused:
1072
elif self.decoding_failed:
1073
# An exception occured while processing this message, probably from
1074
# self.message_handler. We're not sure that this state machine is
1075
# in a consistent state, so just signal that we're done (i.e. give
1079
if self._number_needed_bytes is not None:
1080
return self._number_needed_bytes - self._in_buffer_len
1082
raise AssertionError("don't know how many bytes are expected!")
1085
class _ProtocolThreeEncoder(object):
1087
response_marker = request_marker = MESSAGE_VERSION_THREE
1088
BUFFER_SIZE = 1024 * 1024 # 1 MiB buffer before flushing
1090
def __init__(self, write_func):
1093
self._real_write_func = write_func
1095
def _write_func(self, bytes):
1096
# TODO: Another possibility would be to turn this into an async model.
1097
# Where we let another thread know that we have some bytes if
1098
# they want it, but we don't actually block for it
1099
# Note that osutils.send_all always sends 64kB chunks anyway, so
1100
# we might just push out smaller bits at a time?
1101
self._buf.append(bytes)
1102
self._buf_len += len(bytes)
1103
if self._buf_len > self.BUFFER_SIZE:
1108
self._real_write_func(b''.join(self._buf))
1112
def _serialise_offsets(self, offsets):
1113
"""Serialise a readv offset list."""
1115
for start, length in offsets:
1116
txt.append(b'%d,%d' % (start, length))
1117
return b'\n'.join(txt)
1119
def _write_protocol_version(self):
1120
self._write_func(MESSAGE_VERSION_THREE)
1122
def _write_prefixed_bencode(self, structure):
1123
bytes = bencode(structure)
1124
self._write_func(struct.pack('!L', len(bytes)))
1125
self._write_func(bytes)
1127
def _write_headers(self, headers):
1128
self._write_prefixed_bencode(headers)
1130
def _write_structure(self, args):
1131
self._write_func(b's')
1134
if isinstance(arg, text_type):
1135
utf8_args.append(arg.encode('utf8'))
1137
utf8_args.append(arg)
1138
self._write_prefixed_bencode(utf8_args)
1140
def _write_end(self):
1141
self._write_func(b'e')
1144
def _write_prefixed_body(self, bytes):
1145
self._write_func(b'b')
1146
self._write_func(struct.pack('!L', len(bytes)))
1147
self._write_func(bytes)
1149
def _write_chunked_body_start(self):
1150
self._write_func(b'oC')
1152
def _write_error_status(self):
1153
self._write_func(b'oE')
1155
def _write_success_status(self):
1156
self._write_func(b'oS')
1159
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1161
def __init__(self, write_func):
1162
_ProtocolThreeEncoder.__init__(self, write_func)
1163
self.response_sent = False
1165
b'Software version': breezy.__version__.encode('utf-8')}
1166
if 'hpss' in debug.debug_flags:
1167
self._thread_id = _thread.get_ident()
1168
self._response_start_time = None
1170
def _trace(self, action, message, extra_bytes=None, include_time=False):
1171
if self._response_start_time is None:
1172
self._response_start_time = osutils.timer_func()
1174
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1177
if extra_bytes is None:
1180
extra = ' ' + repr(extra_bytes[:40])
1182
extra = extra[:29] + extra[-1] + '...'
1183
mutter('%12s: [%s] %s%s%s'
1184
% (action, self._thread_id, t, message, extra))
1186
def send_error(self, exception):
1187
if self.response_sent:
1188
raise AssertionError(
1189
"send_error(%s) called, but response already sent."
1191
if isinstance(exception, errors.UnknownSmartMethod):
1192
failure = request.FailedSmartServerResponse(
1193
(b'UnknownMethod', exception.verb))
1194
self.send_response(failure)
1196
if 'hpss' in debug.debug_flags:
1197
self._trace('error', str(exception))
1198
self.response_sent = True
1199
self._write_protocol_version()
1200
self._write_headers(self._headers)
1201
self._write_error_status()
1202
self._write_structure(
1203
(b'error', str(exception).encode('utf-8', 'replace')))
1206
def send_response(self, response):
1207
if self.response_sent:
1208
raise AssertionError(
1209
"send_response(%r) called, but response already sent."
1211
self.response_sent = True
1212
self._write_protocol_version()
1213
self._write_headers(self._headers)
1214
if response.is_successful():
1215
self._write_success_status()
1217
self._write_error_status()
1218
if 'hpss' in debug.debug_flags:
1219
self._trace('response', repr(response.args))
1220
self._write_structure(response.args)
1221
if response.body is not None:
1222
self._write_prefixed_body(response.body)
1223
if 'hpss' in debug.debug_flags:
1224
self._trace('body', '%d bytes' % (len(response.body),),
1225
response.body, include_time=True)
1226
elif response.body_stream is not None:
1227
count = num_bytes = 0
1229
for exc_info, chunk in _iter_with_errors(response.body_stream):
1231
if exc_info is not None:
1232
self._write_error_status()
1233
error_struct = request._translate_error(exc_info[1])
1234
self._write_structure(error_struct)
1237
if isinstance(chunk, request.FailedSmartServerResponse):
1238
self._write_error_status()
1239
self._write_structure(chunk.args)
1241
num_bytes += len(chunk)
1242
if first_chunk is None:
1244
self._write_prefixed_body(chunk)
1246
if 'hpssdetail' in debug.debug_flags:
1247
# Not worth timing separately, as _write_func is
1249
self._trace('body chunk',
1250
'%d bytes' % (len(chunk),),
1251
chunk, suppress_time=True)
1252
if 'hpss' in debug.debug_flags:
1253
self._trace('body stream',
1254
'%d bytes %d chunks' % (num_bytes, count),
1257
if 'hpss' in debug.debug_flags:
1258
self._trace('response end', '', include_time=True)
1261
def _iter_with_errors(iterable):
1262
"""Handle errors from iterable.next().
1266
for exc_info, value in _iter_with_errors(iterable):
1269
This is a safer alternative to::
1272
for value in iterable:
1277
Because the latter will catch errors from the for-loop body, not just
1280
If an error occurs, exc_info will be a exc_info tuple, and the generator
1281
will terminate. Otherwise exc_info will be None, and value will be the
1282
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1283
will not be itercepted.
1285
iterator = iter(iterable)
1288
yield None, next(iterator)
1289
except StopIteration:
1291
except (KeyboardInterrupt, SystemExit):
1294
mutter('_iter_with_errors caught error')
1295
log_exception_quietly()
1296
yield sys.exc_info(), None
1300
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1302
def __init__(self, medium_request):
1303
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1304
self._medium_request = medium_request
1306
self.body_stream_started = None
1308
def set_headers(self, headers):
1309
self._headers = headers.copy()
1311
def call(self, *args):
1312
if 'hpss' in debug.debug_flags:
1313
mutter('hpss call: %s', repr(args)[1:-1])
1314
base = getattr(self._medium_request._medium, 'base', None)
1315
if base is not None:
1316
mutter(' (to %s)', base)
1317
self._request_start_time = osutils.timer_func()
1318
self._write_protocol_version()
1319
self._write_headers(self._headers)
1320
self._write_structure(args)
1322
self._medium_request.finished_writing()
1324
def call_with_body_bytes(self, args, body):
1325
"""Make a remote call of args with body bytes 'body'.
1327
After calling this, call read_response_tuple to find the result out.
1329
if 'hpss' in debug.debug_flags:
1330
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1331
path = getattr(self._medium_request._medium, '_path', None)
1332
if path is not None:
1333
mutter(' (to %s)', path)
1334
mutter(' %d bytes', len(body))
1335
self._request_start_time = osutils.timer_func()
1336
self._write_protocol_version()
1337
self._write_headers(self._headers)
1338
self._write_structure(args)
1339
self._write_prefixed_body(body)
1341
self._medium_request.finished_writing()
1343
def call_with_body_readv_array(self, args, body):
1344
"""Make a remote call with a readv array.
1346
The body is encoded with one line per readv offset pair. The numbers in
1347
each pair are separated by a comma, and no trailing \\n is emitted.
1349
if 'hpss' in debug.debug_flags:
1350
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1351
path = getattr(self._medium_request._medium, '_path', None)
1352
if path is not None:
1353
mutter(' (to %s)', path)
1354
self._request_start_time = osutils.timer_func()
1355
self._write_protocol_version()
1356
self._write_headers(self._headers)
1357
self._write_structure(args)
1358
readv_bytes = self._serialise_offsets(body)
1359
if 'hpss' in debug.debug_flags:
1360
mutter(' %d bytes in readv request', len(readv_bytes))
1361
self._write_prefixed_body(readv_bytes)
1363
self._medium_request.finished_writing()
1365
def call_with_body_stream(self, args, stream):
1366
if 'hpss' in debug.debug_flags:
1367
mutter('hpss call w/body stream: %r', args)
1368
path = getattr(self._medium_request._medium, '_path', None)
1369
if path is not None:
1370
mutter(' (to %s)', path)
1371
self._request_start_time = osutils.timer_func()
1372
self.body_stream_started = False
1373
self._write_protocol_version()
1374
self._write_headers(self._headers)
1375
self._write_structure(args)
1376
# TODO: notice if the server has sent an early error reply before we
1377
# have finished sending the stream. We would notice at the end
1378
# anyway, but if the medium can deliver it early then it's good
1379
# to short-circuit the whole request...
1380
# Provoke any ConnectionReset failures before we start the body stream.
1382
self.body_stream_started = True
1383
for exc_info, part in _iter_with_errors(stream):
1384
if exc_info is not None:
1385
# Iterating the stream failed. Cleanly abort the request.
1386
self._write_error_status()
1387
# Currently the client unconditionally sends ('error',) as the
1389
self._write_structure((b'error',))
1391
self._medium_request.finished_writing()
1397
self._write_prefixed_body(part)
1400
self._medium_request.finished_writing()