22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic. See the overview in
24
breezy/transport/smart/__init__.py.
24
bzrlib/transport/smart/__init__.py.
30
from bzrlib import errors
31
from bzrlib.smart.protocol import SmartServerRequestProtocolOne
36
import thread as _thread
39
from ...lazy_import import lazy_import
40
lazy_import(globals(), """
52
from breezy.i18n import gettext
53
from breezy.bzr.smart import client, protocol, request, signals, vfs
54
from breezy.transport import ssh
61
# Throughout this module buffer size parameters are either limited to be at
62
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
63
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
64
# from non-sockets as well.
65
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
68
class HpssVfsRequestNotAllowed(errors.BzrError):
70
_fmt = ("VFS requests over the smart server are not allowed. Encountered: "
71
"%(method)s, %(arguments)s.")
73
def __init__(self, method, arguments):
75
self.arguments = arguments
78
def _get_protocol_factory_for_bytes(bytes):
79
"""Determine the right protocol factory for 'bytes'.
81
This will return an appropriate protocol factory depending on the version
82
of the protocol being used, as determined by inspecting the given bytes.
83
The bytes should have at least one newline byte (i.e. be a whole line),
84
otherwise it's possible that a request will be incorrectly identified as
87
Typical use would be::
89
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
90
server_protocol = factory(transport, write_func, root_client_path)
91
server_protocol.accept_bytes(unused_bytes)
93
:param bytes: a str of bytes of the start of the request.
94
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
95
a callable that takes three args: transport, write_func,
96
root_client_path. unused_bytes are any bytes that were not part of a
97
protocol version marker.
99
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
100
protocol_factory = protocol.build_server_protocol_three
101
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
102
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
103
protocol_factory = protocol.SmartServerRequestProtocolTwo
104
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
106
protocol_factory = protocol.SmartServerRequestProtocolOne
107
return protocol_factory, bytes
110
def _get_line(read_bytes_func):
111
"""Read bytes using read_bytes_func until a newline byte.
113
This isn't particularly efficient, so should only be used when the
114
expected size of the line is quite short.
116
:returns: a tuple of two strs: (line, excess)
120
while newline_pos == -1:
121
new_bytes = read_bytes_func(1)
124
# Ran out of bytes before receiving a complete line.
126
newline_pos = bytes.find(b'\n')
127
line = bytes[:newline_pos + 1]
128
excess = bytes[newline_pos + 1:]
132
class SmartMedium(object):
133
"""Base class for smart protocol media, both client- and server-side."""
136
self._push_back_buffer = None
138
def _push_back(self, data):
139
"""Return unused bytes to the medium, because they belong to the next
142
This sets the _push_back_buffer to the given bytes.
144
if not isinstance(data, bytes):
145
raise TypeError(data)
146
if self._push_back_buffer is not None:
147
raise AssertionError(
148
"_push_back called when self._push_back_buffer is %r"
149
% (self._push_back_buffer,))
152
self._push_back_buffer = data
154
def _get_push_back_buffer(self):
155
if self._push_back_buffer == b'':
156
raise AssertionError(
157
'%s._push_back_buffer should never be the empty string, '
158
'which can be confused with EOF' % (self,))
159
bytes = self._push_back_buffer
160
self._push_back_buffer = None
163
def read_bytes(self, desired_count):
164
"""Read some bytes from this medium.
166
:returns: some bytes, possibly more or less than the number requested
167
in 'desired_count' depending on the medium.
169
if self._push_back_buffer is not None:
170
return self._get_push_back_buffer()
171
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
172
return self._read_bytes(bytes_to_read)
174
def _read_bytes(self, count):
175
raise NotImplementedError(self._read_bytes)
178
"""Read bytes from this request's response until a newline byte.
180
This isn't particularly efficient, so should only be used when the
181
expected size of the line is quite short.
183
:returns: a string of bytes ending in a newline (byte 0x0A).
185
line, excess = _get_line(self.read_bytes)
186
self._push_back(excess)
189
def _report_activity(self, bytes, direction):
190
"""Notify that this medium has activity.
192
Implementations should call this from all methods that actually do IO.
193
Be careful that it's not called twice, if one method is implemented on
196
:param bytes: Number of bytes read or written.
197
:param direction: 'read' or 'write' or None.
199
ui.ui_factory.report_transport_activity(self, bytes, direction)
202
_bad_file_descriptor = (errno.EBADF,)
203
if sys.platform == 'win32':
204
# Given on Windows if you pass a closed socket to select.select. Probably
205
# also given if you pass a file handle to select.
207
_bad_file_descriptor += (WSAENOTSOCK,)
210
class SmartServerStreamMedium(SmartMedium):
34
from bzrlib.transport import ssh
35
except errors.ParamikoNotPresent:
36
# no paramiko. SmartSSHClientMedium will break.
40
class SmartServerStreamMedium(object):
211
41
"""Handles smart commands coming over a stream.
213
43
The stream may be a pipe connected to sshd, or a tcp socket, or an
249
66
from sys import stderr
251
68
while not self.finished:
252
server_protocol = self._build_protocol()
253
self._serve_one_request(server_protocol)
254
except errors.ConnectionTimeout as e:
255
trace.note('%s' % (e,))
256
trace.log_exception_quietly()
257
self._disconnect_client()
258
# We reported it, no reason to make a big fuss.
260
except Exception as e:
69
protocol = SmartServerRequestProtocolOne(self.backing_transport,
71
self._serve_one_request(protocol)
261
73
stderr.write("%s terminating on exception %s\n" % (self, e))
263
self._disconnect_client()
265
def _stop_gracefully(self):
266
"""When we finish this message, stop looking for more."""
267
trace.mutter('Stopping %s' % (self,))
270
def _disconnect_client(self):
271
"""Close the current connection. We stopped due to a timeout/etc."""
272
# The default implementation is a no-op, because that is all we used to
273
# do when disconnecting from a client. I suppose we never had the
274
# *server* initiate a disconnect, before
276
def _wait_for_bytes_with_timeout(self, timeout_seconds):
277
"""Wait for more bytes to be read, but timeout if none available.
279
This allows us to detect idle connections, and stop trying to read from
280
them, without setting the socket itself to non-blocking. This also
281
allows us to specify when we watch for idle timeouts.
283
:return: Did we timeout? (True if we timed out, False if there is data
286
raise NotImplementedError(self._wait_for_bytes_with_timeout)
288
def _build_protocol(self):
289
"""Identifies the version of the incoming request, and returns an
290
a protocol object that can interpret it.
292
If more bytes than the version prefix of the request are read, they will
293
be fed into the protocol before it is returned.
295
:returns: a SmartServerRequestProtocol.
297
self._wait_for_bytes_with_timeout(self._client_timeout)
299
# We're stopping, so don't try to do any more work
301
bytes = self._get_line()
302
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
303
protocol = protocol_factory(
304
self.backing_transport, self._write_out, self.root_client_path)
305
protocol.accept_bytes(unused_bytes)
308
def _wait_on_descriptor(self, fd, timeout_seconds):
309
"""select() on a file descriptor, waiting for nonblocking read()
311
This will raise a ConnectionTimeout exception if we do not get a
312
readable handle before timeout_seconds.
315
t_end = self._timer() + timeout_seconds
316
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
318
while not rs and not xs and self._timer() < t_end:
322
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
323
except (select.error, socket.error) as e:
324
err = getattr(e, 'errno', None)
325
if err is None and getattr(e, 'args', None) is not None:
326
# select.error doesn't have 'errno', it just has args[0]
328
if err in _bad_file_descriptor:
329
return # Not a socket indicates read() will fail
330
elif err == errno.EINTR:
331
# Interrupted, keep looping.
335
return # Socket may already be closed
338
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
339
% (timeout_seconds,))
341
76
def _serve_one_request(self, protocol):
342
77
"""Read one request from input, process, send back a response.
344
79
:param protocol: a SmartServerRequestProtocol.
349
82
self._serve_one_request_unguarded(protocol)
350
83
except KeyboardInterrupt:
352
except Exception as e:
353
86
self.terminate_due_to_error()
355
88
def terminate_due_to_error(self):
356
89
"""Called when an unhandled exception from the protocol occurs."""
357
90
raise NotImplementedError(self.terminate_due_to_error)
359
def _read_bytes(self, desired_count):
360
"""Get some bytes from the medium.
362
:param desired_count: number of bytes we want to read.
364
raise NotImplementedError(self._read_bytes)
367
93
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
369
def __init__(self, sock, backing_transport, root_client_path='/',
95
def __init__(self, sock, backing_transport):
373
98
:param sock: the socket the server will read from. It will be put
374
99
into blocking mode.
376
SmartServerStreamMedium.__init__(
377
self, backing_transport, root_client_path=root_client_path,
101
SmartServerStreamMedium.__init__(self, backing_transport)
379
103
sock.setblocking(True)
380
104
self.socket = sock
381
# Get the getpeername now, as we might be closed later when we care.
383
self._client_info = sock.getpeername()
385
self._client_info = '<unknown>'
388
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
391
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
394
106
def _serve_one_request_unguarded(self, protocol):
395
107
while protocol.next_read_size():
396
# We can safely try to read large chunks. If there is less data
397
# than MAX_SOCKET_CHUNK ready, the socket will just return a
398
# short read immediately rather than block.
399
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
403
protocol.accept_bytes(bytes)
405
self._push_back(protocol.unused_data)
407
def _disconnect_client(self):
408
"""Close the current connection. We stopped due to a timeout/etc."""
411
def _wait_for_bytes_with_timeout(self, timeout_seconds):
412
"""Wait for more bytes to be read, but timeout if none available.
414
This allows us to detect idle connections, and stop trying to read from
415
them, without setting the socket itself to non-blocking. This also
416
allows us to specify when we watch for idle timeouts.
418
:return: None, this will raise ConnectionTimeout if we time out before
421
return self._wait_on_descriptor(self.socket, timeout_seconds)
423
def _read_bytes(self, desired_count):
424
return osutils.read_bytes_from_socket(
425
self.socket, self._report_activity)
109
protocol.accept_bytes(self.push_back)
112
bytes = self.socket.recv(4096)
116
protocol.accept_bytes(bytes)
118
self.push_back = protocol.excess_buffer
427
120
def terminate_due_to_error(self):
121
"""Called when an unhandled exception from the protocol occurs."""
428
122
# TODO: This should log to a server log file, but no such thing
429
123
# exists yet. Andrew Bennetts 2006-09-29.
430
124
self.socket.close()
431
125
self.finished = True
433
127
def _write_out(self, bytes):
434
tstart = osutils.perf_counter()
435
osutils.send_all(self.socket, bytes, self._report_activity)
436
if 'hpss' in debug.debug_flags:
437
thread_id = _thread.get_ident()
438
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
439
% ('wrote', thread_id, len(bytes),
440
osutils.perf_counter() - tstart))
128
self.socket.sendall(bytes)
443
131
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
445
def __init__(self, in_file, out_file, backing_transport, timeout=None):
133
def __init__(self, in_file, out_file, backing_transport):
446
134
"""Construct new server.
448
136
:param in_file: Python file from which requests can be read.
449
137
:param out_file: Python file to write responses.
450
138
:param backing_transport: Transport for the directory served.
452
SmartServerStreamMedium.__init__(self, backing_transport,
140
SmartServerStreamMedium.__init__(self, backing_transport)
454
141
if sys.platform == 'win32':
455
142
# force binary mode for files
634
280
return self._read_bytes(count)
636
282
def _read_bytes(self, count):
637
"""Helper for SmartClientMediumRequest.read_bytes.
283
"""Helper for read_bytes.
639
285
read_bytes checks the state of the request to determing if bytes
640
286
should be read. After that it hands off to _read_bytes to do the
643
By default this forwards to self._medium.read_bytes because we are
644
operating on the medium's stream.
646
return self._medium.read_bytes(count)
649
line = self._read_line()
650
if not line.endswith(b'\n'):
651
# end of file encountered reading from server
652
raise errors.ConnectionReset(
653
"Unexpected end of message. Please check connectivity "
654
"and permissions, and report a bug if problems persist.")
657
def _read_line(self):
658
"""Helper for SmartClientMediumRequest.read_line.
660
By default this forwards to self._medium._get_line because we are
661
operating on the medium's stream.
663
return self._medium._get_line()
666
class _VfsRefuser(object):
667
"""An object that refuses all VFS requests.
672
client._SmartClient.hooks.install_named_hook(
673
'call', self.check_vfs, 'vfs refuser')
675
def check_vfs(self, params):
677
request_method = request.request_handlers.get(params.method)
679
# A method we don't know about doesn't count as a VFS method.
681
if issubclass(request_method, vfs.VfsRequest):
682
raise HpssVfsRequestNotAllowed(params.method, params.args)
685
class _DebugCounter(object):
686
"""An object that counts the HPSS calls made to each client medium.
688
When a medium is garbage-collected, or failing that when
689
breezy.global_state exits, the total number of calls made on that medium
690
are reported via trace.note.
694
self.counts = weakref.WeakKeyDictionary()
695
client._SmartClient.hooks.install_named_hook(
696
'call', self.increment_call_count, 'hpss call counter')
697
breezy.get_global_state().exit_stack.callback(self.flush_all)
699
def track(self, medium):
700
"""Start tracking calls made to a medium.
702
This only keeps a weakref to the medium, so shouldn't affect the
705
medium_repr = repr(medium)
706
# Add this medium to the WeakKeyDictionary
707
self.counts[medium] = dict(count=0, vfs_count=0,
708
medium_repr=medium_repr)
709
# Weakref callbacks are fired in reverse order of their association
710
# with the referenced object. So we add a weakref *after* adding to
711
# the WeakKeyDict so that we can report the value from it before the
712
# entry is removed by the WeakKeyDict's own callback.
713
ref = weakref.ref(medium, self.done)
715
def increment_call_count(self, params):
716
# Increment the count in the WeakKeyDictionary
717
value = self.counts[params.medium]
720
request_method = request.request_handlers.get(params.method)
722
# A method we don't know about doesn't count as a VFS method.
724
if issubclass(request_method, vfs.VfsRequest):
725
value['vfs_count'] += 1
728
value = self.counts[ref]
729
count, vfs_count, medium_repr = (
730
value['count'], value['vfs_count'], value['medium_repr'])
731
# In case this callback is invoked for the same ref twice (by the
732
# weakref callback and by the atexit function), set the call count back
733
# to 0 so this item won't be reported twice.
735
value['vfs_count'] = 0
737
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
738
count, vfs_count, medium_repr))
741
for ref in list(self.counts.keys()):
745
_debug_counter = None
749
class SmartClientMedium(SmartMedium):
289
raise NotImplementedError(self._read_bytes)
292
class SmartClientMedium(object):
750
293
"""Smart client is a medium for sending smart protocol requests over."""
752
def __init__(self, base):
753
super(SmartClientMedium, self).__init__()
755
self._protocol_version_error = None
756
self._protocol_version = None
757
self._done_hello = False
758
# Be optimistic: we assume the remote end can accept new remote
759
# requests until we get an error saying otherwise.
760
# _remote_version_is_before tracks the bzr version the remote side
761
# can be based on what we've seen so far.
762
self._remote_version_is_before = None
763
# Install debug hook function if debug flag is set.
764
if 'hpss' in debug.debug_flags:
765
global _debug_counter
766
if _debug_counter is None:
767
_debug_counter = _DebugCounter()
768
_debug_counter.track(self)
769
if 'hpss_client_no_vfs' in debug.debug_flags:
771
if _vfs_refuser is None:
772
_vfs_refuser = _VfsRefuser()
774
def _is_remote_before(self, version_tuple):
775
"""Is it possible the remote side supports RPCs for a given version?
779
needed_version = (1, 2)
780
if medium._is_remote_before(needed_version):
781
fallback_to_pre_1_2_rpc()
785
except UnknownSmartMethod:
786
medium._remember_remote_is_before(needed_version)
787
fallback_to_pre_1_2_rpc()
789
:seealso: _remember_remote_is_before
791
if self._remote_version_is_before is None:
792
# So far, the remote side seems to support everything
794
return version_tuple >= self._remote_version_is_before
796
def _remember_remote_is_before(self, version_tuple):
797
"""Tell this medium that the remote side is older the given version.
799
:seealso: _is_remote_before
801
if (self._remote_version_is_before is not None and
802
version_tuple > self._remote_version_is_before):
803
# We have been told that the remote side is older than some version
804
# which is newer than a previously supplied older-than version.
805
# This indicates that some smart verb call is not guarded
806
# appropriately (it should simply not have been tried).
808
"_remember_remote_is_before(%r) called, but "
809
"_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
810
if 'hpss' in debug.debug_flags:
811
ui.ui_factory.show_warning(
812
"_remember_remote_is_before(%r) called, but "
813
"_remember_remote_is_before(%r) was called previously."
814
% (version_tuple, self._remote_version_is_before))
816
self._remote_version_is_before = version_tuple
818
def protocol_version(self):
819
"""Find out if 'hello' smart request works."""
820
if self._protocol_version_error is not None:
821
raise self._protocol_version_error
822
if not self._done_hello:
824
medium_request = self.get_request()
825
# Send a 'hello' request in protocol version one, for maximum
826
# backwards compatibility.
827
client_protocol = protocol.SmartClientRequestProtocolOne(
829
client_protocol.query_version()
830
self._done_hello = True
831
except errors.SmartProtocolError as e:
832
# Cache the error, just like we would cache a successful
834
self._protocol_version_error = e
838
def should_probe(self):
839
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
842
Some transports are unambiguously smart-only; there's no need to check
843
if the transport is able to carry smart requests, because that's all
844
it is for. In those cases, this method should return False.
846
But some HTTP transports can sometimes fail to carry smart requests,
847
but still be usuable for accessing remote bzrdirs via plain file
848
accesses. So for those transports, their media should return True here
849
so that RemoteBzrDirFormat can determine if it is appropriate for that
854
295
def disconnect(self):
855
296
"""If this medium maintains a persistent connection, close it.
857
298
The default implementation does nothing.
860
def remote_path_from_transport(self, transport):
861
"""Convert transport into a path suitable for using in a request.
863
Note that the resulting remote path doesn't encode the host name or
864
anything but path, so it is only safe to use it in requests sent over
865
the medium from the matching transport.
867
medium_base = urlutils.join(self.base, '/')
868
rel_url = urlutils.relative_url(medium_base, transport.base)
869
return urlutils.unquote(rel_url)
872
302
class SmartClientStreamMedium(SmartClientMedium):
873
303
"""Stream based medium common class.
908
337
return SmartClientStreamMediumRequest(self)
911
"""We have been disconnected, reset current state.
913
This resets things like _current_request and connected state.
916
self._current_request = None
339
def read_bytes(self, count):
340
return self._read_bytes(count)
919
343
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
920
344
"""A client medium using simple pipes.
922
346
This client does not manage the pipes: it assumes they will always be open.
925
def __init__(self, readable_pipe, writeable_pipe, base):
926
SmartClientStreamMedium.__init__(self, base)
349
def __init__(self, readable_pipe, writeable_pipe):
350
SmartClientStreamMedium.__init__(self)
927
351
self._readable_pipe = readable_pipe
928
352
self._writeable_pipe = writeable_pipe
930
def _accept_bytes(self, data):
354
def _accept_bytes(self, bytes):
931
355
"""See SmartClientStreamMedium.accept_bytes."""
933
self._writeable_pipe.write(data)
935
if e.errno in (errno.EINVAL, errno.EPIPE):
936
raise errors.ConnectionReset(
937
"Error trying to write to subprocess", e)
939
self._report_activity(len(data), 'write')
356
self._writeable_pipe.write(bytes)
941
358
def _flush(self):
942
359
"""See SmartClientStreamMedium._flush()."""
943
# Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
944
# However, testing shows that even when the child process is
945
# gone, this doesn't error.
946
360
self._writeable_pipe.flush()
948
362
def _read_bytes(self, count):
949
363
"""See SmartClientStreamMedium._read_bytes."""
950
bytes_to_read = min(count, _MAX_READ_SIZE)
951
data = self._readable_pipe.read(bytes_to_read)
952
self._report_activity(len(data), 'read')
956
class SSHParams(object):
957
"""A set of parameters for starting a remote bzr via SSH."""
364
return self._readable_pipe.read(count)
367
class SmartSSHClientMedium(SmartClientStreamMedium):
368
"""A client medium using SSH."""
959
370
def __init__(self, host, port=None, username=None, password=None,
960
bzr_remote_path='bzr'):
963
self.username = username
964
self.password = password
965
self.bzr_remote_path = bzr_remote_path
968
class SmartSSHClientMedium(SmartClientStreamMedium):
969
"""A client medium using SSH.
971
It delegates IO to a SmartSimplePipesClientMedium or
972
SmartClientAlreadyConnectedSocketMedium (depending on platform).
975
def __init__(self, base, ssh_params, vendor=None):
976
372
"""Creates a client that will connect on the first use.
978
:param ssh_params: A SSHParams instance.
979
374
:param vendor: An optional override for the ssh vendor to use. See
980
breezy.transport.ssh for details on ssh vendors.
375
bzrlib.transport.ssh for details on ssh vendors.
982
self._real_medium = None
983
self._ssh_params = ssh_params
984
# for the benefit of progress making a short description of this
986
self._scheme = 'bzr+ssh'
987
# SmartClientStreamMedium stores the repr of this object in its
988
# _DebugCounter so we have to store all the values used in our repr
989
# method before calling the super init.
990
SmartClientStreamMedium.__init__(self, base)
377
SmartClientStreamMedium.__init__(self)
378
self._connected = False
380
self._password = password
382
self._username = username
383
self._read_from = None
384
self._ssh_connection = None
991
385
self._vendor = vendor
992
self._ssh_connection = None
995
if self._ssh_params.port is None:
998
maybe_port = ':%s' % self._ssh_params.port
999
if self._ssh_params.username is None:
1002
maybe_user = '%s@' % self._ssh_params.username
1003
return "%s(%s://%s%s%s/)" % (
1004
self.__class__.__name__,
1007
self._ssh_params.host,
386
self._write_to = None
1010
388
def _accept_bytes(self, bytes):
1011
389
"""See SmartClientStreamMedium.accept_bytes."""
1012
390
self._ensure_connection()
1013
self._real_medium.accept_bytes(bytes)
391
self._write_to.write(bytes)
1015
393
def disconnect(self):
1016
394
"""See SmartClientMedium.disconnect()."""
1017
if self._real_medium is not None:
1018
self._real_medium.disconnect()
1019
self._real_medium = None
1020
if self._ssh_connection is not None:
1021
self._ssh_connection.close()
1022
self._ssh_connection = None
395
if not self._connected:
397
self._read_from.close()
398
self._write_to.close()
399
self._ssh_connection.close()
400
self._connected = False
1024
402
def _ensure_connection(self):
1025
403
"""Connect this medium if not already connected."""
1026
if self._real_medium is not None:
406
executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
1028
407
if self._vendor is None:
1029
408
vendor = ssh._get_ssh_vendor()
1031
410
vendor = self._vendor
1032
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1033
self._ssh_params.password, self._ssh_params.host,
1034
self._ssh_params.port,
1035
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1036
'--directory=/', '--allow-writes'])
1037
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1038
if io_kind == 'socket':
1039
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1040
self.base, io_object)
1041
elif io_kind == 'pipes':
1042
read_from, write_to = io_object
1043
self._real_medium = SmartSimplePipesClientMedium(
1044
read_from, write_to, self.base)
1046
raise AssertionError(
1047
"Unexpected io_kind %r from %r"
1048
% (io_kind, self._ssh_connection))
1049
for hook in transport.Transport.hooks["post_connect"]:
411
self._ssh_connection = vendor.connect_ssh(self._username,
412
self._password, self._host, self._port,
413
command=[executable, 'serve', '--inet', '--directory=/',
415
self._read_from, self._write_to = \
416
self._ssh_connection.get_filelike_channels()
417
self._connected = True
1052
419
def _flush(self):
1053
420
"""See SmartClientStreamMedium._flush()."""
1054
self._real_medium._flush()
421
self._write_to.flush()
1056
423
def _read_bytes(self, count):
1057
424
"""See SmartClientStreamMedium.read_bytes."""
1058
if self._real_medium is None:
425
if not self._connected:
1059
426
raise errors.MediumNotConnected(self)
1060
return self._real_medium.read_bytes(count)
1063
# Port 4155 is the default port for bzr://, registered with IANA.
1064
BZR_DEFAULT_INTERFACE = None
1065
BZR_DEFAULT_PORT = 4155
1068
class SmartClientSocketMedium(SmartClientStreamMedium):
1069
"""A client medium using a socket.
1071
This class isn't usable directly. Use one of its subclasses instead.
1074
def __init__(self, base):
1075
SmartClientStreamMedium.__init__(self, base)
427
return self._read_from.read(count)
430
class SmartTCPClientMedium(SmartClientStreamMedium):
431
"""A client medium using TCP."""
433
def __init__(self, host, port):
434
"""Creates a client that will connect on the first use."""
435
SmartClientStreamMedium.__init__(self)
436
self._connected = False
1076
439
self._socket = None
1077
self._connected = False
1079
441
def _accept_bytes(self, bytes):
1080
442
"""See SmartClientMedium.accept_bytes."""
1081
443
self._ensure_connection()
1082
osutils.send_all(self._socket, bytes, self._report_activity)
444
self._socket.sendall(bytes)
446
def disconnect(self):
447
"""See SmartClientMedium.disconnect()."""
448
if not self._connected:
452
self._connected = False
1084
454
def _ensure_connection(self):
1085
455
"""Connect this medium if not already connected."""
1086
raise NotImplementedError(self._ensure_connection)
458
self._socket = socket.socket()
459
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
460
result = self._socket.connect_ex((self._host, int(self._port)))
462
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
463
(self._host, self._port, os.strerror(result)))
464
self._connected = True
1088
466
def _flush(self):
1089
467
"""See SmartClientStreamMedium._flush().
1091
For sockets we do no flushing. For TCP sockets we may want to turn off
1092
TCP_NODELAY and add a means to do a flush, but that can be done in the
469
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
470
add a means to do a flush, but that can be done in the future.
1096
473
def _read_bytes(self, count):
1097
474
"""See SmartClientMedium.read_bytes."""
1098
475
if not self._connected:
1099
476
raise errors.MediumNotConnected(self)
1100
return osutils.read_bytes_from_socket(
1101
self._socket, self._report_activity)
1103
def disconnect(self):
1104
"""See SmartClientMedium.disconnect()."""
1105
if not self._connected:
1107
self._socket.close()
1109
self._connected = False
1112
class SmartTCPClientMedium(SmartClientSocketMedium):
1113
"""A client medium that creates a TCP connection."""
1115
def __init__(self, host, port, base):
1116
"""Creates a client that will connect on the first use."""
1117
SmartClientSocketMedium.__init__(self, base)
1121
def _ensure_connection(self):
1122
"""Connect this medium if not already connected."""
1125
if self._port is None:
1126
port = BZR_DEFAULT_PORT
1128
port = int(self._port)
1130
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1131
socket.SOCK_STREAM, 0, 0)
1132
except socket.gaierror as xxx_todo_changeme:
1133
(err_num, err_msg) = xxx_todo_changeme.args
1134
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1135
(self._host, port, err_msg))
1136
# Initialize err in case there are no addresses returned:
1137
last_err = socket.error("no address found for %s" % self._host)
1138
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1140
self._socket = socket.socket(family, socktype, proto)
1141
self._socket.setsockopt(socket.IPPROTO_TCP,
1142
socket.TCP_NODELAY, 1)
1143
self._socket.connect(sockaddr)
1144
except socket.error as err:
1145
if self._socket is not None:
1146
self._socket.close()
1151
if self._socket is None:
1152
# socket errors either have a (string) or (errno, string) as their
1154
if isinstance(last_err.args, str):
1155
err_msg = last_err.args
1157
err_msg = last_err.args[1]
1158
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1159
(self._host, port, err_msg))
1160
self._connected = True
1161
for hook in transport.Transport.hooks["post_connect"]:
1165
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1166
"""A client medium for an already connected socket.
1168
Note that this class will assume it "owns" the socket, so it will close it
1169
when its disconnect method is called.
1172
def __init__(self, base, sock):
1173
SmartClientSocketMedium.__init__(self, base)
1175
self._connected = True
1177
def _ensure_connection(self):
1178
# Already connected, by definition! So nothing to do.
477
return self._socket.recv(count)
1182
480
class SmartClientStreamMediumRequest(SmartClientMediumRequest):