/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Robert Collins
  • Date: 2010-05-06 11:08:10 UTC
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100506110810-h3j07fh5gmw54s25
Cleaner matcher matching revised unlocking protocol.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
24
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
 
import errno
30
27
import os
31
28
import sys
32
 
import time
 
29
import urllib
33
30
 
34
 
import bzrlib
35
31
from bzrlib.lazy_import import lazy_import
36
32
lazy_import(globals(), """
37
 
import select
 
33
import atexit
38
34
import socket
39
35
import thread
40
36
import weakref
42
38
from bzrlib import (
43
39
    debug,
44
40
    errors,
 
41
    symbol_versioning,
45
42
    trace,
46
43
    ui,
47
44
    urlutils,
48
45
    )
49
 
from bzrlib.i18n import gettext
50
 
from bzrlib.smart import client, protocol, request, signals, vfs
 
46
from bzrlib.smart import client, protocol, request, vfs
51
47
from bzrlib.transport import ssh
52
48
""")
53
49
from bzrlib import osutils
180
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
181
177
 
182
178
 
183
 
_bad_file_descriptor = (errno.EBADF,)
184
 
if sys.platform == 'win32':
185
 
    # Given on Windows if you pass a closed socket to select.select. Probably
186
 
    # also given if you pass a file handle to select.
187
 
    WSAENOTSOCK = 10038
188
 
    _bad_file_descriptor += (WSAENOTSOCK,)
189
 
 
190
 
 
191
179
class SmartServerStreamMedium(SmartMedium):
192
180
    """Handles smart commands coming over a stream.
193
181
 
206
194
        the stream.  See also the _push_back method.
207
195
    """
208
196
 
209
 
    _timer = time.time
210
 
 
211
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
212
198
        """Construct new server.
213
199
 
214
200
        :param backing_transport: Transport for the directory served.
217
203
        self.backing_transport = backing_transport
218
204
        self.root_client_path = root_client_path
219
205
        self.finished = False
220
 
        if timeout is None:
221
 
            raise AssertionError('You must supply a timeout.')
222
 
        self._client_timeout = timeout
223
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
224
206
        SmartMedium.__init__(self)
225
207
 
226
208
    def serve(self):
232
214
            while not self.finished:
233
215
                server_protocol = self._build_protocol()
234
216
                self._serve_one_request(server_protocol)
235
 
        except errors.ConnectionTimeout, e:
236
 
            trace.note('%s' % (e,))
237
 
            trace.log_exception_quietly()
238
 
            self._disconnect_client()
239
 
            # We reported it, no reason to make a big fuss.
240
 
            return
241
217
        except Exception, e:
242
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
243
219
            raise
244
 
        self._disconnect_client()
245
 
 
246
 
    def _stop_gracefully(self):
247
 
        """When we finish this message, stop looking for more."""
248
 
        trace.mutter('Stopping %s' % (self,))
249
 
        self.finished = True
250
 
 
251
 
    def _disconnect_client(self):
252
 
        """Close the current connection. We stopped due to a timeout/etc."""
253
 
        # The default implementation is a no-op, because that is all we used to
254
 
        # do when disconnecting from a client. I suppose we never had the
255
 
        # *server* initiate a disconnect, before
256
 
 
257
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
258
 
        """Wait for more bytes to be read, but timeout if none available.
259
 
 
260
 
        This allows us to detect idle connections, and stop trying to read from
261
 
        them, without setting the socket itself to non-blocking. This also
262
 
        allows us to specify when we watch for idle timeouts.
263
 
 
264
 
        :return: Did we timeout? (True if we timed out, False if there is data
265
 
            to be read)
266
 
        """
267
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
268
220
 
269
221
    def _build_protocol(self):
270
222
        """Identifies the version of the incoming request, and returns an
275
227
 
276
228
        :returns: a SmartServerRequestProtocol.
277
229
        """
278
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
279
 
        if self.finished:
280
 
            # We're stopping, so don't try to do any more work
281
 
            return None
282
230
        bytes = self._get_line()
283
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
284
232
        protocol = protocol_factory(
286
234
        protocol.accept_bytes(unused_bytes)
287
235
        return protocol
288
236
 
289
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
290
 
        """select() on a file descriptor, waiting for nonblocking read()
291
 
 
292
 
        This will raise a ConnectionTimeout exception if we do not get a
293
 
        readable handle before timeout_seconds.
294
 
        :return: None
295
 
        """
296
 
        t_end = self._timer() + timeout_seconds
297
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
298
 
        rs = xs = None
299
 
        while not rs and not xs and self._timer() < t_end:
300
 
            if self.finished:
301
 
                return
302
 
            try:
303
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
304
 
            except (select.error, socket.error) as e:
305
 
                err = getattr(e, 'errno', None)
306
 
                if err is None and getattr(e, 'args', None) is not None:
307
 
                    # select.error doesn't have 'errno', it just has args[0]
308
 
                    err = e.args[0]
309
 
                if err in _bad_file_descriptor:
310
 
                    return # Not a socket indicates read() will fail
311
 
                elif err == errno.EINTR:
312
 
                    # Interrupted, keep looping.
313
 
                    continue
314
 
                raise
315
 
        if rs or xs:
316
 
            return
317
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
318
 
                                       % (timeout_seconds,))
319
 
 
320
237
    def _serve_one_request(self, protocol):
321
238
        """Read one request from input, process, send back a response.
322
239
 
323
240
        :param protocol: a SmartServerRequestProtocol.
324
241
        """
325
 
        if protocol is None:
326
 
            return
327
242
        try:
328
243
            self._serve_one_request_unguarded(protocol)
329
244
        except KeyboardInterrupt:
345
260
 
346
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
347
262
 
348
 
    def __init__(self, sock, backing_transport, root_client_path='/',
349
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
350
264
        """Constructor.
351
265
 
352
266
        :param sock: the socket the server will read from.  It will be put
353
267
            into blocking mode.
354
268
        """
355
269
        SmartServerStreamMedium.__init__(
356
 
            self, backing_transport, root_client_path=root_client_path,
357
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
358
271
        sock.setblocking(True)
359
272
        self.socket = sock
360
 
        # Get the getpeername now, as we might be closed later when we care.
361
 
        try:
362
 
            self._client_info = sock.getpeername()
363
 
        except socket.error:
364
 
            self._client_info = '<unknown>'
365
 
 
366
 
    def __str__(self):
367
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
368
 
 
369
 
    def __repr__(self):
370
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
371
 
            self._client_info)
372
273
 
373
274
    def _serve_one_request_unguarded(self, protocol):
374
275
        while protocol.next_read_size():
383
284
 
384
285
        self._push_back(protocol.unused_data)
385
286
 
386
 
    def _disconnect_client(self):
387
 
        """Close the current connection. We stopped due to a timeout/etc."""
388
 
        self.socket.close()
389
 
 
390
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
391
 
        """Wait for more bytes to be read, but timeout if none available.
392
 
 
393
 
        This allows us to detect idle connections, and stop trying to read from
394
 
        them, without setting the socket itself to non-blocking. This also
395
 
        allows us to specify when we watch for idle timeouts.
396
 
 
397
 
        :return: None, this will raise ConnectionTimeout if we time out before
398
 
            data is available.
399
 
        """
400
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
401
 
 
402
287
    def _read_bytes(self, desired_count):
403
288
        return osutils.read_bytes_from_socket(
404
289
            self.socket, self._report_activity)
421
306
 
422
307
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
423
308
 
424
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
309
    def __init__(self, in_file, out_file, backing_transport):
425
310
        """Construct new server.
426
311
 
427
312
        :param in_file: Python file from which requests can be read.
428
313
        :param out_file: Python file to write responses.
429
314
        :param backing_transport: Transport for the directory served.
430
315
        """
431
 
        SmartServerStreamMedium.__init__(self, backing_transport,
432
 
            timeout=timeout)
 
316
        SmartServerStreamMedium.__init__(self, backing_transport)
433
317
        if sys.platform == 'win32':
434
318
            # force binary mode for files
435
319
            import msvcrt
440
324
        self._in = in_file
441
325
        self._out = out_file
442
326
 
443
 
    def serve(self):
444
 
        """See SmartServerStreamMedium.serve"""
445
 
        # This is the regular serve, except it adds signal trapping for soft
446
 
        # shutdown.
447
 
        stop_gracefully = self._stop_gracefully
448
 
        signals.register_on_hangup(id(self), stop_gracefully)
449
 
        try:
450
 
            return super(SmartServerPipeStreamMedium, self).serve()
451
 
        finally:
452
 
            signals.unregister_on_hangup(id(self))
453
 
 
454
327
    def _serve_one_request_unguarded(self, protocol):
455
328
        while True:
456
329
            # We need to be careful not to read past the end of the current
469
342
                return
470
343
            protocol.accept_bytes(bytes)
471
344
 
472
 
    def _disconnect_client(self):
473
 
        self._in.close()
474
 
        self._out.flush()
475
 
        self._out.close()
476
 
 
477
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
478
 
        """Wait for more bytes to be read, but timeout if none available.
479
 
 
480
 
        This allows us to detect idle connections, and stop trying to read from
481
 
        them, without setting the socket itself to non-blocking. This also
482
 
        allows us to specify when we watch for idle timeouts.
483
 
 
484
 
        :return: None, this will raise ConnectionTimeout if we time out before
485
 
            data is available.
486
 
        """
487
 
        if (getattr(self._in, 'fileno', None) is None
488
 
            or sys.platform == 'win32'):
489
 
            # You can't select() file descriptors on Windows.
490
 
            return
491
 
        return self._wait_on_descriptor(self._in, timeout_seconds)
492
 
 
493
345
    def _read_bytes(self, desired_count):
494
346
        return self._in.read(desired_count)
495
347
 
639
491
        return self._medium._get_line()
640
492
 
641
493
 
642
 
class _VfsRefuser(object):
643
 
    """An object that refuses all VFS requests.
644
 
 
645
 
    """
646
 
 
647
 
    def __init__(self):
648
 
        client._SmartClient.hooks.install_named_hook(
649
 
            'call', self.check_vfs, 'vfs refuser')
650
 
 
651
 
    def check_vfs(self, params):
652
 
        try:
653
 
            request_method = request.request_handlers.get(params.method)
654
 
        except KeyError:
655
 
            # A method we don't know about doesn't count as a VFS method.
656
 
            return
657
 
        if issubclass(request_method, vfs.VfsRequest):
658
 
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
659
 
 
660
 
 
661
494
class _DebugCounter(object):
662
495
    """An object that counts the HPSS calls made to each client medium.
663
496
 
664
 
    When a medium is garbage-collected, or failing that when
665
 
    bzrlib.global_state exits, the total number of calls made on that medium
666
 
    are reported via trace.note.
 
497
    When a medium is garbage-collected, or failing that when atexit functions
 
498
    are run, the total number of calls made on that medium are reported via
 
499
    trace.note.
667
500
    """
668
501
 
669
502
    def __init__(self):
670
503
        self.counts = weakref.WeakKeyDictionary()
671
504
        client._SmartClient.hooks.install_named_hook(
672
505
            'call', self.increment_call_count, 'hpss call counter')
673
 
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
 
506
        atexit.register(self.flush_all)
674
507
 
675
508
    def track(self, medium):
676
509
        """Start tracking calls made to a medium.
710
543
        value['count'] = 0
711
544
        value['vfs_count'] = 0
712
545
        if count != 0:
713
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
714
 
                       count, vfs_count, medium_repr))
 
546
            trace.note('HPSS calls: %d (%d vfs) %s',
 
547
                       count, vfs_count, medium_repr)
715
548
 
716
549
    def flush_all(self):
717
550
        for ref in list(self.counts.keys()):
718
551
            self.done(ref)
719
552
 
720
553
_debug_counter = None
721
 
_vfs_refuser = None
722
554
 
723
555
 
724
556
class SmartClientMedium(SmartMedium):
741
573
            if _debug_counter is None:
742
574
                _debug_counter = _DebugCounter()
743
575
            _debug_counter.track(self)
744
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
745
 
            global _vfs_refuser
746
 
            if _vfs_refuser is None:
747
 
                _vfs_refuser = _VfsRefuser()
748
576
 
749
577
    def _is_remote_before(self, version_tuple):
750
578
        """Is it possible the remote side supports RPCs for a given version?
841
669
        """
842
670
        medium_base = urlutils.join(self.base, '/')
843
671
        rel_url = urlutils.relative_url(medium_base, transport.base)
844
 
        return urlutils.unquote(rel_url)
 
672
        return urllib.unquote(rel_url)
845
673
 
846
674
 
847
675
class SmartClientStreamMedium(SmartClientMedium):
882
710
        """
883
711
        return SmartClientStreamMediumRequest(self)
884
712
 
885
 
    def reset(self):
886
 
        """We have been disconnected, reset current state.
887
 
 
888
 
        This resets things like _current_request and connected state.
889
 
        """
890
 
        self.disconnect()
891
 
        self._current_request = None
892
 
 
893
713
 
894
714
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
895
715
    """A client medium using simple pipes.
896
716
 
897
717
    This client does not manage the pipes: it assumes they will always be open.
 
718
 
 
719
    Note that if readable_pipe.read might raise IOError or OSError with errno
 
720
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
 
721
    (such as used for sys.stdin) are safe.
898
722
    """
899
723
 
900
724
    def __init__(self, readable_pipe, writeable_pipe, base):
904
728
 
905
729
    def _accept_bytes(self, bytes):
906
730
        """See SmartClientStreamMedium.accept_bytes."""
907
 
        try:
908
 
            self._writeable_pipe.write(bytes)
909
 
        except IOError, e:
910
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
911
 
                raise errors.ConnectionReset(
912
 
                    "Error trying to write to subprocess:\n%s" % (e,))
913
 
            raise
 
731
        self._writeable_pipe.write(bytes)
914
732
        self._report_activity(len(bytes), 'write')
915
733
 
916
734
    def _flush(self):
917
735
        """See SmartClientStreamMedium._flush()."""
918
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
919
 
        #       However, testing shows that even when the child process is
920
 
        #       gone, this doesn't error.
921
736
        self._writeable_pipe.flush()
922
737
 
923
738
    def _read_bytes(self, count):
924
739
        """See SmartClientStreamMedium._read_bytes."""
925
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
926
 
        bytes = self._readable_pipe.read(bytes_to_read)
 
740
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
927
741
        self._report_activity(len(bytes), 'read')
928
742
        return bytes
929
743
 
930
744
 
931
 
class SSHParams(object):
932
 
    """A set of parameters for starting a remote bzr via SSH."""
 
745
class SmartSSHClientMedium(SmartClientStreamMedium):
 
746
    """A client medium using SSH."""
933
747
 
934
748
    def __init__(self, host, port=None, username=None, password=None,
935
 
            bzr_remote_path='bzr'):
936
 
        self.host = host
937
 
        self.port = port
938
 
        self.username = username
939
 
        self.password = password
940
 
        self.bzr_remote_path = bzr_remote_path
941
 
 
942
 
 
943
 
class SmartSSHClientMedium(SmartClientStreamMedium):
944
 
    """A client medium using SSH.
945
 
 
946
 
    It delegates IO to a SmartSimplePipesClientMedium or
947
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
948
 
    """
949
 
 
950
 
    def __init__(self, base, ssh_params, vendor=None):
 
749
            base=None, vendor=None, bzr_remote_path=None):
951
750
        """Creates a client that will connect on the first use.
952
751
 
953
 
        :param ssh_params: A SSHParams instance.
954
752
        :param vendor: An optional override for the ssh vendor to use. See
955
753
            bzrlib.transport.ssh for details on ssh vendors.
956
754
        """
957
 
        self._real_medium = None
958
 
        self._ssh_params = ssh_params
 
755
        self._connected = False
 
756
        self._host = host
 
757
        self._password = password
 
758
        self._port = port
 
759
        self._username = username
959
760
        # for the benefit of progress making a short description of this
960
761
        # transport
961
762
        self._scheme = 'bzr+ssh'
963
764
        # _DebugCounter so we have to store all the values used in our repr
964
765
        # method before calling the super init.
965
766
        SmartClientStreamMedium.__init__(self, base)
 
767
        self._read_from = None
 
768
        self._ssh_connection = None
966
769
        self._vendor = vendor
967
 
        self._ssh_connection = None
 
770
        self._write_to = None
 
771
        self._bzr_remote_path = bzr_remote_path
968
772
 
969
773
    def __repr__(self):
970
 
        if self._ssh_params.port is None:
 
774
        if self._port is None:
971
775
            maybe_port = ''
972
776
        else:
973
 
            maybe_port = ':%s' % self._ssh_params.port
974
 
        if self._ssh_params.username is None:
975
 
            maybe_user = ''
976
 
        else:
977
 
            maybe_user = '%s@' % self._ssh_params.username
978
 
        return "%s(%s://%s%s%s/)" % (
 
777
            maybe_port = ':%s' % self._port
 
778
        return "%s(%s://%s@%s%s/)" % (
979
779
            self.__class__.__name__,
980
780
            self._scheme,
981
 
            maybe_user,
982
 
            self._ssh_params.host,
 
781
            self._username,
 
782
            self._host,
983
783
            maybe_port)
984
784
 
985
785
    def _accept_bytes(self, bytes):
986
786
        """See SmartClientStreamMedium.accept_bytes."""
987
787
        self._ensure_connection()
988
 
        self._real_medium.accept_bytes(bytes)
 
788
        self._write_to.write(bytes)
 
789
        self._report_activity(len(bytes), 'write')
989
790
 
990
791
    def disconnect(self):
991
792
        """See SmartClientMedium.disconnect()."""
992
 
        if self._real_medium is not None:
993
 
            self._real_medium.disconnect()
994
 
            self._real_medium = None
995
 
        if self._ssh_connection is not None:
996
 
            self._ssh_connection.close()
997
 
            self._ssh_connection = None
 
793
        if not self._connected:
 
794
            return
 
795
        self._read_from.close()
 
796
        self._write_to.close()
 
797
        self._ssh_connection.close()
 
798
        self._connected = False
998
799
 
999
800
    def _ensure_connection(self):
1000
801
        """Connect this medium if not already connected."""
1001
 
        if self._real_medium is not None:
 
802
        if self._connected:
1002
803
            return
1003
804
        if self._vendor is None:
1004
805
            vendor = ssh._get_ssh_vendor()
1005
806
        else:
1006
807
            vendor = self._vendor
1007
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1008
 
                self._ssh_params.password, self._ssh_params.host,
1009
 
                self._ssh_params.port,
1010
 
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
808
        self._ssh_connection = vendor.connect_ssh(self._username,
 
809
                self._password, self._host, self._port,
 
810
                command=[self._bzr_remote_path, 'serve', '--inet',
1011
811
                         '--directory=/', '--allow-writes'])
1012
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1013
 
        if io_kind == 'socket':
1014
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1015
 
                self.base, io_object)
1016
 
        elif io_kind == 'pipes':
1017
 
            read_from, write_to = io_object
1018
 
            self._real_medium = SmartSimplePipesClientMedium(
1019
 
                read_from, write_to, self.base)
1020
 
        else:
1021
 
            raise AssertionError(
1022
 
                "Unexpected io_kind %r from %r"
1023
 
                % (io_kind, self._ssh_connection))
 
812
        self._read_from, self._write_to = \
 
813
            self._ssh_connection.get_filelike_channels()
 
814
        self._connected = True
1024
815
 
1025
816
    def _flush(self):
1026
817
        """See SmartClientStreamMedium._flush()."""
1027
 
        self._real_medium._flush()
 
818
        self._write_to.flush()
1028
819
 
1029
820
    def _read_bytes(self, count):
1030
821
        """See SmartClientStreamMedium.read_bytes."""
1031
 
        if self._real_medium is None:
 
822
        if not self._connected:
1032
823
            raise errors.MediumNotConnected(self)
1033
 
        return self._real_medium.read_bytes(count)
 
824
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
825
        bytes = self._read_from.read(bytes_to_read)
 
826
        self._report_activity(len(bytes), 'read')
 
827
        return bytes
1034
828
 
1035
829
 
1036
830
# Port 4155 is the default port for bzr://, registered with IANA.
1038
832
BZR_DEFAULT_PORT = 4155
1039
833
 
1040
834
 
1041
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1042
 
    """A client medium using a socket.
1043
 
    
1044
 
    This class isn't usable directly.  Use one of its subclasses instead.
1045
 
    """
 
835
class SmartTCPClientMedium(SmartClientStreamMedium):
 
836
    """A client medium using TCP."""
1046
837
 
1047
 
    def __init__(self, base):
 
838
    def __init__(self, host, port, base):
 
839
        """Creates a client that will connect on the first use."""
1048
840
        SmartClientStreamMedium.__init__(self, base)
 
841
        self._connected = False
 
842
        self._host = host
 
843
        self._port = port
1049
844
        self._socket = None
1050
 
        self._connected = False
1051
845
 
1052
846
    def _accept_bytes(self, bytes):
1053
847
        """See SmartClientMedium.accept_bytes."""
1054
848
        self._ensure_connection()
1055
849
        osutils.send_all(self._socket, bytes, self._report_activity)
1056
850
 
1057
 
    def _ensure_connection(self):
1058
 
        """Connect this medium if not already connected."""
1059
 
        raise NotImplementedError(self._ensure_connection)
1060
 
 
1061
 
    def _flush(self):
1062
 
        """See SmartClientStreamMedium._flush().
1063
 
 
1064
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1065
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1066
 
        future.
1067
 
        """
1068
 
 
1069
 
    def _read_bytes(self, count):
1070
 
        """See SmartClientMedium.read_bytes."""
1071
 
        if not self._connected:
1072
 
            raise errors.MediumNotConnected(self)
1073
 
        return osutils.read_bytes_from_socket(
1074
 
            self._socket, self._report_activity)
1075
 
 
1076
851
    def disconnect(self):
1077
852
        """See SmartClientMedium.disconnect()."""
1078
853
        if not self._connected:
1081
856
        self._socket = None
1082
857
        self._connected = False
1083
858
 
1084
 
 
1085
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1086
 
    """A client medium that creates a TCP connection."""
1087
 
 
1088
 
    def __init__(self, host, port, base):
1089
 
        """Creates a client that will connect on the first use."""
1090
 
        SmartClientSocketMedium.__init__(self, base)
1091
 
        self._host = host
1092
 
        self._port = port
1093
 
 
1094
859
    def _ensure_connection(self):
1095
860
        """Connect this medium if not already connected."""
1096
861
        if self._connected:
1130
895
                    (self._host, port, err_msg))
1131
896
        self._connected = True
1132
897
 
1133
 
 
1134
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1135
 
    """A client medium for an already connected socket.
1136
 
    
1137
 
    Note that this class will assume it "owns" the socket, so it will close it
1138
 
    when its disconnect method is called.
1139
 
    """
1140
 
 
1141
 
    def __init__(self, base, sock):
1142
 
        SmartClientSocketMedium.__init__(self, base)
1143
 
        self._socket = sock
1144
 
        self._connected = True
1145
 
 
1146
 
    def _ensure_connection(self):
1147
 
        # Already connected, by definition!  So nothing to do.
1148
 
        pass
 
898
    def _flush(self):
 
899
        """See SmartClientStreamMedium._flush().
 
900
 
 
901
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
902
        add a means to do a flush, but that can be done in the future.
 
903
        """
 
904
 
 
905
    def _read_bytes(self, count):
 
906
        """See SmartClientMedium.read_bytes."""
 
907
        if not self._connected:
 
908
            raise errors.MediumNotConnected(self)
 
909
        return osutils.read_bytes_from_socket(
 
910
            self._socket, self._report_activity)
1149
911
 
1150
912
 
1151
913
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1186
948
        This invokes self._medium._flush to ensure all bytes are transmitted.
1187
949
        """
1188
950
        self._medium._flush()
 
951
 
 
952