/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2019-03-13 23:24:13 UTC
  • mto: (7290.1.23 work)
  • mto: This revision was merged to the branch mainline in revision 7311.
  • Revision ID: jelmer@jelmer.uk-20190313232413-y1c951be4surcc9g
Fix formatting.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2011 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""The 'medium' layer for the smart servers and clients.
 
18
 
 
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
 
20
for "the quality between big and small."
 
21
 
 
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
 
23
over SSH), and pass them to and from the protocol logic.  See the overview in
 
24
breezy/transport/smart/__init__.py.
 
25
"""
 
26
 
 
27
from __future__ import absolute_import
 
28
 
 
29
import errno
 
30
import io
 
31
import os
 
32
import sys
 
33
import time
 
34
 
 
35
try:
 
36
    import _thread
 
37
except ImportError:
 
38
    import thread as _thread
 
39
 
 
40
import breezy
 
41
from ...lazy_import import lazy_import
 
42
lazy_import(globals(), """
 
43
import select
 
44
import socket
 
45
import weakref
 
46
 
 
47
from breezy import (
 
48
    debug,
 
49
    trace,
 
50
    transport,
 
51
    ui,
 
52
    urlutils,
 
53
    )
 
54
from breezy.i18n import gettext
 
55
from breezy.bzr.smart import client, protocol, request, signals, vfs
 
56
from breezy.transport import ssh
 
57
""")
 
58
from ... import (
 
59
    errors,
 
60
    osutils,
 
61
    )
 
62
 
 
63
# Throughout this module buffer size parameters are either limited to be at
 
64
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
 
65
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
 
66
# from non-sockets as well.
 
67
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
 
68
 
 
69
 
 
70
class HpssVfsRequestNotAllowed(errors.BzrError):
 
71
 
 
72
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
 
73
            "%(method)s, %(arguments)s.")
 
74
 
 
75
    def __init__(self, method, arguments):
 
76
        self.method = method
 
77
        self.arguments = arguments
 
78
 
 
79
 
 
80
def _get_protocol_factory_for_bytes(bytes):
 
81
    """Determine the right protocol factory for 'bytes'.
 
82
 
 
83
    This will return an appropriate protocol factory depending on the version
 
84
    of the protocol being used, as determined by inspecting the given bytes.
 
85
    The bytes should have at least one newline byte (i.e. be a whole line),
 
86
    otherwise it's possible that a request will be incorrectly identified as
 
87
    version 1.
 
88
 
 
89
    Typical use would be::
 
90
 
 
91
         factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
92
         server_protocol = factory(transport, write_func, root_client_path)
 
93
         server_protocol.accept_bytes(unused_bytes)
 
94
 
 
95
    :param bytes: a str of bytes of the start of the request.
 
96
    :returns: 2-tuple of (protocol_factory, unused_bytes).  protocol_factory is
 
97
        a callable that takes three args: transport, write_func,
 
98
        root_client_path.  unused_bytes are any bytes that were not part of a
 
99
        protocol version marker.
 
100
    """
 
101
    if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
 
102
        protocol_factory = protocol.build_server_protocol_three
 
103
        bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
 
104
    elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
 
105
        protocol_factory = protocol.SmartServerRequestProtocolTwo
 
106
        bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
 
107
    else:
 
108
        protocol_factory = protocol.SmartServerRequestProtocolOne
 
109
    return protocol_factory, bytes
 
110
 
 
111
 
 
112
def _get_line(read_bytes_func):
 
113
    """Read bytes using read_bytes_func until a newline byte.
 
114
 
 
115
    This isn't particularly efficient, so should only be used when the
 
116
    expected size of the line is quite short.
 
117
 
 
118
    :returns: a tuple of two strs: (line, excess)
 
119
    """
 
120
    newline_pos = -1
 
121
    bytes = b''
 
122
    while newline_pos == -1:
 
123
        new_bytes = read_bytes_func(1)
 
124
        bytes += new_bytes
 
125
        if new_bytes == b'':
 
126
            # Ran out of bytes before receiving a complete line.
 
127
            return bytes, b''
 
128
        newline_pos = bytes.find(b'\n')
 
129
    line = bytes[:newline_pos + 1]
 
130
    excess = bytes[newline_pos + 1:]
 
131
    return line, excess
 
132
 
 
133
 
 
134
class SmartMedium(object):
 
135
    """Base class for smart protocol media, both client- and server-side."""
 
136
 
 
137
    def __init__(self):
 
138
        self._push_back_buffer = None
 
139
 
 
140
    def _push_back(self, data):
 
141
        """Return unused bytes to the medium, because they belong to the next
 
142
        request(s).
 
143
 
 
144
        This sets the _push_back_buffer to the given bytes.
 
145
        """
 
146
        if not isinstance(data, bytes):
 
147
            raise TypeError(data)
 
148
        if self._push_back_buffer is not None:
 
149
            raise AssertionError(
 
150
                "_push_back called when self._push_back_buffer is %r"
 
151
                % (self._push_back_buffer,))
 
152
        if data == b'':
 
153
            return
 
154
        self._push_back_buffer = data
 
155
 
 
156
    def _get_push_back_buffer(self):
 
157
        if self._push_back_buffer == b'':
 
158
            raise AssertionError(
 
159
                '%s._push_back_buffer should never be the empty string, '
 
160
                'which can be confused with EOF' % (self,))
 
161
        bytes = self._push_back_buffer
 
162
        self._push_back_buffer = None
 
163
        return bytes
 
164
 
 
165
    def read_bytes(self, desired_count):
 
166
        """Read some bytes from this medium.
 
167
 
 
168
        :returns: some bytes, possibly more or less than the number requested
 
169
            in 'desired_count' depending on the medium.
 
170
        """
 
171
        if self._push_back_buffer is not None:
 
172
            return self._get_push_back_buffer()
 
173
        bytes_to_read = min(desired_count, _MAX_READ_SIZE)
 
174
        return self._read_bytes(bytes_to_read)
 
175
 
 
176
    def _read_bytes(self, count):
 
177
        raise NotImplementedError(self._read_bytes)
 
178
 
 
179
    def _get_line(self):
 
180
        """Read bytes from this request's response until a newline byte.
 
181
 
 
182
        This isn't particularly efficient, so should only be used when the
 
183
        expected size of the line is quite short.
 
184
 
 
185
        :returns: a string of bytes ending in a newline (byte 0x0A).
 
186
        """
 
187
        line, excess = _get_line(self.read_bytes)
 
188
        self._push_back(excess)
 
189
        return line
 
190
 
 
191
    def _report_activity(self, bytes, direction):
 
192
        """Notify that this medium has activity.
 
193
 
 
194
        Implementations should call this from all methods that actually do IO.
 
195
        Be careful that it's not called twice, if one method is implemented on
 
196
        top of another.
 
197
 
 
198
        :param bytes: Number of bytes read or written.
 
199
        :param direction: 'read' or 'write' or None.
 
200
        """
 
201
        ui.ui_factory.report_transport_activity(self, bytes, direction)
 
202
 
 
203
 
 
204
_bad_file_descriptor = (errno.EBADF,)
 
205
if sys.platform == 'win32':
 
206
    # Given on Windows if you pass a closed socket to select.select. Probably
 
207
    # also given if you pass a file handle to select.
 
208
    WSAENOTSOCK = 10038
 
209
    _bad_file_descriptor += (WSAENOTSOCK,)
 
210
 
 
211
 
 
212
class SmartServerStreamMedium(SmartMedium):
 
213
    """Handles smart commands coming over a stream.
 
214
 
 
215
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
216
    in-process fifo for testing.
 
217
 
 
218
    One instance is created for each connected client; it can serve multiple
 
219
    requests in the lifetime of the connection.
 
220
 
 
221
    The server passes requests through to an underlying backing transport,
 
222
    which will typically be a LocalTransport looking at the server's filesystem.
 
223
 
 
224
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
225
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
226
        should make sure to exhaust this buffer before reading more bytes from
 
227
        the stream.  See also the _push_back method.
 
228
    """
 
229
 
 
230
    _timer = time.time
 
231
 
 
232
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
233
        """Construct new server.
 
234
 
 
235
        :param backing_transport: Transport for the directory served.
 
236
        """
 
237
        # backing_transport could be passed to serve instead of __init__
 
238
        self.backing_transport = backing_transport
 
239
        self.root_client_path = root_client_path
 
240
        self.finished = False
 
241
        if timeout is None:
 
242
            raise AssertionError('You must supply a timeout.')
 
243
        self._client_timeout = timeout
 
244
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
 
245
        SmartMedium.__init__(self)
 
246
 
 
247
    def serve(self):
 
248
        """Serve requests until the client disconnects."""
 
249
        # Keep a reference to stderr because the sys module's globals get set to
 
250
        # None during interpreter shutdown.
 
251
        from sys import stderr
 
252
        try:
 
253
            while not self.finished:
 
254
                server_protocol = self._build_protocol()
 
255
                self._serve_one_request(server_protocol)
 
256
        except errors.ConnectionTimeout as e:
 
257
            trace.note('%s' % (e,))
 
258
            trace.log_exception_quietly()
 
259
            self._disconnect_client()
 
260
            # We reported it, no reason to make a big fuss.
 
261
            return
 
262
        except Exception as e:
 
263
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
264
            raise
 
265
        self._disconnect_client()
 
266
 
 
267
    def _stop_gracefully(self):
 
268
        """When we finish this message, stop looking for more."""
 
269
        trace.mutter('Stopping %s' % (self,))
 
270
        self.finished = True
 
271
 
 
272
    def _disconnect_client(self):
 
273
        """Close the current connection. We stopped due to a timeout/etc."""
 
274
        # The default implementation is a no-op, because that is all we used to
 
275
        # do when disconnecting from a client. I suppose we never had the
 
276
        # *server* initiate a disconnect, before
 
277
 
 
278
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
279
        """Wait for more bytes to be read, but timeout if none available.
 
280
 
 
281
        This allows us to detect idle connections, and stop trying to read from
 
282
        them, without setting the socket itself to non-blocking. This also
 
283
        allows us to specify when we watch for idle timeouts.
 
284
 
 
285
        :return: Did we timeout? (True if we timed out, False if there is data
 
286
            to be read)
 
287
        """
 
288
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
 
289
 
 
290
    def _build_protocol(self):
 
291
        """Identifies the version of the incoming request, and returns an
 
292
        a protocol object that can interpret it.
 
293
 
 
294
        If more bytes than the version prefix of the request are read, they will
 
295
        be fed into the protocol before it is returned.
 
296
 
 
297
        :returns: a SmartServerRequestProtocol.
 
298
        """
 
299
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
300
        if self.finished:
 
301
            # We're stopping, so don't try to do any more work
 
302
            return None
 
303
        bytes = self._get_line()
 
304
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
305
        protocol = protocol_factory(
 
306
            self.backing_transport, self._write_out, self.root_client_path)
 
307
        protocol.accept_bytes(unused_bytes)
 
308
        return protocol
 
309
 
 
310
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
311
        """select() on a file descriptor, waiting for nonblocking read()
 
312
 
 
313
        This will raise a ConnectionTimeout exception if we do not get a
 
314
        readable handle before timeout_seconds.
 
315
        :return: None
 
316
        """
 
317
        t_end = self._timer() + timeout_seconds
 
318
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
319
        rs = xs = None
 
320
        while not rs and not xs and self._timer() < t_end:
 
321
            if self.finished:
 
322
                return
 
323
            try:
 
324
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
325
            except (select.error, socket.error) as e:
 
326
                err = getattr(e, 'errno', None)
 
327
                if err is None and getattr(e, 'args', None) is not None:
 
328
                    # select.error doesn't have 'errno', it just has args[0]
 
329
                    err = e.args[0]
 
330
                if err in _bad_file_descriptor:
 
331
                    return  # Not a socket indicates read() will fail
 
332
                elif err == errno.EINTR:
 
333
                    # Interrupted, keep looping.
 
334
                    continue
 
335
                raise
 
336
            except ValueError:
 
337
                return  # Socket may already be closed
 
338
        if rs or xs:
 
339
            return
 
340
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
341
                                       % (timeout_seconds,))
 
342
 
 
343
    def _serve_one_request(self, protocol):
 
344
        """Read one request from input, process, send back a response.
 
345
 
 
346
        :param protocol: a SmartServerRequestProtocol.
 
347
        """
 
348
        if protocol is None:
 
349
            return
 
350
        try:
 
351
            self._serve_one_request_unguarded(protocol)
 
352
        except KeyboardInterrupt:
 
353
            raise
 
354
        except Exception as e:
 
355
            self.terminate_due_to_error()
 
356
 
 
357
    def terminate_due_to_error(self):
 
358
        """Called when an unhandled exception from the protocol occurs."""
 
359
        raise NotImplementedError(self.terminate_due_to_error)
 
360
 
 
361
    def _read_bytes(self, desired_count):
 
362
        """Get some bytes from the medium.
 
363
 
 
364
        :param desired_count: number of bytes we want to read.
 
365
        """
 
366
        raise NotImplementedError(self._read_bytes)
 
367
 
 
368
 
 
369
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
370
 
 
371
    def __init__(self, sock, backing_transport, root_client_path='/',
 
372
                 timeout=None):
 
373
        """Constructor.
 
374
 
 
375
        :param sock: the socket the server will read from.  It will be put
 
376
            into blocking mode.
 
377
        """
 
378
        SmartServerStreamMedium.__init__(
 
379
            self, backing_transport, root_client_path=root_client_path,
 
380
            timeout=timeout)
 
381
        sock.setblocking(True)
 
382
        self.socket = sock
 
383
        # Get the getpeername now, as we might be closed later when we care.
 
384
        try:
 
385
            self._client_info = sock.getpeername()
 
386
        except socket.error:
 
387
            self._client_info = '<unknown>'
 
388
 
 
389
    def __str__(self):
 
390
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
391
 
 
392
    def __repr__(self):
 
393
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
394
                                     self._client_info)
 
395
 
 
396
    def _serve_one_request_unguarded(self, protocol):
 
397
        while protocol.next_read_size():
 
398
            # We can safely try to read large chunks.  If there is less data
 
399
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
 
400
            # short read immediately rather than block.
 
401
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
 
402
            if bytes == b'':
 
403
                self.finished = True
 
404
                return
 
405
            protocol.accept_bytes(bytes)
 
406
 
 
407
        self._push_back(protocol.unused_data)
 
408
 
 
409
    def _disconnect_client(self):
 
410
        """Close the current connection. We stopped due to a timeout/etc."""
 
411
        self.socket.close()
 
412
 
 
413
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
414
        """Wait for more bytes to be read, but timeout if none available.
 
415
 
 
416
        This allows us to detect idle connections, and stop trying to read from
 
417
        them, without setting the socket itself to non-blocking. This also
 
418
        allows us to specify when we watch for idle timeouts.
 
419
 
 
420
        :return: None, this will raise ConnectionTimeout if we time out before
 
421
            data is available.
 
422
        """
 
423
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
424
 
 
425
    def _read_bytes(self, desired_count):
 
426
        return osutils.read_bytes_from_socket(
 
427
            self.socket, self._report_activity)
 
428
 
 
429
    def terminate_due_to_error(self):
 
430
        # TODO: This should log to a server log file, but no such thing
 
431
        # exists yet.  Andrew Bennetts 2006-09-29.
 
432
        self.socket.close()
 
433
        self.finished = True
 
434
 
 
435
    def _write_out(self, bytes):
 
436
        tstart = osutils.perf_counter()
 
437
        osutils.send_all(self.socket, bytes, self._report_activity)
 
438
        if 'hpss' in debug.debug_flags:
 
439
            thread_id = _thread.get_ident()
 
440
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
 
441
                         % ('wrote', thread_id, len(bytes),
 
442
                            osutils.perf_counter() - tstart))
 
443
 
 
444
 
 
445
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
446
 
 
447
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
448
        """Construct new server.
 
449
 
 
450
        :param in_file: Python file from which requests can be read.
 
451
        :param out_file: Python file to write responses.
 
452
        :param backing_transport: Transport for the directory served.
 
453
        """
 
454
        SmartServerStreamMedium.__init__(self, backing_transport,
 
455
                                         timeout=timeout)
 
456
        if sys.platform == 'win32':
 
457
            # force binary mode for files
 
458
            import msvcrt
 
459
            for f in (in_file, out_file):
 
460
                fileno = getattr(f, 'fileno', None)
 
461
                if fileno:
 
462
                    msvcrt.setmode(fileno(), os.O_BINARY)
 
463
        self._in = in_file
 
464
        self._out = out_file
 
465
 
 
466
    def serve(self):
 
467
        """See SmartServerStreamMedium.serve"""
 
468
        # This is the regular serve, except it adds signal trapping for soft
 
469
        # shutdown.
 
470
        stop_gracefully = self._stop_gracefully
 
471
        signals.register_on_hangup(id(self), stop_gracefully)
 
472
        try:
 
473
            return super(SmartServerPipeStreamMedium, self).serve()
 
474
        finally:
 
475
            signals.unregister_on_hangup(id(self))
 
476
 
 
477
    def _serve_one_request_unguarded(self, protocol):
 
478
        while True:
 
479
            # We need to be careful not to read past the end of the current
 
480
            # request, or else the read from the pipe will block, so we use
 
481
            # protocol.next_read_size().
 
482
            bytes_to_read = protocol.next_read_size()
 
483
            if bytes_to_read == 0:
 
484
                # Finished serving this request.
 
485
                self._out.flush()
 
486
                return
 
487
            bytes = self.read_bytes(bytes_to_read)
 
488
            if bytes == b'':
 
489
                # Connection has been closed.
 
490
                self.finished = True
 
491
                self._out.flush()
 
492
                return
 
493
            protocol.accept_bytes(bytes)
 
494
 
 
495
    def _disconnect_client(self):
 
496
        self._in.close()
 
497
        self._out.flush()
 
498
        self._out.close()
 
499
 
 
500
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
501
        """Wait for more bytes to be read, but timeout if none available.
 
502
 
 
503
        This allows us to detect idle connections, and stop trying to read from
 
504
        them, without setting the socket itself to non-blocking. This also
 
505
        allows us to specify when we watch for idle timeouts.
 
506
 
 
507
        :return: None, this will raise ConnectionTimeout if we time out before
 
508
            data is available.
 
509
        """
 
510
        if (getattr(self._in, 'fileno', None) is None
 
511
                or sys.platform == 'win32'):
 
512
            # You can't select() file descriptors on Windows.
 
513
            return
 
514
        try:
 
515
            return self._wait_on_descriptor(self._in, timeout_seconds)
 
516
        except io.UnsupportedOperation:
 
517
            return
 
518
 
 
519
    def _read_bytes(self, desired_count):
 
520
        return self._in.read(desired_count)
 
521
 
 
522
    def terminate_due_to_error(self):
 
523
        # TODO: This should log to a server log file, but no such thing
 
524
        # exists yet.  Andrew Bennetts 2006-09-29.
 
525
        self._out.close()
 
526
        self.finished = True
 
527
 
 
528
    def _write_out(self, bytes):
 
529
        self._out.write(bytes)
 
530
 
 
531
 
 
532
class SmartClientMediumRequest(object):
 
533
    """A request on a SmartClientMedium.
 
534
 
 
535
    Each request allows bytes to be provided to it via accept_bytes, and then
 
536
    the response bytes to be read via read_bytes.
 
537
 
 
538
    For instance:
 
539
    request.accept_bytes('123')
 
540
    request.finished_writing()
 
541
    result = request.read_bytes(3)
 
542
    request.finished_reading()
 
543
 
 
544
    It is up to the individual SmartClientMedium whether multiple concurrent
 
545
    requests can exist. See SmartClientMedium.get_request to obtain instances
 
546
    of SmartClientMediumRequest, and the concrete Medium you are using for
 
547
    details on concurrency and pipelining.
 
548
    """
 
549
 
 
550
    def __init__(self, medium):
 
551
        """Construct a SmartClientMediumRequest for the medium medium."""
 
552
        self._medium = medium
 
553
        # we track state by constants - we may want to use the same
 
554
        # pattern as BodyReader if it gets more complex.
 
555
        # valid states are: "writing", "reading", "done"
 
556
        self._state = "writing"
 
557
 
 
558
    def accept_bytes(self, bytes):
 
559
        """Accept bytes for inclusion in this request.
 
560
 
 
561
        This method may not be called after finished_writing() has been
 
562
        called.  It depends upon the Medium whether or not the bytes will be
 
563
        immediately transmitted. Message based Mediums will tend to buffer the
 
564
        bytes until finished_writing() is called.
 
565
 
 
566
        :param bytes: A bytestring.
 
567
        """
 
568
        if self._state != "writing":
 
569
            raise errors.WritingCompleted(self)
 
570
        self._accept_bytes(bytes)
 
571
 
 
572
    def _accept_bytes(self, bytes):
 
573
        """Helper for accept_bytes.
 
574
 
 
575
        Accept_bytes checks the state of the request to determing if bytes
 
576
        should be accepted. After that it hands off to _accept_bytes to do the
 
577
        actual acceptance.
 
578
        """
 
579
        raise NotImplementedError(self._accept_bytes)
 
580
 
 
581
    def finished_reading(self):
 
582
        """Inform the request that all desired data has been read.
 
583
 
 
584
        This will remove the request from the pipeline for its medium (if the
 
585
        medium supports pipelining) and any further calls to methods on the
 
586
        request will raise ReadingCompleted.
 
587
        """
 
588
        if self._state == "writing":
 
589
            raise errors.WritingNotComplete(self)
 
590
        if self._state != "reading":
 
591
            raise errors.ReadingCompleted(self)
 
592
        self._state = "done"
 
593
        self._finished_reading()
 
594
 
 
595
    def _finished_reading(self):
 
596
        """Helper for finished_reading.
 
597
 
 
598
        finished_reading checks the state of the request to determine if
 
599
        finished_reading is allowed, and if it is hands off to _finished_reading
 
600
        to perform the action.
 
601
        """
 
602
        raise NotImplementedError(self._finished_reading)
 
603
 
 
604
    def finished_writing(self):
 
605
        """Finish the writing phase of this request.
 
606
 
 
607
        This will flush all pending data for this request along the medium.
 
608
        After calling finished_writing, you may not call accept_bytes anymore.
 
609
        """
 
610
        if self._state != "writing":
 
611
            raise errors.WritingCompleted(self)
 
612
        self._state = "reading"
 
613
        self._finished_writing()
 
614
 
 
615
    def _finished_writing(self):
 
616
        """Helper for finished_writing.
 
617
 
 
618
        finished_writing checks the state of the request to determine if
 
619
        finished_writing is allowed, and if it is hands off to _finished_writing
 
620
        to perform the action.
 
621
        """
 
622
        raise NotImplementedError(self._finished_writing)
 
623
 
 
624
    def read_bytes(self, count):
 
625
        """Read bytes from this requests response.
 
626
 
 
627
        This method will block and wait for count bytes to be read. It may not
 
628
        be invoked until finished_writing() has been called - this is to ensure
 
629
        a message-based approach to requests, for compatibility with message
 
630
        based mediums like HTTP.
 
631
        """
 
632
        if self._state == "writing":
 
633
            raise errors.WritingNotComplete(self)
 
634
        if self._state != "reading":
 
635
            raise errors.ReadingCompleted(self)
 
636
        return self._read_bytes(count)
 
637
 
 
638
    def _read_bytes(self, count):
 
639
        """Helper for SmartClientMediumRequest.read_bytes.
 
640
 
 
641
        read_bytes checks the state of the request to determing if bytes
 
642
        should be read. After that it hands off to _read_bytes to do the
 
643
        actual read.
 
644
 
 
645
        By default this forwards to self._medium.read_bytes because we are
 
646
        operating on the medium's stream.
 
647
        """
 
648
        return self._medium.read_bytes(count)
 
649
 
 
650
    def read_line(self):
 
651
        line = self._read_line()
 
652
        if not line.endswith(b'\n'):
 
653
            # end of file encountered reading from server
 
654
            raise errors.ConnectionReset(
 
655
                "Unexpected end of message. Please check connectivity "
 
656
                "and permissions, and report a bug if problems persist.")
 
657
        return line
 
658
 
 
659
    def _read_line(self):
 
660
        """Helper for SmartClientMediumRequest.read_line.
 
661
 
 
662
        By default this forwards to self._medium._get_line because we are
 
663
        operating on the medium's stream.
 
664
        """
 
665
        return self._medium._get_line()
 
666
 
 
667
 
 
668
class _VfsRefuser(object):
 
669
    """An object that refuses all VFS requests.
 
670
 
 
671
    """
 
672
 
 
673
    def __init__(self):
 
674
        client._SmartClient.hooks.install_named_hook(
 
675
            'call', self.check_vfs, 'vfs refuser')
 
676
 
 
677
    def check_vfs(self, params):
 
678
        try:
 
679
            request_method = request.request_handlers.get(params.method)
 
680
        except KeyError:
 
681
            # A method we don't know about doesn't count as a VFS method.
 
682
            return
 
683
        if issubclass(request_method, vfs.VfsRequest):
 
684
            raise HpssVfsRequestNotAllowed(params.method, params.args)
 
685
 
 
686
 
 
687
class _DebugCounter(object):
 
688
    """An object that counts the HPSS calls made to each client medium.
 
689
 
 
690
    When a medium is garbage-collected, or failing that when
 
691
    breezy.global_state exits, the total number of calls made on that medium
 
692
    are reported via trace.note.
 
693
    """
 
694
 
 
695
    def __init__(self):
 
696
        self.counts = weakref.WeakKeyDictionary()
 
697
        client._SmartClient.hooks.install_named_hook(
 
698
            'call', self.increment_call_count, 'hpss call counter')
 
699
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
 
700
 
 
701
    def track(self, medium):
 
702
        """Start tracking calls made to a medium.
 
703
 
 
704
        This only keeps a weakref to the medium, so shouldn't affect the
 
705
        medium's lifetime.
 
706
        """
 
707
        medium_repr = repr(medium)
 
708
        # Add this medium to the WeakKeyDictionary
 
709
        self.counts[medium] = dict(count=0, vfs_count=0,
 
710
                                   medium_repr=medium_repr)
 
711
        # Weakref callbacks are fired in reverse order of their association
 
712
        # with the referenced object.  So we add a weakref *after* adding to
 
713
        # the WeakKeyDict so that we can report the value from it before the
 
714
        # entry is removed by the WeakKeyDict's own callback.
 
715
        ref = weakref.ref(medium, self.done)
 
716
 
 
717
    def increment_call_count(self, params):
 
718
        # Increment the count in the WeakKeyDictionary
 
719
        value = self.counts[params.medium]
 
720
        value['count'] += 1
 
721
        try:
 
722
            request_method = request.request_handlers.get(params.method)
 
723
        except KeyError:
 
724
            # A method we don't know about doesn't count as a VFS method.
 
725
            return
 
726
        if issubclass(request_method, vfs.VfsRequest):
 
727
            value['vfs_count'] += 1
 
728
 
 
729
    def done(self, ref):
 
730
        value = self.counts[ref]
 
731
        count, vfs_count, medium_repr = (
 
732
            value['count'], value['vfs_count'], value['medium_repr'])
 
733
        # In case this callback is invoked for the same ref twice (by the
 
734
        # weakref callback and by the atexit function), set the call count back
 
735
        # to 0 so this item won't be reported twice.
 
736
        value['count'] = 0
 
737
        value['vfs_count'] = 0
 
738
        if count != 0:
 
739
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
740
                       count, vfs_count, medium_repr))
 
741
 
 
742
    def flush_all(self):
 
743
        for ref in list(self.counts.keys()):
 
744
            self.done(ref)
 
745
 
 
746
 
 
747
_debug_counter = None
 
748
_vfs_refuser = None
 
749
 
 
750
 
 
751
class SmartClientMedium(SmartMedium):
 
752
    """Smart client is a medium for sending smart protocol requests over."""
 
753
 
 
754
    def __init__(self, base):
 
755
        super(SmartClientMedium, self).__init__()
 
756
        self.base = base
 
757
        self._protocol_version_error = None
 
758
        self._protocol_version = None
 
759
        self._done_hello = False
 
760
        # Be optimistic: we assume the remote end can accept new remote
 
761
        # requests until we get an error saying otherwise.
 
762
        # _remote_version_is_before tracks the bzr version the remote side
 
763
        # can be based on what we've seen so far.
 
764
        self._remote_version_is_before = None
 
765
        # Install debug hook function if debug flag is set.
 
766
        if 'hpss' in debug.debug_flags:
 
767
            global _debug_counter
 
768
            if _debug_counter is None:
 
769
                _debug_counter = _DebugCounter()
 
770
            _debug_counter.track(self)
 
771
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
772
            global _vfs_refuser
 
773
            if _vfs_refuser is None:
 
774
                _vfs_refuser = _VfsRefuser()
 
775
 
 
776
    def _is_remote_before(self, version_tuple):
 
777
        """Is it possible the remote side supports RPCs for a given version?
 
778
 
 
779
        Typical use::
 
780
 
 
781
            needed_version = (1, 2)
 
782
            if medium._is_remote_before(needed_version):
 
783
                fallback_to_pre_1_2_rpc()
 
784
            else:
 
785
                try:
 
786
                    do_1_2_rpc()
 
787
                except UnknownSmartMethod:
 
788
                    medium._remember_remote_is_before(needed_version)
 
789
                    fallback_to_pre_1_2_rpc()
 
790
 
 
791
        :seealso: _remember_remote_is_before
 
792
        """
 
793
        if self._remote_version_is_before is None:
 
794
            # So far, the remote side seems to support everything
 
795
            return False
 
796
        return version_tuple >= self._remote_version_is_before
 
797
 
 
798
    def _remember_remote_is_before(self, version_tuple):
 
799
        """Tell this medium that the remote side is older the given version.
 
800
 
 
801
        :seealso: _is_remote_before
 
802
        """
 
803
        if (self._remote_version_is_before is not None and
 
804
                version_tuple > self._remote_version_is_before):
 
805
            # We have been told that the remote side is older than some version
 
806
            # which is newer than a previously supplied older-than version.
 
807
            # This indicates that some smart verb call is not guarded
 
808
            # appropriately (it should simply not have been tried).
 
809
            trace.mutter(
 
810
                "_remember_remote_is_before(%r) called, but "
 
811
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
 
812
            if 'hpss' in debug.debug_flags:
 
813
                ui.ui_factory.show_warning(
 
814
                    "_remember_remote_is_before(%r) called, but "
 
815
                    "_remember_remote_is_before(%r) was called previously."
 
816
                    % (version_tuple, self._remote_version_is_before))
 
817
            return
 
818
        self._remote_version_is_before = version_tuple
 
819
 
 
820
    def protocol_version(self):
 
821
        """Find out if 'hello' smart request works."""
 
822
        if self._protocol_version_error is not None:
 
823
            raise self._protocol_version_error
 
824
        if not self._done_hello:
 
825
            try:
 
826
                medium_request = self.get_request()
 
827
                # Send a 'hello' request in protocol version one, for maximum
 
828
                # backwards compatibility.
 
829
                client_protocol = protocol.SmartClientRequestProtocolOne(
 
830
                    medium_request)
 
831
                client_protocol.query_version()
 
832
                self._done_hello = True
 
833
            except errors.SmartProtocolError as e:
 
834
                # Cache the error, just like we would cache a successful
 
835
                # result.
 
836
                self._protocol_version_error = e
 
837
                raise
 
838
        return '2'
 
839
 
 
840
    def should_probe(self):
 
841
        """Should RemoteBzrDirFormat.probe_transport send a smart request on
 
842
        this medium?
 
843
 
 
844
        Some transports are unambiguously smart-only; there's no need to check
 
845
        if the transport is able to carry smart requests, because that's all
 
846
        it is for.  In those cases, this method should return False.
 
847
 
 
848
        But some HTTP transports can sometimes fail to carry smart requests,
 
849
        but still be usuable for accessing remote bzrdirs via plain file
 
850
        accesses.  So for those transports, their media should return True here
 
851
        so that RemoteBzrDirFormat can determine if it is appropriate for that
 
852
        transport.
 
853
        """
 
854
        return False
 
855
 
 
856
    def disconnect(self):
 
857
        """If this medium maintains a persistent connection, close it.
 
858
 
 
859
        The default implementation does nothing.
 
860
        """
 
861
 
 
862
    def remote_path_from_transport(self, transport):
 
863
        """Convert transport into a path suitable for using in a request.
 
864
 
 
865
        Note that the resulting remote path doesn't encode the host name or
 
866
        anything but path, so it is only safe to use it in requests sent over
 
867
        the medium from the matching transport.
 
868
        """
 
869
        medium_base = urlutils.join(self.base, '/')
 
870
        rel_url = urlutils.relative_url(medium_base, transport.base)
 
871
        return urlutils.unquote(rel_url)
 
872
 
 
873
 
 
874
class SmartClientStreamMedium(SmartClientMedium):
 
875
    """Stream based medium common class.
 
876
 
 
877
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
878
    SmartClientStreamMediumRequest for their requests, and should implement
 
879
    _accept_bytes and _read_bytes to allow the request objects to send and
 
880
    receive bytes.
 
881
    """
 
882
 
 
883
    def __init__(self, base):
 
884
        SmartClientMedium.__init__(self, base)
 
885
        self._current_request = None
 
886
 
 
887
    def accept_bytes(self, bytes):
 
888
        self._accept_bytes(bytes)
 
889
 
 
890
    def __del__(self):
 
891
        """The SmartClientStreamMedium knows how to close the stream when it is
 
892
        finished with it.
 
893
        """
 
894
        self.disconnect()
 
895
 
 
896
    def _flush(self):
 
897
        """Flush the output stream.
 
898
 
 
899
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
900
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
901
        """
 
902
        raise NotImplementedError(self._flush)
 
903
 
 
904
    def get_request(self):
 
905
        """See SmartClientMedium.get_request().
 
906
 
 
907
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
908
        for get_request.
 
909
        """
 
910
        return SmartClientStreamMediumRequest(self)
 
911
 
 
912
    def reset(self):
 
913
        """We have been disconnected, reset current state.
 
914
 
 
915
        This resets things like _current_request and connected state.
 
916
        """
 
917
        self.disconnect()
 
918
        self._current_request = None
 
919
 
 
920
 
 
921
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
922
    """A client medium using simple pipes.
 
923
 
 
924
    This client does not manage the pipes: it assumes they will always be open.
 
925
    """
 
926
 
 
927
    def __init__(self, readable_pipe, writeable_pipe, base):
 
928
        SmartClientStreamMedium.__init__(self, base)
 
929
        self._readable_pipe = readable_pipe
 
930
        self._writeable_pipe = writeable_pipe
 
931
 
 
932
    def _accept_bytes(self, data):
 
933
        """See SmartClientStreamMedium.accept_bytes."""
 
934
        try:
 
935
            self._writeable_pipe.write(data)
 
936
        except IOError as e:
 
937
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
938
                raise errors.ConnectionReset(
 
939
                    "Error trying to write to subprocess", e)
 
940
            raise
 
941
        self._report_activity(len(data), 'write')
 
942
 
 
943
    def _flush(self):
 
944
        """See SmartClientStreamMedium._flush()."""
 
945
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
946
        #       However, testing shows that even when the child process is
 
947
        #       gone, this doesn't error.
 
948
        self._writeable_pipe.flush()
 
949
 
 
950
    def _read_bytes(self, count):
 
951
        """See SmartClientStreamMedium._read_bytes."""
 
952
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
953
        data = self._readable_pipe.read(bytes_to_read)
 
954
        self._report_activity(len(data), 'read')
 
955
        return data
 
956
 
 
957
 
 
958
class SSHParams(object):
 
959
    """A set of parameters for starting a remote bzr via SSH."""
 
960
 
 
961
    def __init__(self, host, port=None, username=None, password=None,
 
962
                 bzr_remote_path='bzr'):
 
963
        self.host = host
 
964
        self.port = port
 
965
        self.username = username
 
966
        self.password = password
 
967
        self.bzr_remote_path = bzr_remote_path
 
968
 
 
969
 
 
970
class SmartSSHClientMedium(SmartClientStreamMedium):
 
971
    """A client medium using SSH.
 
972
 
 
973
    It delegates IO to a SmartSimplePipesClientMedium or
 
974
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
975
    """
 
976
 
 
977
    def __init__(self, base, ssh_params, vendor=None):
 
978
        """Creates a client that will connect on the first use.
 
979
 
 
980
        :param ssh_params: A SSHParams instance.
 
981
        :param vendor: An optional override for the ssh vendor to use. See
 
982
            breezy.transport.ssh for details on ssh vendors.
 
983
        """
 
984
        self._real_medium = None
 
985
        self._ssh_params = ssh_params
 
986
        # for the benefit of progress making a short description of this
 
987
        # transport
 
988
        self._scheme = 'bzr+ssh'
 
989
        # SmartClientStreamMedium stores the repr of this object in its
 
990
        # _DebugCounter so we have to store all the values used in our repr
 
991
        # method before calling the super init.
 
992
        SmartClientStreamMedium.__init__(self, base)
 
993
        self._vendor = vendor
 
994
        self._ssh_connection = None
 
995
 
 
996
    def __repr__(self):
 
997
        if self._ssh_params.port is None:
 
998
            maybe_port = ''
 
999
        else:
 
1000
            maybe_port = ':%s' % self._ssh_params.port
 
1001
        if self._ssh_params.username is None:
 
1002
            maybe_user = ''
 
1003
        else:
 
1004
            maybe_user = '%s@' % self._ssh_params.username
 
1005
        return "%s(%s://%s%s%s/)" % (
 
1006
            self.__class__.__name__,
 
1007
            self._scheme,
 
1008
            maybe_user,
 
1009
            self._ssh_params.host,
 
1010
            maybe_port)
 
1011
 
 
1012
    def _accept_bytes(self, bytes):
 
1013
        """See SmartClientStreamMedium.accept_bytes."""
 
1014
        self._ensure_connection()
 
1015
        self._real_medium.accept_bytes(bytes)
 
1016
 
 
1017
    def disconnect(self):
 
1018
        """See SmartClientMedium.disconnect()."""
 
1019
        if self._real_medium is not None:
 
1020
            self._real_medium.disconnect()
 
1021
            self._real_medium = None
 
1022
        if self._ssh_connection is not None:
 
1023
            self._ssh_connection.close()
 
1024
            self._ssh_connection = None
 
1025
 
 
1026
    def _ensure_connection(self):
 
1027
        """Connect this medium if not already connected."""
 
1028
        if self._real_medium is not None:
 
1029
            return
 
1030
        if self._vendor is None:
 
1031
            vendor = ssh._get_ssh_vendor()
 
1032
        else:
 
1033
            vendor = self._vendor
 
1034
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1035
                                                  self._ssh_params.password, self._ssh_params.host,
 
1036
                                                  self._ssh_params.port,
 
1037
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
1038
                                                           '--directory=/', '--allow-writes'])
 
1039
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1040
        if io_kind == 'socket':
 
1041
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1042
                self.base, io_object)
 
1043
        elif io_kind == 'pipes':
 
1044
            read_from, write_to = io_object
 
1045
            self._real_medium = SmartSimplePipesClientMedium(
 
1046
                read_from, write_to, self.base)
 
1047
        else:
 
1048
            raise AssertionError(
 
1049
                "Unexpected io_kind %r from %r"
 
1050
                % (io_kind, self._ssh_connection))
 
1051
        for hook in transport.Transport.hooks["post_connect"]:
 
1052
            hook(self)
 
1053
 
 
1054
    def _flush(self):
 
1055
        """See SmartClientStreamMedium._flush()."""
 
1056
        self._real_medium._flush()
 
1057
 
 
1058
    def _read_bytes(self, count):
 
1059
        """See SmartClientStreamMedium.read_bytes."""
 
1060
        if self._real_medium is None:
 
1061
            raise errors.MediumNotConnected(self)
 
1062
        return self._real_medium.read_bytes(count)
 
1063
 
 
1064
 
 
1065
# Port 4155 is the default port for bzr://, registered with IANA.
 
1066
BZR_DEFAULT_INTERFACE = None
 
1067
BZR_DEFAULT_PORT = 4155
 
1068
 
 
1069
 
 
1070
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1071
    """A client medium using a socket.
 
1072
 
 
1073
    This class isn't usable directly.  Use one of its subclasses instead.
 
1074
    """
 
1075
 
 
1076
    def __init__(self, base):
 
1077
        SmartClientStreamMedium.__init__(self, base)
 
1078
        self._socket = None
 
1079
        self._connected = False
 
1080
 
 
1081
    def _accept_bytes(self, bytes):
 
1082
        """See SmartClientMedium.accept_bytes."""
 
1083
        self._ensure_connection()
 
1084
        osutils.send_all(self._socket, bytes, self._report_activity)
 
1085
 
 
1086
    def _ensure_connection(self):
 
1087
        """Connect this medium if not already connected."""
 
1088
        raise NotImplementedError(self._ensure_connection)
 
1089
 
 
1090
    def _flush(self):
 
1091
        """See SmartClientStreamMedium._flush().
 
1092
 
 
1093
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1094
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1095
        future.
 
1096
        """
 
1097
 
 
1098
    def _read_bytes(self, count):
 
1099
        """See SmartClientMedium.read_bytes."""
 
1100
        if not self._connected:
 
1101
            raise errors.MediumNotConnected(self)
 
1102
        return osutils.read_bytes_from_socket(
 
1103
            self._socket, self._report_activity)
 
1104
 
 
1105
    def disconnect(self):
 
1106
        """See SmartClientMedium.disconnect()."""
 
1107
        if not self._connected:
 
1108
            return
 
1109
        self._socket.close()
 
1110
        self._socket = None
 
1111
        self._connected = False
 
1112
 
 
1113
 
 
1114
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1115
    """A client medium that creates a TCP connection."""
 
1116
 
 
1117
    def __init__(self, host, port, base):
 
1118
        """Creates a client that will connect on the first use."""
 
1119
        SmartClientSocketMedium.__init__(self, base)
 
1120
        self._host = host
 
1121
        self._port = port
 
1122
 
 
1123
    def _ensure_connection(self):
 
1124
        """Connect this medium if not already connected."""
 
1125
        if self._connected:
 
1126
            return
 
1127
        if self._port is None:
 
1128
            port = BZR_DEFAULT_PORT
 
1129
        else:
 
1130
            port = int(self._port)
 
1131
        try:
 
1132
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
 
1133
                                           socket.SOCK_STREAM, 0, 0)
 
1134
        except socket.gaierror as xxx_todo_changeme:
 
1135
            (err_num, err_msg) = xxx_todo_changeme.args
 
1136
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
 
1137
                                         (self._host, port, err_msg))
 
1138
        # Initialize err in case there are no addresses returned:
 
1139
        last_err = socket.error("no address found for %s" % self._host)
 
1140
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
 
1141
            try:
 
1142
                self._socket = socket.socket(family, socktype, proto)
 
1143
                self._socket.setsockopt(socket.IPPROTO_TCP,
 
1144
                                        socket.TCP_NODELAY, 1)
 
1145
                self._socket.connect(sockaddr)
 
1146
            except socket.error as err:
 
1147
                if self._socket is not None:
 
1148
                    self._socket.close()
 
1149
                self._socket = None
 
1150
                last_err = err
 
1151
                continue
 
1152
            break
 
1153
        if self._socket is None:
 
1154
            # socket errors either have a (string) or (errno, string) as their
 
1155
            # args.
 
1156
            if isinstance(last_err.args, str):
 
1157
                err_msg = last_err.args
 
1158
            else:
 
1159
                err_msg = last_err.args[1]
 
1160
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
1161
                                         (self._host, port, err_msg))
 
1162
        self._connected = True
 
1163
        for hook in transport.Transport.hooks["post_connect"]:
 
1164
            hook(self)
 
1165
 
 
1166
 
 
1167
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1168
    """A client medium for an already connected socket.
 
1169
 
 
1170
    Note that this class will assume it "owns" the socket, so it will close it
 
1171
    when its disconnect method is called.
 
1172
    """
 
1173
 
 
1174
    def __init__(self, base, sock):
 
1175
        SmartClientSocketMedium.__init__(self, base)
 
1176
        self._socket = sock
 
1177
        self._connected = True
 
1178
 
 
1179
    def _ensure_connection(self):
 
1180
        # Already connected, by definition!  So nothing to do.
 
1181
        pass
 
1182
 
 
1183
 
 
1184
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
1185
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
1186
 
 
1187
    def __init__(self, medium):
 
1188
        SmartClientMediumRequest.__init__(self, medium)
 
1189
        # check that we are safe concurrency wise. If some streams start
 
1190
        # allowing concurrent requests - i.e. via multiplexing - then this
 
1191
        # assert should be moved to SmartClientStreamMedium.get_request,
 
1192
        # and the setting/unsetting of _current_request likewise moved into
 
1193
        # that class : but its unneeded overhead for now. RBC 20060922
 
1194
        if self._medium._current_request is not None:
 
1195
            raise errors.TooManyConcurrentRequests(self._medium)
 
1196
        self._medium._current_request = self
 
1197
 
 
1198
    def _accept_bytes(self, bytes):
 
1199
        """See SmartClientMediumRequest._accept_bytes.
 
1200
 
 
1201
        This forwards to self._medium._accept_bytes because we are operating
 
1202
        on the mediums stream.
 
1203
        """
 
1204
        self._medium._accept_bytes(bytes)
 
1205
 
 
1206
    def _finished_reading(self):
 
1207
        """See SmartClientMediumRequest._finished_reading.
 
1208
 
 
1209
        This clears the _current_request on self._medium to allow a new
 
1210
        request to be created.
 
1211
        """
 
1212
        if self._medium._current_request is not self:
 
1213
            raise AssertionError()
 
1214
        self._medium._current_request = None
 
1215
 
 
1216
    def _finished_writing(self):
 
1217
        """See SmartClientMediumRequest._finished_writing.
 
1218
 
 
1219
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
1220
        """
 
1221
        self._medium._flush()