/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Robert Collins
  • Date: 2010-05-06 11:08:10 UTC
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100506110810-h3j07fh5gmw54s25
Cleaner matcher matching revised unlocking protocol.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
 
import errno
30
27
import os
31
28
import sys
32
 
import time
 
29
import urllib
33
30
 
34
 
import breezy
35
 
from ...lazy_import import lazy_import
 
31
from bzrlib.lazy_import import lazy_import
36
32
lazy_import(globals(), """
37
 
import select
 
33
import atexit
38
34
import socket
39
35
import thread
40
36
import weakref
41
37
 
42
 
from breezy import (
 
38
from bzrlib import (
43
39
    debug,
 
40
    errors,
 
41
    symbol_versioning,
44
42
    trace,
45
 
    transport,
46
43
    ui,
47
44
    urlutils,
48
45
    )
49
 
from breezy.i18n import gettext
50
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
51
 
from breezy.transport import ssh
 
46
from bzrlib.smart import client, protocol, request, vfs
 
47
from bzrlib.transport import ssh
52
48
""")
53
 
from ... import (
54
 
    errors,
55
 
    osutils,
56
 
    )
 
49
from bzrlib import osutils
57
50
 
58
51
# Throughout this module buffer size parameters are either limited to be at
59
52
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
61
54
# from non-sockets as well.
62
55
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
63
56
 
64
 
 
65
 
class HpssVfsRequestNotAllowed(errors.BzrError):
66
 
 
67
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
68
 
            "%(method)s, %(arguments)s.")
69
 
 
70
 
    def __init__(self, method, arguments):
71
 
        self.method = method
72
 
        self.arguments = arguments
73
 
 
74
 
 
75
57
def _get_protocol_factory_for_bytes(bytes):
76
58
    """Determine the right protocol factory for 'bytes'.
77
59
 
194
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
195
177
 
196
178
 
197
 
_bad_file_descriptor = (errno.EBADF,)
198
 
if sys.platform == 'win32':
199
 
    # Given on Windows if you pass a closed socket to select.select. Probably
200
 
    # also given if you pass a file handle to select.
201
 
    WSAENOTSOCK = 10038
202
 
    _bad_file_descriptor += (WSAENOTSOCK,)
203
 
 
204
 
 
205
179
class SmartServerStreamMedium(SmartMedium):
206
180
    """Handles smart commands coming over a stream.
207
181
 
220
194
        the stream.  See also the _push_back method.
221
195
    """
222
196
 
223
 
    _timer = time.time
224
 
 
225
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
226
198
        """Construct new server.
227
199
 
228
200
        :param backing_transport: Transport for the directory served.
231
203
        self.backing_transport = backing_transport
232
204
        self.root_client_path = root_client_path
233
205
        self.finished = False
234
 
        if timeout is None:
235
 
            raise AssertionError('You must supply a timeout.')
236
 
        self._client_timeout = timeout
237
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
238
206
        SmartMedium.__init__(self)
239
207
 
240
208
    def serve(self):
246
214
            while not self.finished:
247
215
                server_protocol = self._build_protocol()
248
216
                self._serve_one_request(server_protocol)
249
 
        except errors.ConnectionTimeout as e:
250
 
            trace.note('%s' % (e,))
251
 
            trace.log_exception_quietly()
252
 
            self._disconnect_client()
253
 
            # We reported it, no reason to make a big fuss.
254
 
            return
255
 
        except Exception as e:
 
217
        except Exception, e:
256
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
257
219
            raise
258
 
        self._disconnect_client()
259
 
 
260
 
    def _stop_gracefully(self):
261
 
        """When we finish this message, stop looking for more."""
262
 
        trace.mutter('Stopping %s' % (self,))
263
 
        self.finished = True
264
 
 
265
 
    def _disconnect_client(self):
266
 
        """Close the current connection. We stopped due to a timeout/etc."""
267
 
        # The default implementation is a no-op, because that is all we used to
268
 
        # do when disconnecting from a client. I suppose we never had the
269
 
        # *server* initiate a disconnect, before
270
 
 
271
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
272
 
        """Wait for more bytes to be read, but timeout if none available.
273
 
 
274
 
        This allows us to detect idle connections, and stop trying to read from
275
 
        them, without setting the socket itself to non-blocking. This also
276
 
        allows us to specify when we watch for idle timeouts.
277
 
 
278
 
        :return: Did we timeout? (True if we timed out, False if there is data
279
 
            to be read)
280
 
        """
281
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
282
220
 
283
221
    def _build_protocol(self):
284
222
        """Identifies the version of the incoming request, and returns an
289
227
 
290
228
        :returns: a SmartServerRequestProtocol.
291
229
        """
292
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
293
 
        if self.finished:
294
 
            # We're stopping, so don't try to do any more work
295
 
            return None
296
230
        bytes = self._get_line()
297
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
298
232
        protocol = protocol_factory(
300
234
        protocol.accept_bytes(unused_bytes)
301
235
        return protocol
302
236
 
303
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
304
 
        """select() on a file descriptor, waiting for nonblocking read()
305
 
 
306
 
        This will raise a ConnectionTimeout exception if we do not get a
307
 
        readable handle before timeout_seconds.
308
 
        :return: None
309
 
        """
310
 
        t_end = self._timer() + timeout_seconds
311
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
312
 
        rs = xs = None
313
 
        while not rs and not xs and self._timer() < t_end:
314
 
            if self.finished:
315
 
                return
316
 
            try:
317
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
318
 
            except (select.error, socket.error) as e:
319
 
                err = getattr(e, 'errno', None)
320
 
                if err is None and getattr(e, 'args', None) is not None:
321
 
                    # select.error doesn't have 'errno', it just has args[0]
322
 
                    err = e.args[0]
323
 
                if err in _bad_file_descriptor:
324
 
                    return # Not a socket indicates read() will fail
325
 
                elif err == errno.EINTR:
326
 
                    # Interrupted, keep looping.
327
 
                    continue
328
 
                raise
329
 
        if rs or xs:
330
 
            return
331
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
332
 
                                       % (timeout_seconds,))
333
 
 
334
237
    def _serve_one_request(self, protocol):
335
238
        """Read one request from input, process, send back a response.
336
239
 
337
240
        :param protocol: a SmartServerRequestProtocol.
338
241
        """
339
 
        if protocol is None:
340
 
            return
341
242
        try:
342
243
            self._serve_one_request_unguarded(protocol)
343
244
        except KeyboardInterrupt:
344
245
            raise
345
 
        except Exception as e:
 
246
        except Exception, e:
346
247
            self.terminate_due_to_error()
347
248
 
348
249
    def terminate_due_to_error(self):
359
260
 
360
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
361
262
 
362
 
    def __init__(self, sock, backing_transport, root_client_path='/',
363
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
364
264
        """Constructor.
365
265
 
366
266
        :param sock: the socket the server will read from.  It will be put
367
267
            into blocking mode.
368
268
        """
369
269
        SmartServerStreamMedium.__init__(
370
 
            self, backing_transport, root_client_path=root_client_path,
371
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
372
271
        sock.setblocking(True)
373
272
        self.socket = sock
374
 
        # Get the getpeername now, as we might be closed later when we care.
375
 
        try:
376
 
            self._client_info = sock.getpeername()
377
 
        except socket.error:
378
 
            self._client_info = '<unknown>'
379
 
 
380
 
    def __str__(self):
381
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
382
 
 
383
 
    def __repr__(self):
384
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
385
 
            self._client_info)
386
273
 
387
274
    def _serve_one_request_unguarded(self, protocol):
388
275
        while protocol.next_read_size():
397
284
 
398
285
        self._push_back(protocol.unused_data)
399
286
 
400
 
    def _disconnect_client(self):
401
 
        """Close the current connection. We stopped due to a timeout/etc."""
402
 
        self.socket.close()
403
 
 
404
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
405
 
        """Wait for more bytes to be read, but timeout if none available.
406
 
 
407
 
        This allows us to detect idle connections, and stop trying to read from
408
 
        them, without setting the socket itself to non-blocking. This also
409
 
        allows us to specify when we watch for idle timeouts.
410
 
 
411
 
        :return: None, this will raise ConnectionTimeout if we time out before
412
 
            data is available.
413
 
        """
414
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
415
 
 
416
287
    def _read_bytes(self, desired_count):
417
288
        return osutils.read_bytes_from_socket(
418
289
            self.socket, self._report_activity)
435
306
 
436
307
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
437
308
 
438
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
309
    def __init__(self, in_file, out_file, backing_transport):
439
310
        """Construct new server.
440
311
 
441
312
        :param in_file: Python file from which requests can be read.
442
313
        :param out_file: Python file to write responses.
443
314
        :param backing_transport: Transport for the directory served.
444
315
        """
445
 
        SmartServerStreamMedium.__init__(self, backing_transport,
446
 
            timeout=timeout)
 
316
        SmartServerStreamMedium.__init__(self, backing_transport)
447
317
        if sys.platform == 'win32':
448
318
            # force binary mode for files
449
319
            import msvcrt
454
324
        self._in = in_file
455
325
        self._out = out_file
456
326
 
457
 
    def serve(self):
458
 
        """See SmartServerStreamMedium.serve"""
459
 
        # This is the regular serve, except it adds signal trapping for soft
460
 
        # shutdown.
461
 
        stop_gracefully = self._stop_gracefully
462
 
        signals.register_on_hangup(id(self), stop_gracefully)
463
 
        try:
464
 
            return super(SmartServerPipeStreamMedium, self).serve()
465
 
        finally:
466
 
            signals.unregister_on_hangup(id(self))
467
 
 
468
327
    def _serve_one_request_unguarded(self, protocol):
469
328
        while True:
470
329
            # We need to be careful not to read past the end of the current
483
342
                return
484
343
            protocol.accept_bytes(bytes)
485
344
 
486
 
    def _disconnect_client(self):
487
 
        self._in.close()
488
 
        self._out.flush()
489
 
        self._out.close()
490
 
 
491
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
492
 
        """Wait for more bytes to be read, but timeout if none available.
493
 
 
494
 
        This allows us to detect idle connections, and stop trying to read from
495
 
        them, without setting the socket itself to non-blocking. This also
496
 
        allows us to specify when we watch for idle timeouts.
497
 
 
498
 
        :return: None, this will raise ConnectionTimeout if we time out before
499
 
            data is available.
500
 
        """
501
 
        if (getattr(self._in, 'fileno', None) is None
502
 
            or sys.platform == 'win32'):
503
 
            # You can't select() file descriptors on Windows.
504
 
            return
505
 
        return self._wait_on_descriptor(self._in, timeout_seconds)
506
 
 
507
345
    def _read_bytes(self, desired_count):
508
346
        return self._in.read(desired_count)
509
347
 
653
491
        return self._medium._get_line()
654
492
 
655
493
 
656
 
class _VfsRefuser(object):
657
 
    """An object that refuses all VFS requests.
658
 
 
659
 
    """
660
 
 
661
 
    def __init__(self):
662
 
        client._SmartClient.hooks.install_named_hook(
663
 
            'call', self.check_vfs, 'vfs refuser')
664
 
 
665
 
    def check_vfs(self, params):
666
 
        try:
667
 
            request_method = request.request_handlers.get(params.method)
668
 
        except KeyError:
669
 
            # A method we don't know about doesn't count as a VFS method.
670
 
            return
671
 
        if issubclass(request_method, vfs.VfsRequest):
672
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
673
 
 
674
 
 
675
494
class _DebugCounter(object):
676
495
    """An object that counts the HPSS calls made to each client medium.
677
496
 
678
 
    When a medium is garbage-collected, or failing that when
679
 
    breezy.global_state exits, the total number of calls made on that medium
680
 
    are reported via trace.note.
 
497
    When a medium is garbage-collected, or failing that when atexit functions
 
498
    are run, the total number of calls made on that medium are reported via
 
499
    trace.note.
681
500
    """
682
501
 
683
502
    def __init__(self):
684
503
        self.counts = weakref.WeakKeyDictionary()
685
504
        client._SmartClient.hooks.install_named_hook(
686
505
            'call', self.increment_call_count, 'hpss call counter')
687
 
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
 
506
        atexit.register(self.flush_all)
688
507
 
689
508
    def track(self, medium):
690
509
        """Start tracking calls made to a medium.
724
543
        value['count'] = 0
725
544
        value['vfs_count'] = 0
726
545
        if count != 0:
727
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
728
 
                       count, vfs_count, medium_repr))
 
546
            trace.note('HPSS calls: %d (%d vfs) %s',
 
547
                       count, vfs_count, medium_repr)
729
548
 
730
549
    def flush_all(self):
731
550
        for ref in list(self.counts.keys()):
732
551
            self.done(ref)
733
552
 
734
553
_debug_counter = None
735
 
_vfs_refuser = None
736
554
 
737
555
 
738
556
class SmartClientMedium(SmartMedium):
755
573
            if _debug_counter is None:
756
574
                _debug_counter = _DebugCounter()
757
575
            _debug_counter.track(self)
758
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
759
 
            global _vfs_refuser
760
 
            if _vfs_refuser is None:
761
 
                _vfs_refuser = _VfsRefuser()
762
576
 
763
577
    def _is_remote_before(self, version_tuple):
764
578
        """Is it possible the remote side supports RPCs for a given version?
817
631
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
818
632
                client_protocol.query_version()
819
633
                self._done_hello = True
820
 
            except errors.SmartProtocolError as e:
 
634
            except errors.SmartProtocolError, e:
821
635
                # Cache the error, just like we would cache a successful
822
636
                # result.
823
637
                self._protocol_version_error = e
855
669
        """
856
670
        medium_base = urlutils.join(self.base, '/')
857
671
        rel_url = urlutils.relative_url(medium_base, transport.base)
858
 
        return urlutils.unquote(rel_url)
 
672
        return urllib.unquote(rel_url)
859
673
 
860
674
 
861
675
class SmartClientStreamMedium(SmartClientMedium):
896
710
        """
897
711
        return SmartClientStreamMediumRequest(self)
898
712
 
899
 
    def reset(self):
900
 
        """We have been disconnected, reset current state.
901
 
 
902
 
        This resets things like _current_request and connected state.
903
 
        """
904
 
        self.disconnect()
905
 
        self._current_request = None
906
 
 
907
713
 
908
714
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
909
715
    """A client medium using simple pipes.
910
716
 
911
717
    This client does not manage the pipes: it assumes they will always be open.
 
718
 
 
719
    Note that if readable_pipe.read might raise IOError or OSError with errno
 
720
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
 
721
    (such as used for sys.stdin) are safe.
912
722
    """
913
723
 
914
724
    def __init__(self, readable_pipe, writeable_pipe, base):
918
728
 
919
729
    def _accept_bytes(self, bytes):
920
730
        """See SmartClientStreamMedium.accept_bytes."""
921
 
        try:
922
 
            self._writeable_pipe.write(bytes)
923
 
        except IOError as e:
924
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
925
 
                raise errors.ConnectionReset(
926
 
                    "Error trying to write to subprocess", e)
927
 
            raise
 
731
        self._writeable_pipe.write(bytes)
928
732
        self._report_activity(len(bytes), 'write')
929
733
 
930
734
    def _flush(self):
931
735
        """See SmartClientStreamMedium._flush()."""
932
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
933
 
        #       However, testing shows that even when the child process is
934
 
        #       gone, this doesn't error.
935
736
        self._writeable_pipe.flush()
936
737
 
937
738
    def _read_bytes(self, count):
938
739
        """See SmartClientStreamMedium._read_bytes."""
939
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
940
 
        bytes = self._readable_pipe.read(bytes_to_read)
 
740
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
941
741
        self._report_activity(len(bytes), 'read')
942
742
        return bytes
943
743
 
944
744
 
945
 
class SSHParams(object):
946
 
    """A set of parameters for starting a remote bzr via SSH."""
 
745
class SmartSSHClientMedium(SmartClientStreamMedium):
 
746
    """A client medium using SSH."""
947
747
 
948
748
    def __init__(self, host, port=None, username=None, password=None,
949
 
            bzr_remote_path='bzr'):
950
 
        self.host = host
951
 
        self.port = port
952
 
        self.username = username
953
 
        self.password = password
954
 
        self.bzr_remote_path = bzr_remote_path
955
 
 
956
 
 
957
 
class SmartSSHClientMedium(SmartClientStreamMedium):
958
 
    """A client medium using SSH.
959
 
 
960
 
    It delegates IO to a SmartSimplePipesClientMedium or
961
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
962
 
    """
963
 
 
964
 
    def __init__(self, base, ssh_params, vendor=None):
 
749
            base=None, vendor=None, bzr_remote_path=None):
965
750
        """Creates a client that will connect on the first use.
966
751
 
967
 
        :param ssh_params: A SSHParams instance.
968
752
        :param vendor: An optional override for the ssh vendor to use. See
969
 
            breezy.transport.ssh for details on ssh vendors.
 
753
            bzrlib.transport.ssh for details on ssh vendors.
970
754
        """
971
 
        self._real_medium = None
972
 
        self._ssh_params = ssh_params
 
755
        self._connected = False
 
756
        self._host = host
 
757
        self._password = password
 
758
        self._port = port
 
759
        self._username = username
973
760
        # for the benefit of progress making a short description of this
974
761
        # transport
975
762
        self._scheme = 'bzr+ssh'
977
764
        # _DebugCounter so we have to store all the values used in our repr
978
765
        # method before calling the super init.
979
766
        SmartClientStreamMedium.__init__(self, base)
 
767
        self._read_from = None
 
768
        self._ssh_connection = None
980
769
        self._vendor = vendor
981
 
        self._ssh_connection = None
 
770
        self._write_to = None
 
771
        self._bzr_remote_path = bzr_remote_path
982
772
 
983
773
    def __repr__(self):
984
 
        if self._ssh_params.port is None:
 
774
        if self._port is None:
985
775
            maybe_port = ''
986
776
        else:
987
 
            maybe_port = ':%s' % self._ssh_params.port
988
 
        if self._ssh_params.username is None:
989
 
            maybe_user = ''
990
 
        else:
991
 
            maybe_user = '%s@' % self._ssh_params.username
992
 
        return "%s(%s://%s%s%s/)" % (
 
777
            maybe_port = ':%s' % self._port
 
778
        return "%s(%s://%s@%s%s/)" % (
993
779
            self.__class__.__name__,
994
780
            self._scheme,
995
 
            maybe_user,
996
 
            self._ssh_params.host,
 
781
            self._username,
 
782
            self._host,
997
783
            maybe_port)
998
784
 
999
785
    def _accept_bytes(self, bytes):
1000
786
        """See SmartClientStreamMedium.accept_bytes."""
1001
787
        self._ensure_connection()
1002
 
        self._real_medium.accept_bytes(bytes)
 
788
        self._write_to.write(bytes)
 
789
        self._report_activity(len(bytes), 'write')
1003
790
 
1004
791
    def disconnect(self):
1005
792
        """See SmartClientMedium.disconnect()."""
1006
 
        if self._real_medium is not None:
1007
 
            self._real_medium.disconnect()
1008
 
            self._real_medium = None
1009
 
        if self._ssh_connection is not None:
1010
 
            self._ssh_connection.close()
1011
 
            self._ssh_connection = None
 
793
        if not self._connected:
 
794
            return
 
795
        self._read_from.close()
 
796
        self._write_to.close()
 
797
        self._ssh_connection.close()
 
798
        self._connected = False
1012
799
 
1013
800
    def _ensure_connection(self):
1014
801
        """Connect this medium if not already connected."""
1015
 
        if self._real_medium is not None:
 
802
        if self._connected:
1016
803
            return
1017
804
        if self._vendor is None:
1018
805
            vendor = ssh._get_ssh_vendor()
1019
806
        else:
1020
807
            vendor = self._vendor
1021
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1022
 
                self._ssh_params.password, self._ssh_params.host,
1023
 
                self._ssh_params.port,
1024
 
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
808
        self._ssh_connection = vendor.connect_ssh(self._username,
 
809
                self._password, self._host, self._port,
 
810
                command=[self._bzr_remote_path, 'serve', '--inet',
1025
811
                         '--directory=/', '--allow-writes'])
1026
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1027
 
        if io_kind == 'socket':
1028
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1029
 
                self.base, io_object)
1030
 
        elif io_kind == 'pipes':
1031
 
            read_from, write_to = io_object
1032
 
            self._real_medium = SmartSimplePipesClientMedium(
1033
 
                read_from, write_to, self.base)
1034
 
        else:
1035
 
            raise AssertionError(
1036
 
                "Unexpected io_kind %r from %r"
1037
 
                % (io_kind, self._ssh_connection))
1038
 
        for hook in transport.Transport.hooks["post_connect"]:
1039
 
            hook(self)
 
812
        self._read_from, self._write_to = \
 
813
            self._ssh_connection.get_filelike_channels()
 
814
        self._connected = True
1040
815
 
1041
816
    def _flush(self):
1042
817
        """See SmartClientStreamMedium._flush()."""
1043
 
        self._real_medium._flush()
 
818
        self._write_to.flush()
1044
819
 
1045
820
    def _read_bytes(self, count):
1046
821
        """See SmartClientStreamMedium.read_bytes."""
1047
 
        if self._real_medium is None:
 
822
        if not self._connected:
1048
823
            raise errors.MediumNotConnected(self)
1049
 
        return self._real_medium.read_bytes(count)
 
824
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
825
        bytes = self._read_from.read(bytes_to_read)
 
826
        self._report_activity(len(bytes), 'read')
 
827
        return bytes
1050
828
 
1051
829
 
1052
830
# Port 4155 is the default port for bzr://, registered with IANA.
1054
832
BZR_DEFAULT_PORT = 4155
1055
833
 
1056
834
 
1057
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1058
 
    """A client medium using a socket.
1059
 
 
1060
 
    This class isn't usable directly.  Use one of its subclasses instead.
1061
 
    """
1062
 
 
1063
 
    def __init__(self, base):
 
835
class SmartTCPClientMedium(SmartClientStreamMedium):
 
836
    """A client medium using TCP."""
 
837
 
 
838
    def __init__(self, host, port, base):
 
839
        """Creates a client that will connect on the first use."""
1064
840
        SmartClientStreamMedium.__init__(self, base)
 
841
        self._connected = False
 
842
        self._host = host
 
843
        self._port = port
1065
844
        self._socket = None
1066
 
        self._connected = False
1067
845
 
1068
846
    def _accept_bytes(self, bytes):
1069
847
        """See SmartClientMedium.accept_bytes."""
1070
848
        self._ensure_connection()
1071
849
        osutils.send_all(self._socket, bytes, self._report_activity)
1072
850
 
1073
 
    def _ensure_connection(self):
1074
 
        """Connect this medium if not already connected."""
1075
 
        raise NotImplementedError(self._ensure_connection)
1076
 
 
1077
 
    def _flush(self):
1078
 
        """See SmartClientStreamMedium._flush().
1079
 
 
1080
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1081
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1082
 
        future.
1083
 
        """
1084
 
 
1085
 
    def _read_bytes(self, count):
1086
 
        """See SmartClientMedium.read_bytes."""
1087
 
        if not self._connected:
1088
 
            raise errors.MediumNotConnected(self)
1089
 
        return osutils.read_bytes_from_socket(
1090
 
            self._socket, self._report_activity)
1091
 
 
1092
851
    def disconnect(self):
1093
852
        """See SmartClientMedium.disconnect()."""
1094
853
        if not self._connected:
1097
856
        self._socket = None
1098
857
        self._connected = False
1099
858
 
1100
 
 
1101
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1102
 
    """A client medium that creates a TCP connection."""
1103
 
 
1104
 
    def __init__(self, host, port, base):
1105
 
        """Creates a client that will connect on the first use."""
1106
 
        SmartClientSocketMedium.__init__(self, base)
1107
 
        self._host = host
1108
 
        self._port = port
1109
 
 
1110
859
    def _ensure_connection(self):
1111
860
        """Connect this medium if not already connected."""
1112
861
        if self._connected:
1118
867
        try:
1119
868
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1120
869
                socket.SOCK_STREAM, 0, 0)
1121
 
        except socket.gaierror as xxx_todo_changeme:
1122
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
870
        except socket.gaierror, (err_num, err_msg):
1123
871
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1124
872
                    (self._host, port, err_msg))
1125
873
        # Initialize err in case there are no addresses returned:
1130
878
                self._socket.setsockopt(socket.IPPROTO_TCP,
1131
879
                                        socket.TCP_NODELAY, 1)
1132
880
                self._socket.connect(sockaddr)
1133
 
            except socket.error as err:
 
881
            except socket.error, err:
1134
882
                if self._socket is not None:
1135
883
                    self._socket.close()
1136
884
                self._socket = None
1139
887
        if self._socket is None:
1140
888
            # socket errors either have a (string) or (errno, string) as their
1141
889
            # args.
1142
 
            if isinstance(err.args, str):
 
890
            if type(err.args) is str:
1143
891
                err_msg = err.args
1144
892
            else:
1145
893
                err_msg = err.args[1]
1146
894
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1147
895
                    (self._host, port, err_msg))
1148
896
        self._connected = True
1149
 
        for hook in transport.Transport.hooks["post_connect"]:
1150
 
            hook(self)
1151
 
 
1152
 
 
1153
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1154
 
    """A client medium for an already connected socket.
1155
 
    
1156
 
    Note that this class will assume it "owns" the socket, so it will close it
1157
 
    when its disconnect method is called.
1158
 
    """
1159
 
 
1160
 
    def __init__(self, base, sock):
1161
 
        SmartClientSocketMedium.__init__(self, base)
1162
 
        self._socket = sock
1163
 
        self._connected = True
1164
 
 
1165
 
    def _ensure_connection(self):
1166
 
        # Already connected, by definition!  So nothing to do.
1167
 
        pass
 
897
 
 
898
    def _flush(self):
 
899
        """See SmartClientStreamMedium._flush().
 
900
 
 
901
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
902
        add a means to do a flush, but that can be done in the future.
 
903
        """
 
904
 
 
905
    def _read_bytes(self, count):
 
906
        """See SmartClientMedium.read_bytes."""
 
907
        if not self._connected:
 
908
            raise errors.MediumNotConnected(self)
 
909
        return osutils.read_bytes_from_socket(
 
910
            self._socket, self._report_activity)
1168
911
 
1169
912
 
1170
913
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1205
948
        This invokes self._medium._flush to ensure all bytes are transmitted.
1206
949
        """
1207
950
        self._medium._flush()
 
951
 
 
952