/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/protocol.py

  • Committer: Jelmer Vernooij
  • Date: 2020-01-31 17:43:44 UTC
  • mto: This revision was merged to the branch mainline in revision 7478.
  • Revision ID: jelmer@jelmer.uk-20200131174344-qjhgqm7bdkuqj9sj
Default to running Python 3.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
from __future__ import absolute_import
 
22
 
 
23
try:
 
24
    from collections.abc import deque
 
25
except ImportError:  # python < 3.7
 
26
    from collections import deque
 
27
 
 
28
import struct
 
29
import sys
 
30
try:
 
31
    import _thread
 
32
except ImportError:
 
33
    import thread as _thread
 
34
import time
 
35
 
 
36
import breezy
 
37
from ... import (
 
38
    debug,
 
39
    errors,
 
40
    osutils,
 
41
    )
 
42
from ...sixish import (
 
43
    BytesIO,
 
44
    reraise,
 
45
)
 
46
from . import message, request
 
47
from ...sixish import text_type
 
48
from ...trace import log_exception_quietly, mutter
 
49
from ...bencode import bdecode_as_tuple, bencode
 
50
 
 
51
 
 
52
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
53
# responses to identify the protocol version being used. (There are no version
 
54
# one strings because that version doesn't send any).
 
55
REQUEST_VERSION_TWO = b'bzr request 2\n'
 
56
RESPONSE_VERSION_TWO = b'bzr response 2\n'
 
57
 
 
58
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
 
59
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
60
 
 
61
 
 
62
def _recv_tuple(from_file):
 
63
    req_line = from_file.readline()
 
64
    return _decode_tuple(req_line)
 
65
 
 
66
 
 
67
def _decode_tuple(req_line):
 
68
    if req_line is None or req_line == b'':
 
69
        return None
 
70
    if not req_line.endswith(b'\n'):
 
71
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
72
    return tuple(req_line[:-1].split(b'\x01'))
 
73
 
 
74
 
 
75
def _encode_tuple(args):
 
76
    """Encode the tuple args to a bytestream."""
 
77
    for arg in args:
 
78
        if isinstance(arg, text_type):
 
79
            raise TypeError(args)
 
80
    return b'\x01'.join(args) + b'\n'
 
81
 
 
82
 
 
83
class Requester(object):
 
84
    """Abstract base class for an object that can issue requests on a smart
 
85
    medium.
 
86
    """
 
87
 
 
88
    def call(self, *args):
 
89
        """Make a remote call.
 
90
 
 
91
        :param args: the arguments of this call.
 
92
        """
 
93
        raise NotImplementedError(self.call)
 
94
 
 
95
    def call_with_body_bytes(self, args, body):
 
96
        """Make a remote call with a body.
 
97
 
 
98
        :param args: the arguments of this call.
 
99
        :type body: str
 
100
        :param body: the body to send with the request.
 
101
        """
 
102
        raise NotImplementedError(self.call_with_body_bytes)
 
103
 
 
104
    def call_with_body_readv_array(self, args, body):
 
105
        """Make a remote call with a readv array.
 
106
 
 
107
        :param args: the arguments of this call.
 
108
        :type body: iterable of (start, length) tuples.
 
109
        :param body: the readv ranges to send with this request.
 
110
        """
 
111
        raise NotImplementedError(self.call_with_body_readv_array)
 
112
 
 
113
    def set_headers(self, headers):
 
114
        raise NotImplementedError(self.set_headers)
 
115
 
 
116
 
 
117
class SmartProtocolBase(object):
 
118
    """Methods common to client and server"""
 
119
 
 
120
    # TODO: this only actually accomodates a single block; possibly should
 
121
    # support multiple chunks?
 
122
    def _encode_bulk_data(self, body):
 
123
        """Encode body as a bulk data chunk."""
 
124
        return b''.join((b'%d\n' % len(body), body, b'done\n'))
 
125
 
 
126
    def _serialise_offsets(self, offsets):
 
127
        """Serialise a readv offset list."""
 
128
        txt = []
 
129
        for start, length in offsets:
 
130
            txt.append(b'%d,%d' % (start, length))
 
131
        return b'\n'.join(txt)
 
132
 
 
133
 
 
134
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
135
    """Server-side encoding and decoding logic for smart version 1."""
 
136
 
 
137
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
138
                 jail_root=None):
 
139
        self._backing_transport = backing_transport
 
140
        self._root_client_path = root_client_path
 
141
        self._jail_root = jail_root
 
142
        self.unused_data = b''
 
143
        self._finished = False
 
144
        self.in_buffer = b''
 
145
        self._has_dispatched = False
 
146
        self.request = None
 
147
        self._body_decoder = None
 
148
        self._write_func = write_func
 
149
 
 
150
    def accept_bytes(self, data):
 
151
        """Take bytes, and advance the internal state machine appropriately.
 
152
 
 
153
        :param data: must be a byte string
 
154
        """
 
155
        if not isinstance(data, bytes):
 
156
            raise ValueError(data)
 
157
        self.in_buffer += data
 
158
        if not self._has_dispatched:
 
159
            if b'\n' not in self.in_buffer:
 
160
                # no command line yet
 
161
                return
 
162
            self._has_dispatched = True
 
163
            try:
 
164
                first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
 
165
                first_line += b'\n'
 
166
                req_args = _decode_tuple(first_line)
 
167
                self.request = request.SmartServerRequestHandler(
 
168
                    self._backing_transport, commands=request.request_handlers,
 
169
                    root_client_path=self._root_client_path,
 
170
                    jail_root=self._jail_root)
 
171
                self.request.args_received(req_args)
 
172
                if self.request.finished_reading:
 
173
                    # trivial request
 
174
                    self.unused_data = self.in_buffer
 
175
                    self.in_buffer = b''
 
176
                    self._send_response(self.request.response)
 
177
            except KeyboardInterrupt:
 
178
                raise
 
179
            except errors.UnknownSmartMethod as err:
 
180
                protocol_error = errors.SmartProtocolError(
 
181
                    "bad request '%s'" % (err.verb.decode('ascii'),))
 
182
                failure = request.FailedSmartServerResponse(
 
183
                    (b'error', str(protocol_error).encode('utf-8')))
 
184
                self._send_response(failure)
 
185
                return
 
186
            except Exception as exception:
 
187
                # everything else: pass to client, flush, and quit
 
188
                log_exception_quietly()
 
189
                self._send_response(request.FailedSmartServerResponse(
 
190
                    (b'error', str(exception).encode('utf-8'))))
 
191
                return
 
192
 
 
193
        if self._has_dispatched:
 
194
            if self._finished:
 
195
                # nothing to do.XXX: this routine should be a single state
 
196
                # machine too.
 
197
                self.unused_data += self.in_buffer
 
198
                self.in_buffer = b''
 
199
                return
 
200
            if self._body_decoder is None:
 
201
                self._body_decoder = LengthPrefixedBodyDecoder()
 
202
            self._body_decoder.accept_bytes(self.in_buffer)
 
203
            self.in_buffer = self._body_decoder.unused_data
 
204
            body_data = self._body_decoder.read_pending_data()
 
205
            self.request.accept_body(body_data)
 
206
            if self._body_decoder.finished_reading:
 
207
                self.request.end_of_body()
 
208
                if not self.request.finished_reading:
 
209
                    raise AssertionError("no more body, request not finished")
 
210
            if self.request.response is not None:
 
211
                self._send_response(self.request.response)
 
212
                self.unused_data = self.in_buffer
 
213
                self.in_buffer = b''
 
214
            else:
 
215
                if self.request.finished_reading:
 
216
                    raise AssertionError(
 
217
                        "no response and we have finished reading.")
 
218
 
 
219
    def _send_response(self, response):
 
220
        """Send a smart server response down the output stream."""
 
221
        if self._finished:
 
222
            raise AssertionError('response already sent')
 
223
        args = response.args
 
224
        body = response.body
 
225
        self._finished = True
 
226
        self._write_protocol_version()
 
227
        self._write_success_or_failure_prefix(response)
 
228
        self._write_func(_encode_tuple(args))
 
229
        if body is not None:
 
230
            if not isinstance(body, bytes):
 
231
                raise ValueError(body)
 
232
            data = self._encode_bulk_data(body)
 
233
            self._write_func(data)
 
234
 
 
235
    def _write_protocol_version(self):
 
236
        """Write any prefixes this protocol requires.
 
237
 
 
238
        Version one doesn't send protocol versions.
 
239
        """
 
240
 
 
241
    def _write_success_or_failure_prefix(self, response):
 
242
        """Write the protocol specific success/failure prefix.
 
243
 
 
244
        For SmartServerRequestProtocolOne this is omitted but we
 
245
        call is_successful to ensure that the response is valid.
 
246
        """
 
247
        response.is_successful()
 
248
 
 
249
    def next_read_size(self):
 
250
        if self._finished:
 
251
            return 0
 
252
        if self._body_decoder is None:
 
253
            return 1
 
254
        else:
 
255
            return self._body_decoder.next_read_size()
 
256
 
 
257
 
 
258
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
259
    r"""Version two of the server side of the smart protocol.
 
260
 
 
261
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
262
    """
 
263
 
 
264
    response_marker = RESPONSE_VERSION_TWO
 
265
    request_marker = REQUEST_VERSION_TWO
 
266
 
 
267
    def _write_success_or_failure_prefix(self, response):
 
268
        """Write the protocol specific success/failure prefix."""
 
269
        if response.is_successful():
 
270
            self._write_func(b'success\n')
 
271
        else:
 
272
            self._write_func(b'failed\n')
 
273
 
 
274
    def _write_protocol_version(self):
 
275
        r"""Write any prefixes this protocol requires.
 
276
 
 
277
        Version two sends the value of RESPONSE_VERSION_TWO.
 
278
        """
 
279
        self._write_func(self.response_marker)
 
280
 
 
281
    def _send_response(self, response):
 
282
        """Send a smart server response down the output stream."""
 
283
        if (self._finished):
 
284
            raise AssertionError('response already sent')
 
285
        self._finished = True
 
286
        self._write_protocol_version()
 
287
        self._write_success_or_failure_prefix(response)
 
288
        self._write_func(_encode_tuple(response.args))
 
289
        if response.body is not None:
 
290
            if not isinstance(response.body, bytes):
 
291
                raise AssertionError('body must be bytes')
 
292
            if not (response.body_stream is None):
 
293
                raise AssertionError(
 
294
                    'body_stream and body cannot both be set')
 
295
            data = self._encode_bulk_data(response.body)
 
296
            self._write_func(data)
 
297
        elif response.body_stream is not None:
 
298
            _send_stream(response.body_stream, self._write_func)
 
299
 
 
300
 
 
301
def _send_stream(stream, write_func):
 
302
    write_func(b'chunked\n')
 
303
    _send_chunks(stream, write_func)
 
304
    write_func(b'END\n')
 
305
 
 
306
 
 
307
def _send_chunks(stream, write_func):
 
308
    for chunk in stream:
 
309
        if isinstance(chunk, bytes):
 
310
            data = ("%x\n" % len(chunk)).encode('ascii') + chunk
 
311
            write_func(data)
 
312
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
313
            write_func(b'ERR\n')
 
314
            _send_chunks(chunk.args, write_func)
 
315
            return
 
316
        else:
 
317
            raise errors.BzrError(
 
318
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
319
                % chunk)
 
320
 
 
321
 
 
322
class _NeedMoreBytes(Exception):
 
323
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
324
    have been received.
 
325
    """
 
326
 
 
327
    def __init__(self, count=None):
 
328
        """Constructor.
 
329
 
 
330
        :param count: the total number of bytes needed by the current state.
 
331
            May be None if the number of bytes needed is unknown.
 
332
        """
 
333
        self.count = count
 
334
 
 
335
 
 
336
class _StatefulDecoder(object):
 
337
    """Base class for writing state machines to decode byte streams.
 
338
 
 
339
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
340
    and, if appropriate, updates self.state_accept to a different function.
 
341
    accept_bytes will call state_accept as often as necessary to make sure the
 
342
    state machine has progressed as far as possible before it returns.
 
343
 
 
344
    See ProtocolThreeDecoder for an example subclass.
 
345
    """
 
346
 
 
347
    def __init__(self):
 
348
        self.finished_reading = False
 
349
        self._in_buffer_list = []
 
350
        self._in_buffer_len = 0
 
351
        self.unused_data = b''
 
352
        self.bytes_left = None
 
353
        self._number_needed_bytes = None
 
354
 
 
355
    def _get_in_buffer(self):
 
356
        if len(self._in_buffer_list) == 1:
 
357
            return self._in_buffer_list[0]
 
358
        in_buffer = b''.join(self._in_buffer_list)
 
359
        if len(in_buffer) != self._in_buffer_len:
 
360
            raise AssertionError(
 
361
                "Length of buffer did not match expected value: %s != %s"
 
362
                % self._in_buffer_len, len(in_buffer))
 
363
        self._in_buffer_list = [in_buffer]
 
364
        return in_buffer
 
365
 
 
366
    def _get_in_bytes(self, count):
 
367
        """Grab X bytes from the input_buffer.
 
368
 
 
369
        Callers should have already checked that self._in_buffer_len is >
 
370
        count. Note, this does not consume the bytes from the buffer. The
 
371
        caller will still need to call _get_in_buffer() and then
 
372
        _set_in_buffer() if they actually need to consume the bytes.
 
373
        """
 
374
        # check if we can yield the bytes from just the first entry in our list
 
375
        if len(self._in_buffer_list) == 0:
 
376
            raise AssertionError('Callers must be sure we have buffered bytes'
 
377
                                 ' before calling _get_in_bytes')
 
378
        if len(self._in_buffer_list[0]) > count:
 
379
            return self._in_buffer_list[0][:count]
 
380
        # We can't yield it from the first buffer, so collapse all buffers, and
 
381
        # yield it from that
 
382
        in_buf = self._get_in_buffer()
 
383
        return in_buf[:count]
 
384
 
 
385
    def _set_in_buffer(self, new_buf):
 
386
        if new_buf is not None:
 
387
            if not isinstance(new_buf, bytes):
 
388
                raise TypeError(new_buf)
 
389
            self._in_buffer_list = [new_buf]
 
390
            self._in_buffer_len = len(new_buf)
 
391
        else:
 
392
            self._in_buffer_list = []
 
393
            self._in_buffer_len = 0
 
394
 
 
395
    def accept_bytes(self, new_buf):
 
396
        """Decode as much of bytes as possible.
 
397
 
 
398
        If 'new_buf' contains too much data it will be appended to
 
399
        self.unused_data.
 
400
 
 
401
        finished_reading will be set when no more data is required.  Further
 
402
        data will be appended to self.unused_data.
 
403
        """
 
404
        if not isinstance(new_buf, bytes):
 
405
            raise TypeError(new_buf)
 
406
        # accept_bytes is allowed to change the state
 
407
        self._number_needed_bytes = None
 
408
        # lsprof puts a very large amount of time on this specific call for
 
409
        # large readv arrays
 
410
        self._in_buffer_list.append(new_buf)
 
411
        self._in_buffer_len += len(new_buf)
 
412
        try:
 
413
            # Run the function for the current state.
 
414
            current_state = self.state_accept
 
415
            self.state_accept()
 
416
            while current_state != self.state_accept:
 
417
                # The current state has changed.  Run the function for the new
 
418
                # current state, so that it can:
 
419
                #   - decode any unconsumed bytes left in a buffer, and
 
420
                #   - signal how many more bytes are expected (via raising
 
421
                #     _NeedMoreBytes).
 
422
                current_state = self.state_accept
 
423
                self.state_accept()
 
424
        except _NeedMoreBytes as e:
 
425
            self._number_needed_bytes = e.count
 
426
 
 
427
 
 
428
class ChunkedBodyDecoder(_StatefulDecoder):
 
429
    """Decoder for chunked body data.
 
430
 
 
431
    This is very similar the HTTP's chunked encoding.  See the description of
 
432
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
433
    """
 
434
 
 
435
    def __init__(self):
 
436
        _StatefulDecoder.__init__(self)
 
437
        self.state_accept = self._state_accept_expecting_header
 
438
        self.chunk_in_progress = None
 
439
        self.chunks = deque()
 
440
        self.error = False
 
441
        self.error_in_progress = None
 
442
 
 
443
    def next_read_size(self):
 
444
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
445
        # end-of-body marker is 4 bytes: 'END\n'.
 
446
        if self.state_accept == self._state_accept_reading_chunk:
 
447
            # We're expecting more chunk content.  So we're expecting at least
 
448
            # the rest of this chunk plus an END chunk.
 
449
            return self.bytes_left + 4
 
450
        elif self.state_accept == self._state_accept_expecting_length:
 
451
            if self._in_buffer_len == 0:
 
452
                # We're expecting a chunk length.  There's at least two bytes
 
453
                # left: a digit plus '\n'.
 
454
                return 2
 
455
            else:
 
456
                # We're in the middle of reading a chunk length.  So there's at
 
457
                # least one byte left, the '\n' that terminates the length.
 
458
                return 1
 
459
        elif self.state_accept == self._state_accept_reading_unused:
 
460
            return 1
 
461
        elif self.state_accept == self._state_accept_expecting_header:
 
462
            return max(0, len('chunked\n') - self._in_buffer_len)
 
463
        else:
 
464
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
465
 
 
466
    def read_next_chunk(self):
 
467
        try:
 
468
            return self.chunks.popleft()
 
469
        except IndexError:
 
470
            return None
 
471
 
 
472
    def _extract_line(self):
 
473
        in_buf = self._get_in_buffer()
 
474
        pos = in_buf.find(b'\n')
 
475
        if pos == -1:
 
476
            # We haven't read a complete line yet, so request more bytes before
 
477
            # we continue.
 
478
            raise _NeedMoreBytes(1)
 
479
        line = in_buf[:pos]
 
480
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
481
        self._set_in_buffer(in_buf[pos + 1:])
 
482
        return line
 
483
 
 
484
    def _finished(self):
 
485
        self.unused_data = self._get_in_buffer()
 
486
        self._in_buffer_list = []
 
487
        self._in_buffer_len = 0
 
488
        self.state_accept = self._state_accept_reading_unused
 
489
        if self.error:
 
490
            error_args = tuple(self.error_in_progress)
 
491
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
492
            self.error_in_progress = None
 
493
        self.finished_reading = True
 
494
 
 
495
    def _state_accept_expecting_header(self):
 
496
        prefix = self._extract_line()
 
497
        if prefix == b'chunked':
 
498
            self.state_accept = self._state_accept_expecting_length
 
499
        else:
 
500
            raise errors.SmartProtocolError(
 
501
                'Bad chunked body header: "%s"' % (prefix,))
 
502
 
 
503
    def _state_accept_expecting_length(self):
 
504
        prefix = self._extract_line()
 
505
        if prefix == b'ERR':
 
506
            self.error = True
 
507
            self.error_in_progress = []
 
508
            self._state_accept_expecting_length()
 
509
            return
 
510
        elif prefix == b'END':
 
511
            # We've read the end-of-body marker.
 
512
            # Any further bytes are unused data, including the bytes left in
 
513
            # the _in_buffer.
 
514
            self._finished()
 
515
            return
 
516
        else:
 
517
            self.bytes_left = int(prefix, 16)
 
518
            self.chunk_in_progress = b''
 
519
            self.state_accept = self._state_accept_reading_chunk
 
520
 
 
521
    def _state_accept_reading_chunk(self):
 
522
        in_buf = self._get_in_buffer()
 
523
        in_buffer_len = len(in_buf)
 
524
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
525
        self._set_in_buffer(in_buf[self.bytes_left:])
 
526
        self.bytes_left -= in_buffer_len
 
527
        if self.bytes_left <= 0:
 
528
            # Finished with chunk
 
529
            self.bytes_left = None
 
530
            if self.error:
 
531
                self.error_in_progress.append(self.chunk_in_progress)
 
532
            else:
 
533
                self.chunks.append(self.chunk_in_progress)
 
534
            self.chunk_in_progress = None
 
535
            self.state_accept = self._state_accept_expecting_length
 
536
 
 
537
    def _state_accept_reading_unused(self):
 
538
        self.unused_data += self._get_in_buffer()
 
539
        self._in_buffer_list = []
 
540
 
 
541
 
 
542
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
543
    """Decodes the length-prefixed bulk data."""
 
544
 
 
545
    def __init__(self):
 
546
        _StatefulDecoder.__init__(self)
 
547
        self.state_accept = self._state_accept_expecting_length
 
548
        self.state_read = self._state_read_no_data
 
549
        self._body = b''
 
550
        self._trailer_buffer = b''
 
551
 
 
552
    def next_read_size(self):
 
553
        if self.bytes_left is not None:
 
554
            # Ideally we want to read all the remainder of the body and the
 
555
            # trailer in one go.
 
556
            return self.bytes_left + 5
 
557
        elif self.state_accept == self._state_accept_reading_trailer:
 
558
            # Just the trailer left
 
559
            return 5 - len(self._trailer_buffer)
 
560
        elif self.state_accept == self._state_accept_expecting_length:
 
561
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
562
            # 'done\n').
 
563
            return 6
 
564
        else:
 
565
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
566
            return 1
 
567
 
 
568
    def read_pending_data(self):
 
569
        """Return any pending data that has been decoded."""
 
570
        return self.state_read()
 
571
 
 
572
    def _state_accept_expecting_length(self):
 
573
        in_buf = self._get_in_buffer()
 
574
        pos = in_buf.find(b'\n')
 
575
        if pos == -1:
 
576
            return
 
577
        self.bytes_left = int(in_buf[:pos])
 
578
        self._set_in_buffer(in_buf[pos + 1:])
 
579
        self.state_accept = self._state_accept_reading_body
 
580
        self.state_read = self._state_read_body_buffer
 
581
 
 
582
    def _state_accept_reading_body(self):
 
583
        in_buf = self._get_in_buffer()
 
584
        self._body += in_buf
 
585
        self.bytes_left -= len(in_buf)
 
586
        self._set_in_buffer(None)
 
587
        if self.bytes_left <= 0:
 
588
            # Finished with body
 
589
            if self.bytes_left != 0:
 
590
                self._trailer_buffer = self._body[self.bytes_left:]
 
591
                self._body = self._body[:self.bytes_left]
 
592
            self.bytes_left = None
 
593
            self.state_accept = self._state_accept_reading_trailer
 
594
 
 
595
    def _state_accept_reading_trailer(self):
 
596
        self._trailer_buffer += self._get_in_buffer()
 
597
        self._set_in_buffer(None)
 
598
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
599
        # a ProtocolViolation exception?
 
600
        if self._trailer_buffer.startswith(b'done\n'):
 
601
            self.unused_data = self._trailer_buffer[len(b'done\n'):]
 
602
            self.state_accept = self._state_accept_reading_unused
 
603
            self.finished_reading = True
 
604
 
 
605
    def _state_accept_reading_unused(self):
 
606
        self.unused_data += self._get_in_buffer()
 
607
        self._set_in_buffer(None)
 
608
 
 
609
    def _state_read_no_data(self):
 
610
        return b''
 
611
 
 
612
    def _state_read_body_buffer(self):
 
613
        result = self._body
 
614
        self._body = b''
 
615
        return result
 
616
 
 
617
 
 
618
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
619
                                    message.ResponseHandler):
 
620
    """The client-side protocol for smart version 1."""
 
621
 
 
622
    def __init__(self, request):
 
623
        """Construct a SmartClientRequestProtocolOne.
 
624
 
 
625
        :param request: A SmartClientMediumRequest to serialise onto and
 
626
            deserialise from.
 
627
        """
 
628
        self._request = request
 
629
        self._body_buffer = None
 
630
        self._request_start_time = None
 
631
        self._last_verb = None
 
632
        self._headers = None
 
633
 
 
634
    def set_headers(self, headers):
 
635
        self._headers = dict(headers)
 
636
 
 
637
    def call(self, *args):
 
638
        if 'hpss' in debug.debug_flags:
 
639
            mutter('hpss call:   %s', repr(args)[1:-1])
 
640
            if getattr(self._request._medium, 'base', None) is not None:
 
641
                mutter('             (to %s)', self._request._medium.base)
 
642
            self._request_start_time = osutils.perf_counter()
 
643
        self._write_args(args)
 
644
        self._request.finished_writing()
 
645
        self._last_verb = args[0]
 
646
 
 
647
    def call_with_body_bytes(self, args, body):
 
648
        """Make a remote call of args with body bytes 'body'.
 
649
 
 
650
        After calling this, call read_response_tuple to find the result out.
 
651
        """
 
652
        if 'hpss' in debug.debug_flags:
 
653
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
654
            if getattr(self._request._medium, '_path', None) is not None:
 
655
                mutter('                  (to %s)',
 
656
                       self._request._medium._path)
 
657
            mutter('              %d bytes', len(body))
 
658
            self._request_start_time = osutils.perf_counter()
 
659
            if 'hpssdetail' in debug.debug_flags:
 
660
                mutter('hpss body content: %s', body)
 
661
        self._write_args(args)
 
662
        bytes = self._encode_bulk_data(body)
 
663
        self._request.accept_bytes(bytes)
 
664
        self._request.finished_writing()
 
665
        self._last_verb = args[0]
 
666
 
 
667
    def call_with_body_readv_array(self, args, body):
 
668
        """Make a remote call with a readv array.
 
669
 
 
670
        The body is encoded with one line per readv offset pair. The numbers in
 
671
        each pair are separated by a comma, and no trailing \\n is emitted.
 
672
        """
 
673
        if 'hpss' in debug.debug_flags:
 
674
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
675
            if getattr(self._request._medium, '_path', None) is not None:
 
676
                mutter('                  (to %s)',
 
677
                       self._request._medium._path)
 
678
            self._request_start_time = osutils.perf_counter()
 
679
        self._write_args(args)
 
680
        readv_bytes = self._serialise_offsets(body)
 
681
        bytes = self._encode_bulk_data(readv_bytes)
 
682
        self._request.accept_bytes(bytes)
 
683
        self._request.finished_writing()
 
684
        if 'hpss' in debug.debug_flags:
 
685
            mutter('              %d bytes in readv request', len(readv_bytes))
 
686
        self._last_verb = args[0]
 
687
 
 
688
    def call_with_body_stream(self, args, stream):
 
689
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
690
        # assume that a v1/v2 server doesn't support whatever method we're
 
691
        # trying to call with a body stream.
 
692
        self._request.finished_writing()
 
693
        self._request.finished_reading()
 
694
        raise errors.UnknownSmartMethod(args[0])
 
695
 
 
696
    def cancel_read_body(self):
 
697
        """After expecting a body, a response code may indicate one otherwise.
 
698
 
 
699
        This method lets the domain client inform the protocol that no body
 
700
        will be transmitted. This is a terminal method: after calling it the
 
701
        protocol is not able to be used further.
 
702
        """
 
703
        self._request.finished_reading()
 
704
 
 
705
    def _read_response_tuple(self):
 
706
        result = self._recv_tuple()
 
707
        if 'hpss' in debug.debug_flags:
 
708
            if self._request_start_time is not None:
 
709
                mutter('   result:   %6.3fs  %s',
 
710
                       osutils.perf_counter() - self._request_start_time,
 
711
                       repr(result)[1:-1])
 
712
                self._request_start_time = None
 
713
            else:
 
714
                mutter('   result:   %s', repr(result)[1:-1])
 
715
        return result
 
716
 
 
717
    def read_response_tuple(self, expect_body=False):
 
718
        """Read a response tuple from the wire.
 
719
 
 
720
        This should only be called once.
 
721
        """
 
722
        result = self._read_response_tuple()
 
723
        self._response_is_unknown_method(result)
 
724
        self._raise_args_if_error(result)
 
725
        if not expect_body:
 
726
            self._request.finished_reading()
 
727
        return result
 
728
 
 
729
    def _raise_args_if_error(self, result_tuple):
 
730
        # Later protocol versions have an explicit flag in the protocol to say
 
731
        # if an error response is "failed" or not.  In version 1 we don't have
 
732
        # that luxury.  So here is a complete list of errors that can be
 
733
        # returned in response to existing version 1 smart requests.  Responses
 
734
        # starting with these codes are always "failed" responses.
 
735
        v1_error_codes = [
 
736
            b'norepository',
 
737
            b'NoSuchFile',
 
738
            b'FileExists',
 
739
            b'DirectoryNotEmpty',
 
740
            b'ShortReadvError',
 
741
            b'UnicodeEncodeError',
 
742
            b'UnicodeDecodeError',
 
743
            b'ReadOnlyError',
 
744
            b'nobranch',
 
745
            b'NoSuchRevision',
 
746
            b'nosuchrevision',
 
747
            b'LockContention',
 
748
            b'UnlockableTransport',
 
749
            b'LockFailed',
 
750
            b'TokenMismatch',
 
751
            b'ReadError',
 
752
            b'PermissionDenied',
 
753
            ]
 
754
        if result_tuple[0] in v1_error_codes:
 
755
            self._request.finished_reading()
 
756
            raise errors.ErrorFromSmartServer(result_tuple)
 
757
 
 
758
    def _response_is_unknown_method(self, result_tuple):
 
759
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
760
        method' response to the request.
 
761
 
 
762
        :param response: The response from a smart client call_expecting_body
 
763
            call.
 
764
        :param verb: The verb used in that call.
 
765
        :raises: UnexpectedSmartServerResponse
 
766
        """
 
767
        if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
 
768
                             b"bad request '" + self._last_verb + b"'")
 
769
            or result_tuple == (b'error', b"Generic bzr smart protocol error: "
 
770
                                b"bad request u'%s'" % self._last_verb)):
 
771
            # The response will have no body, so we've finished reading.
 
772
            self._request.finished_reading()
 
773
            raise errors.UnknownSmartMethod(self._last_verb)
 
774
 
 
775
    def read_body_bytes(self, count=-1):
 
776
        """Read bytes from the body, decoding into a byte stream.
 
777
 
 
778
        We read all bytes at once to ensure we've checked the trailer for
 
779
        errors, and then feed the buffer back as read_body_bytes is called.
 
780
        """
 
781
        if self._body_buffer is not None:
 
782
            return self._body_buffer.read(count)
 
783
        _body_decoder = LengthPrefixedBodyDecoder()
 
784
 
 
785
        while not _body_decoder.finished_reading:
 
786
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
787
            if bytes == b'':
 
788
                # end of file encountered reading from server
 
789
                raise errors.ConnectionReset(
 
790
                    "Connection lost while reading response body.")
 
791
            _body_decoder.accept_bytes(bytes)
 
792
        self._request.finished_reading()
 
793
        self._body_buffer = BytesIO(_body_decoder.read_pending_data())
 
794
        # XXX: TODO check the trailer result.
 
795
        if 'hpss' in debug.debug_flags:
 
796
            mutter('              %d body bytes read',
 
797
                   len(self._body_buffer.getvalue()))
 
798
        return self._body_buffer.read(count)
 
799
 
 
800
    def _recv_tuple(self):
 
801
        """Receive a tuple from the medium request."""
 
802
        return _decode_tuple(self._request.read_line())
 
803
 
 
804
    def query_version(self):
 
805
        """Return protocol version number of the server."""
 
806
        self.call(b'hello')
 
807
        resp = self.read_response_tuple()
 
808
        if resp == (b'ok', b'1'):
 
809
            return 1
 
810
        elif resp == (b'ok', b'2'):
 
811
            return 2
 
812
        else:
 
813
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
814
 
 
815
    def _write_args(self, args):
 
816
        self._write_protocol_version()
 
817
        bytes = _encode_tuple(args)
 
818
        self._request.accept_bytes(bytes)
 
819
 
 
820
    def _write_protocol_version(self):
 
821
        """Write any prefixes this protocol requires.
 
822
 
 
823
        Version one doesn't send protocol versions.
 
824
        """
 
825
 
 
826
 
 
827
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
828
    """Version two of the client side of the smart protocol.
 
829
 
 
830
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
831
    """
 
832
 
 
833
    response_marker = RESPONSE_VERSION_TWO
 
834
    request_marker = REQUEST_VERSION_TWO
 
835
 
 
836
    def read_response_tuple(self, expect_body=False):
 
837
        """Read a response tuple from the wire.
 
838
 
 
839
        This should only be called once.
 
840
        """
 
841
        version = self._request.read_line()
 
842
        if version != self.response_marker:
 
843
            self._request.finished_reading()
 
844
            raise errors.UnexpectedProtocolVersionMarker(version)
 
845
        response_status = self._request.read_line()
 
846
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
847
        self._response_is_unknown_method(result)
 
848
        if response_status == b'success\n':
 
849
            self.response_status = True
 
850
            if not expect_body:
 
851
                self._request.finished_reading()
 
852
            return result
 
853
        elif response_status == b'failed\n':
 
854
            self.response_status = False
 
855
            self._request.finished_reading()
 
856
            raise errors.ErrorFromSmartServer(result)
 
857
        else:
 
858
            raise errors.SmartProtocolError(
 
859
                'bad protocol status %r' % response_status)
 
860
 
 
861
    def _write_protocol_version(self):
 
862
        """Write any prefixes this protocol requires.
 
863
 
 
864
        Version two sends the value of REQUEST_VERSION_TWO.
 
865
        """
 
866
        self._request.accept_bytes(self.request_marker)
 
867
 
 
868
    def read_streamed_body(self):
 
869
        """Read bytes from the body, decoding into a byte stream.
 
870
        """
 
871
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
872
        # buffer space available) on Windows.
 
873
        _body_decoder = ChunkedBodyDecoder()
 
874
        while not _body_decoder.finished_reading:
 
875
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
876
            if bytes == b'':
 
877
                # end of file encountered reading from server
 
878
                raise errors.ConnectionReset(
 
879
                    "Connection lost while reading streamed body.")
 
880
            _body_decoder.accept_bytes(bytes)
 
881
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
882
                if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
 
883
                    mutter('              %d byte chunk read',
 
884
                           len(body_bytes))
 
885
                yield body_bytes
 
886
        self._request.finished_reading()
 
887
 
 
888
 
 
889
def build_server_protocol_three(backing_transport, write_func,
 
890
                                root_client_path, jail_root=None):
 
891
    request_handler = request.SmartServerRequestHandler(
 
892
        backing_transport, commands=request.request_handlers,
 
893
        root_client_path=root_client_path, jail_root=jail_root)
 
894
    responder = ProtocolThreeResponder(write_func)
 
895
    message_handler = message.ConventionalRequestHandler(
 
896
        request_handler, responder)
 
897
    return ProtocolThreeDecoder(message_handler)
 
898
 
 
899
 
 
900
class ProtocolThreeDecoder(_StatefulDecoder):
 
901
 
 
902
    response_marker = RESPONSE_VERSION_THREE
 
903
    request_marker = REQUEST_VERSION_THREE
 
904
 
 
905
    def __init__(self, message_handler, expect_version_marker=False):
 
906
        _StatefulDecoder.__init__(self)
 
907
        self._has_dispatched = False
 
908
        # Initial state
 
909
        if expect_version_marker:
 
910
            self.state_accept = self._state_accept_expecting_protocol_version
 
911
            # We're expecting at least the protocol version marker + some
 
912
            # headers.
 
913
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
914
        else:
 
915
            self.state_accept = self._state_accept_expecting_headers
 
916
            self._number_needed_bytes = 4
 
917
        self.decoding_failed = False
 
918
        self.request_handler = self.message_handler = message_handler
 
919
 
 
920
    def accept_bytes(self, bytes):
 
921
        self._number_needed_bytes = None
 
922
        try:
 
923
            _StatefulDecoder.accept_bytes(self, bytes)
 
924
        except KeyboardInterrupt:
 
925
            raise
 
926
        except errors.SmartMessageHandlerError as exception:
 
927
            # We do *not* set self.decoding_failed here.  The message handler
 
928
            # has raised an error, but the decoder is still able to parse bytes
 
929
            # and determine when this message ends.
 
930
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
931
                log_exception_quietly()
 
932
            self.message_handler.protocol_error(exception.exc_value)
 
933
            # The state machine is ready to continue decoding, but the
 
934
            # exception has interrupted the loop that runs the state machine.
 
935
            # So we call accept_bytes again to restart it.
 
936
            self.accept_bytes(b'')
 
937
        except Exception as exception:
 
938
            # The decoder itself has raised an exception.  We cannot continue
 
939
            # decoding.
 
940
            self.decoding_failed = True
 
941
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
942
                # This happens during normal operation when the client tries a
 
943
                # protocol version the server doesn't understand, so no need to
 
944
                # log a traceback every time.
 
945
                # Note that this can only happen when
 
946
                # expect_version_marker=True, which is only the case on the
 
947
                # client side.
 
948
                pass
 
949
            else:
 
950
                log_exception_quietly()
 
951
            self.message_handler.protocol_error(exception)
 
952
 
 
953
    def _extract_length_prefixed_bytes(self):
 
954
        if self._in_buffer_len < 4:
 
955
            # A length prefix by itself is 4 bytes, and we don't even have that
 
956
            # many yet.
 
957
            raise _NeedMoreBytes(4)
 
958
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
959
        end_of_bytes = 4 + length
 
960
        if self._in_buffer_len < end_of_bytes:
 
961
            # We haven't yet read as many bytes as the length-prefix says there
 
962
            # are.
 
963
            raise _NeedMoreBytes(end_of_bytes)
 
964
        # Extract the bytes from the buffer.
 
965
        in_buf = self._get_in_buffer()
 
966
        bytes = in_buf[4:end_of_bytes]
 
967
        self._set_in_buffer(in_buf[end_of_bytes:])
 
968
        return bytes
 
969
 
 
970
    def _extract_prefixed_bencoded_data(self):
 
971
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
972
        try:
 
973
            decoded = bdecode_as_tuple(prefixed_bytes)
 
974
        except ValueError:
 
975
            raise errors.SmartProtocolError(
 
976
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
977
        return decoded
 
978
 
 
979
    def _extract_single_byte(self):
 
980
        if self._in_buffer_len == 0:
 
981
            # The buffer is empty
 
982
            raise _NeedMoreBytes(1)
 
983
        in_buf = self._get_in_buffer()
 
984
        one_byte = in_buf[0:1]
 
985
        self._set_in_buffer(in_buf[1:])
 
986
        return one_byte
 
987
 
 
988
    def _state_accept_expecting_protocol_version(self):
 
989
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
990
        in_buf = self._get_in_buffer()
 
991
        if needed_bytes > 0:
 
992
            # We don't have enough bytes to check if the protocol version
 
993
            # marker is right.  But we can check if it is already wrong by
 
994
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
995
            # we've read so far.
 
996
            # [In fact, if the remote end isn't bzr we might never receive
 
997
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
998
            # are wrong then we should just raise immediately rather than
 
999
            # stall.]
 
1000
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
1001
                # We have enough bytes to know the protocol version is wrong
 
1002
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
1003
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
1004
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
1005
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
1006
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
1007
        self.state_accept = self._state_accept_expecting_headers
 
1008
 
 
1009
    def _state_accept_expecting_headers(self):
 
1010
        decoded = self._extract_prefixed_bencoded_data()
 
1011
        if not isinstance(decoded, dict):
 
1012
            raise errors.SmartProtocolError(
 
1013
                'Header object %r is not a dict' % (decoded,))
 
1014
        self.state_accept = self._state_accept_expecting_message_part
 
1015
        try:
 
1016
            self.message_handler.headers_received(decoded)
 
1017
        except:
 
1018
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1019
 
 
1020
    def _state_accept_expecting_message_part(self):
 
1021
        message_part_kind = self._extract_single_byte()
 
1022
        if message_part_kind == b'o':
 
1023
            self.state_accept = self._state_accept_expecting_one_byte
 
1024
        elif message_part_kind == b's':
 
1025
            self.state_accept = self._state_accept_expecting_structure
 
1026
        elif message_part_kind == b'b':
 
1027
            self.state_accept = self._state_accept_expecting_bytes
 
1028
        elif message_part_kind == b'e':
 
1029
            self.done()
 
1030
        else:
 
1031
            raise errors.SmartProtocolError(
 
1032
                'Bad message kind byte: %r' % (message_part_kind,))
 
1033
 
 
1034
    def _state_accept_expecting_one_byte(self):
 
1035
        byte = self._extract_single_byte()
 
1036
        self.state_accept = self._state_accept_expecting_message_part
 
1037
        try:
 
1038
            self.message_handler.byte_part_received(byte)
 
1039
        except:
 
1040
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1041
 
 
1042
    def _state_accept_expecting_bytes(self):
 
1043
        # XXX: this should not buffer whole message part, but instead deliver
 
1044
        # the bytes as they arrive.
 
1045
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1046
        self.state_accept = self._state_accept_expecting_message_part
 
1047
        try:
 
1048
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1049
        except:
 
1050
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1051
 
 
1052
    def _state_accept_expecting_structure(self):
 
1053
        structure = self._extract_prefixed_bencoded_data()
 
1054
        self.state_accept = self._state_accept_expecting_message_part
 
1055
        try:
 
1056
            self.message_handler.structure_part_received(structure)
 
1057
        except:
 
1058
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1059
 
 
1060
    def done(self):
 
1061
        self.unused_data = self._get_in_buffer()
 
1062
        self._set_in_buffer(None)
 
1063
        self.state_accept = self._state_accept_reading_unused
 
1064
        try:
 
1065
            self.message_handler.end_received()
 
1066
        except:
 
1067
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1068
 
 
1069
    def _state_accept_reading_unused(self):
 
1070
        self.unused_data += self._get_in_buffer()
 
1071
        self._set_in_buffer(None)
 
1072
 
 
1073
    def next_read_size(self):
 
1074
        if self.state_accept == self._state_accept_reading_unused:
 
1075
            return 0
 
1076
        elif self.decoding_failed:
 
1077
            # An exception occured while processing this message, probably from
 
1078
            # self.message_handler.  We're not sure that this state machine is
 
1079
            # in a consistent state, so just signal that we're done (i.e. give
 
1080
            # up).
 
1081
            return 0
 
1082
        else:
 
1083
            if self._number_needed_bytes is not None:
 
1084
                return self._number_needed_bytes - self._in_buffer_len
 
1085
            else:
 
1086
                raise AssertionError("don't know how many bytes are expected!")
 
1087
 
 
1088
 
 
1089
class _ProtocolThreeEncoder(object):
 
1090
 
 
1091
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1092
    BUFFER_SIZE = 1024 * 1024  # 1 MiB buffer before flushing
 
1093
 
 
1094
    def __init__(self, write_func):
 
1095
        self._buf = []
 
1096
        self._buf_len = 0
 
1097
        self._real_write_func = write_func
 
1098
 
 
1099
    def _write_func(self, bytes):
 
1100
        # TODO: Another possibility would be to turn this into an async model.
 
1101
        #       Where we let another thread know that we have some bytes if
 
1102
        #       they want it, but we don't actually block for it
 
1103
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1104
        #       we might just push out smaller bits at a time?
 
1105
        self._buf.append(bytes)
 
1106
        self._buf_len += len(bytes)
 
1107
        if self._buf_len > self.BUFFER_SIZE:
 
1108
            self.flush()
 
1109
 
 
1110
    def flush(self):
 
1111
        if self._buf:
 
1112
            self._real_write_func(b''.join(self._buf))
 
1113
            del self._buf[:]
 
1114
            self._buf_len = 0
 
1115
 
 
1116
    def _serialise_offsets(self, offsets):
 
1117
        """Serialise a readv offset list."""
 
1118
        txt = []
 
1119
        for start, length in offsets:
 
1120
            txt.append(b'%d,%d' % (start, length))
 
1121
        return b'\n'.join(txt)
 
1122
 
 
1123
    def _write_protocol_version(self):
 
1124
        self._write_func(MESSAGE_VERSION_THREE)
 
1125
 
 
1126
    def _write_prefixed_bencode(self, structure):
 
1127
        bytes = bencode(structure)
 
1128
        self._write_func(struct.pack('!L', len(bytes)))
 
1129
        self._write_func(bytes)
 
1130
 
 
1131
    def _write_headers(self, headers):
 
1132
        self._write_prefixed_bencode(headers)
 
1133
 
 
1134
    def _write_structure(self, args):
 
1135
        self._write_func(b's')
 
1136
        utf8_args = []
 
1137
        for arg in args:
 
1138
            if isinstance(arg, text_type):
 
1139
                utf8_args.append(arg.encode('utf8'))
 
1140
            else:
 
1141
                utf8_args.append(arg)
 
1142
        self._write_prefixed_bencode(utf8_args)
 
1143
 
 
1144
    def _write_end(self):
 
1145
        self._write_func(b'e')
 
1146
        self.flush()
 
1147
 
 
1148
    def _write_prefixed_body(self, bytes):
 
1149
        self._write_func(b'b')
 
1150
        self._write_func(struct.pack('!L', len(bytes)))
 
1151
        self._write_func(bytes)
 
1152
 
 
1153
    def _write_chunked_body_start(self):
 
1154
        self._write_func(b'oC')
 
1155
 
 
1156
    def _write_error_status(self):
 
1157
        self._write_func(b'oE')
 
1158
 
 
1159
    def _write_success_status(self):
 
1160
        self._write_func(b'oS')
 
1161
 
 
1162
 
 
1163
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1164
 
 
1165
    def __init__(self, write_func):
 
1166
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1167
        self.response_sent = False
 
1168
        self._headers = {
 
1169
            b'Software version': breezy.__version__.encode('utf-8')}
 
1170
        if 'hpss' in debug.debug_flags:
 
1171
            self._thread_id = _thread.get_ident()
 
1172
            self._response_start_time = None
 
1173
 
 
1174
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1175
        if self._response_start_time is None:
 
1176
            self._response_start_time = osutils.perf_counter()
 
1177
        if include_time:
 
1178
            t = '%5.3fs ' % (osutils.perf_counter() - self._response_start_time)
 
1179
        else:
 
1180
            t = ''
 
1181
        if extra_bytes is None:
 
1182
            extra = ''
 
1183
        else:
 
1184
            extra = ' ' + repr(extra_bytes[:40])
 
1185
            if len(extra) > 33:
 
1186
                extra = extra[:29] + extra[-1] + '...'
 
1187
        mutter('%12s: [%s] %s%s%s'
 
1188
               % (action, self._thread_id, t, message, extra))
 
1189
 
 
1190
    def send_error(self, exception):
 
1191
        if self.response_sent:
 
1192
            raise AssertionError(
 
1193
                "send_error(%s) called, but response already sent."
 
1194
                % (exception,))
 
1195
        if isinstance(exception, errors.UnknownSmartMethod):
 
1196
            failure = request.FailedSmartServerResponse(
 
1197
                (b'UnknownMethod', exception.verb))
 
1198
            self.send_response(failure)
 
1199
            return
 
1200
        if 'hpss' in debug.debug_flags:
 
1201
            self._trace('error', str(exception))
 
1202
        self.response_sent = True
 
1203
        self._write_protocol_version()
 
1204
        self._write_headers(self._headers)
 
1205
        self._write_error_status()
 
1206
        self._write_structure(
 
1207
            (b'error', str(exception).encode('utf-8', 'replace')))
 
1208
        self._write_end()
 
1209
 
 
1210
    def send_response(self, response):
 
1211
        if self.response_sent:
 
1212
            raise AssertionError(
 
1213
                "send_response(%r) called, but response already sent."
 
1214
                % (response,))
 
1215
        self.response_sent = True
 
1216
        self._write_protocol_version()
 
1217
        self._write_headers(self._headers)
 
1218
        if response.is_successful():
 
1219
            self._write_success_status()
 
1220
        else:
 
1221
            self._write_error_status()
 
1222
        if 'hpss' in debug.debug_flags:
 
1223
            self._trace('response', repr(response.args))
 
1224
        self._write_structure(response.args)
 
1225
        if response.body is not None:
 
1226
            self._write_prefixed_body(response.body)
 
1227
            if 'hpss' in debug.debug_flags:
 
1228
                self._trace('body', '%d bytes' % (len(response.body),),
 
1229
                            response.body, include_time=True)
 
1230
        elif response.body_stream is not None:
 
1231
            count = num_bytes = 0
 
1232
            first_chunk = None
 
1233
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1234
                count += 1
 
1235
                if exc_info is not None:
 
1236
                    self._write_error_status()
 
1237
                    error_struct = request._translate_error(exc_info[1])
 
1238
                    self._write_structure(error_struct)
 
1239
                    break
 
1240
                else:
 
1241
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1242
                        self._write_error_status()
 
1243
                        self._write_structure(chunk.args)
 
1244
                        break
 
1245
                    num_bytes += len(chunk)
 
1246
                    if first_chunk is None:
 
1247
                        first_chunk = chunk
 
1248
                    self._write_prefixed_body(chunk)
 
1249
                    self.flush()
 
1250
                    if 'hpssdetail' in debug.debug_flags:
 
1251
                        # Not worth timing separately, as _write_func is
 
1252
                        # actually buffered
 
1253
                        self._trace('body chunk',
 
1254
                                    '%d bytes' % (len(chunk),),
 
1255
                                    chunk, suppress_time=True)
 
1256
            if 'hpss' in debug.debug_flags:
 
1257
                self._trace('body stream',
 
1258
                            '%d bytes %d chunks' % (num_bytes, count),
 
1259
                            first_chunk)
 
1260
        self._write_end()
 
1261
        if 'hpss' in debug.debug_flags:
 
1262
            self._trace('response end', '', include_time=True)
 
1263
 
 
1264
 
 
1265
def _iter_with_errors(iterable):
 
1266
    """Handle errors from iterable.next().
 
1267
 
 
1268
    Use like::
 
1269
 
 
1270
        for exc_info, value in _iter_with_errors(iterable):
 
1271
            ...
 
1272
 
 
1273
    This is a safer alternative to::
 
1274
 
 
1275
        try:
 
1276
            for value in iterable:
 
1277
               ...
 
1278
        except:
 
1279
            ...
 
1280
 
 
1281
    Because the latter will catch errors from the for-loop body, not just
 
1282
    iterable.next()
 
1283
 
 
1284
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1285
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1286
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1287
    will not be itercepted.
 
1288
    """
 
1289
    iterator = iter(iterable)
 
1290
    while True:
 
1291
        try:
 
1292
            yield None, next(iterator)
 
1293
        except StopIteration:
 
1294
            return
 
1295
        except (KeyboardInterrupt, SystemExit):
 
1296
            raise
 
1297
        except Exception:
 
1298
            mutter('_iter_with_errors caught error')
 
1299
            log_exception_quietly()
 
1300
            yield sys.exc_info(), None
 
1301
            return
 
1302
 
 
1303
 
 
1304
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1305
 
 
1306
    def __init__(self, medium_request):
 
1307
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1308
        self._medium_request = medium_request
 
1309
        self._headers = {}
 
1310
        self.body_stream_started = None
 
1311
 
 
1312
    def set_headers(self, headers):
 
1313
        self._headers = headers.copy()
 
1314
 
 
1315
    def call(self, *args):
 
1316
        if 'hpss' in debug.debug_flags:
 
1317
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1318
            base = getattr(self._medium_request._medium, 'base', None)
 
1319
            if base is not None:
 
1320
                mutter('             (to %s)', base)
 
1321
            self._request_start_time = osutils.perf_counter()
 
1322
        self._write_protocol_version()
 
1323
        self._write_headers(self._headers)
 
1324
        self._write_structure(args)
 
1325
        self._write_end()
 
1326
        self._medium_request.finished_writing()
 
1327
 
 
1328
    def call_with_body_bytes(self, args, body):
 
1329
        """Make a remote call of args with body bytes 'body'.
 
1330
 
 
1331
        After calling this, call read_response_tuple to find the result out.
 
1332
        """
 
1333
        if 'hpss' in debug.debug_flags:
 
1334
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1335
            path = getattr(self._medium_request._medium, '_path', None)
 
1336
            if path is not None:
 
1337
                mutter('                  (to %s)', path)
 
1338
            mutter('              %d bytes', len(body))
 
1339
            self._request_start_time = osutils.perf_counter()
 
1340
        self._write_protocol_version()
 
1341
        self._write_headers(self._headers)
 
1342
        self._write_structure(args)
 
1343
        self._write_prefixed_body(body)
 
1344
        self._write_end()
 
1345
        self._medium_request.finished_writing()
 
1346
 
 
1347
    def call_with_body_readv_array(self, args, body):
 
1348
        """Make a remote call with a readv array.
 
1349
 
 
1350
        The body is encoded with one line per readv offset pair. The numbers in
 
1351
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1352
        """
 
1353
        if 'hpss' in debug.debug_flags:
 
1354
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1355
            path = getattr(self._medium_request._medium, '_path', None)
 
1356
            if path is not None:
 
1357
                mutter('                  (to %s)', path)
 
1358
            self._request_start_time = osutils.perf_counter()
 
1359
        self._write_protocol_version()
 
1360
        self._write_headers(self._headers)
 
1361
        self._write_structure(args)
 
1362
        readv_bytes = self._serialise_offsets(body)
 
1363
        if 'hpss' in debug.debug_flags:
 
1364
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1365
        self._write_prefixed_body(readv_bytes)
 
1366
        self._write_end()
 
1367
        self._medium_request.finished_writing()
 
1368
 
 
1369
    def call_with_body_stream(self, args, stream):
 
1370
        if 'hpss' in debug.debug_flags:
 
1371
            mutter('hpss call w/body stream: %r', args)
 
1372
            path = getattr(self._medium_request._medium, '_path', None)
 
1373
            if path is not None:
 
1374
                mutter('                  (to %s)', path)
 
1375
            self._request_start_time = osutils.perf_counter()
 
1376
        self.body_stream_started = False
 
1377
        self._write_protocol_version()
 
1378
        self._write_headers(self._headers)
 
1379
        self._write_structure(args)
 
1380
        # TODO: notice if the server has sent an early error reply before we
 
1381
        #       have finished sending the stream.  We would notice at the end
 
1382
        #       anyway, but if the medium can deliver it early then it's good
 
1383
        #       to short-circuit the whole request...
 
1384
        # Provoke any ConnectionReset failures before we start the body stream.
 
1385
        self.flush()
 
1386
        self.body_stream_started = True
 
1387
        for exc_info, part in _iter_with_errors(stream):
 
1388
            if exc_info is not None:
 
1389
                # Iterating the stream failed.  Cleanly abort the request.
 
1390
                self._write_error_status()
 
1391
                # Currently the client unconditionally sends ('error',) as the
 
1392
                # error args.
 
1393
                self._write_structure((b'error',))
 
1394
                self._write_end()
 
1395
                self._medium_request.finished_writing()
 
1396
                try:
 
1397
                    reraise(*exc_info)
 
1398
                finally:
 
1399
                    del exc_info
 
1400
            else:
 
1401
                self._write_prefixed_body(part)
 
1402
                self.flush()
 
1403
        self._write_end()
 
1404
        self._medium_request.finished_writing()