1
# Copyright (C) 2006-2011 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""The 'medium' layer for the smart servers and clients.
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
20
for "the quality between big and small."
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
over SSH), and pass them to and from the protocol logic. See the overview in
24
breezy/transport/smart/__init__.py.
27
from __future__ import absolute_import
38
import thread as _thread
41
from ...lazy_import import lazy_import
42
lazy_import(globals(), """
54
from breezy.i18n import gettext
55
from breezy.bzr.smart import client, protocol, request, signals, vfs
56
from breezy.transport import ssh
63
# Throughout this module buffer size parameters are either limited to be at
64
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
65
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
66
# from non-sockets as well.
67
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
70
class HpssVfsRequestNotAllowed(errors.BzrError):
72
_fmt = ("VFS requests over the smart server are not allowed. Encountered: "
73
"%(method)s, %(arguments)s.")
75
def __init__(self, method, arguments):
77
self.arguments = arguments
80
def _get_protocol_factory_for_bytes(bytes):
81
"""Determine the right protocol factory for 'bytes'.
83
This will return an appropriate protocol factory depending on the version
84
of the protocol being used, as determined by inspecting the given bytes.
85
The bytes should have at least one newline byte (i.e. be a whole line),
86
otherwise it's possible that a request will be incorrectly identified as
89
Typical use would be::
91
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
92
server_protocol = factory(transport, write_func, root_client_path)
93
server_protocol.accept_bytes(unused_bytes)
95
:param bytes: a str of bytes of the start of the request.
96
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
97
a callable that takes three args: transport, write_func,
98
root_client_path. unused_bytes are any bytes that were not part of a
99
protocol version marker.
101
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
102
protocol_factory = protocol.build_server_protocol_three
103
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
104
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
105
protocol_factory = protocol.SmartServerRequestProtocolTwo
106
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
108
protocol_factory = protocol.SmartServerRequestProtocolOne
109
return protocol_factory, bytes
112
def _get_line(read_bytes_func):
113
"""Read bytes using read_bytes_func until a newline byte.
115
This isn't particularly efficient, so should only be used when the
116
expected size of the line is quite short.
118
:returns: a tuple of two strs: (line, excess)
122
while newline_pos == -1:
123
new_bytes = read_bytes_func(1)
126
# Ran out of bytes before receiving a complete line.
128
newline_pos = bytes.find(b'\n')
129
line = bytes[:newline_pos + 1]
130
excess = bytes[newline_pos + 1:]
134
class SmartMedium(object):
135
"""Base class for smart protocol media, both client- and server-side."""
138
self._push_back_buffer = None
140
def _push_back(self, data):
141
"""Return unused bytes to the medium, because they belong to the next
144
This sets the _push_back_buffer to the given bytes.
146
if not isinstance(data, bytes):
147
raise TypeError(data)
148
if self._push_back_buffer is not None:
149
raise AssertionError(
150
"_push_back called when self._push_back_buffer is %r"
151
% (self._push_back_buffer,))
154
self._push_back_buffer = data
156
def _get_push_back_buffer(self):
157
if self._push_back_buffer == b'':
158
raise AssertionError(
159
'%s._push_back_buffer should never be the empty string, '
160
'which can be confused with EOF' % (self,))
161
bytes = self._push_back_buffer
162
self._push_back_buffer = None
165
def read_bytes(self, desired_count):
166
"""Read some bytes from this medium.
168
:returns: some bytes, possibly more or less than the number requested
169
in 'desired_count' depending on the medium.
171
if self._push_back_buffer is not None:
172
return self._get_push_back_buffer()
173
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
174
return self._read_bytes(bytes_to_read)
176
def _read_bytes(self, count):
177
raise NotImplementedError(self._read_bytes)
180
"""Read bytes from this request's response until a newline byte.
182
This isn't particularly efficient, so should only be used when the
183
expected size of the line is quite short.
185
:returns: a string of bytes ending in a newline (byte 0x0A).
187
line, excess = _get_line(self.read_bytes)
188
self._push_back(excess)
191
def _report_activity(self, bytes, direction):
192
"""Notify that this medium has activity.
194
Implementations should call this from all methods that actually do IO.
195
Be careful that it's not called twice, if one method is implemented on
198
:param bytes: Number of bytes read or written.
199
:param direction: 'read' or 'write' or None.
201
ui.ui_factory.report_transport_activity(self, bytes, direction)
204
_bad_file_descriptor = (errno.EBADF,)
205
if sys.platform == 'win32':
206
# Given on Windows if you pass a closed socket to select.select. Probably
207
# also given if you pass a file handle to select.
209
_bad_file_descriptor += (WSAENOTSOCK,)
212
class SmartServerStreamMedium(SmartMedium):
213
"""Handles smart commands coming over a stream.
215
The stream may be a pipe connected to sshd, or a tcp socket, or an
216
in-process fifo for testing.
218
One instance is created for each connected client; it can serve multiple
219
requests in the lifetime of the connection.
221
The server passes requests through to an underlying backing transport,
222
which will typically be a LocalTransport looking at the server's filesystem.
224
:ivar _push_back_buffer: a str of bytes that have been read from the stream
225
but not used yet, or None if there are no buffered bytes. Subclasses
226
should make sure to exhaust this buffer before reading more bytes from
227
the stream. See also the _push_back method.
232
def __init__(self, backing_transport, root_client_path='/', timeout=None):
233
"""Construct new server.
235
:param backing_transport: Transport for the directory served.
237
# backing_transport could be passed to serve instead of __init__
238
self.backing_transport = backing_transport
239
self.root_client_path = root_client_path
240
self.finished = False
242
raise AssertionError('You must supply a timeout.')
243
self._client_timeout = timeout
244
self._client_poll_timeout = min(timeout / 10.0, 1.0)
245
SmartMedium.__init__(self)
248
"""Serve requests until the client disconnects."""
249
# Keep a reference to stderr because the sys module's globals get set to
250
# None during interpreter shutdown.
251
from sys import stderr
253
while not self.finished:
254
server_protocol = self._build_protocol()
255
self._serve_one_request(server_protocol)
256
except errors.ConnectionTimeout as e:
257
trace.note('%s' % (e,))
258
trace.log_exception_quietly()
259
self._disconnect_client()
260
# We reported it, no reason to make a big fuss.
262
except Exception as e:
263
stderr.write("%s terminating on exception %s\n" % (self, e))
265
self._disconnect_client()
267
def _stop_gracefully(self):
268
"""When we finish this message, stop looking for more."""
269
trace.mutter('Stopping %s' % (self,))
272
def _disconnect_client(self):
273
"""Close the current connection. We stopped due to a timeout/etc."""
274
# The default implementation is a no-op, because that is all we used to
275
# do when disconnecting from a client. I suppose we never had the
276
# *server* initiate a disconnect, before
278
def _wait_for_bytes_with_timeout(self, timeout_seconds):
279
"""Wait for more bytes to be read, but timeout if none available.
281
This allows us to detect idle connections, and stop trying to read from
282
them, without setting the socket itself to non-blocking. This also
283
allows us to specify when we watch for idle timeouts.
285
:return: Did we timeout? (True if we timed out, False if there is data
288
raise NotImplementedError(self._wait_for_bytes_with_timeout)
290
def _build_protocol(self):
291
"""Identifies the version of the incoming request, and returns an
292
a protocol object that can interpret it.
294
If more bytes than the version prefix of the request are read, they will
295
be fed into the protocol before it is returned.
297
:returns: a SmartServerRequestProtocol.
299
self._wait_for_bytes_with_timeout(self._client_timeout)
301
# We're stopping, so don't try to do any more work
303
bytes = self._get_line()
304
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
305
protocol = protocol_factory(
306
self.backing_transport, self._write_out, self.root_client_path)
307
protocol.accept_bytes(unused_bytes)
310
def _wait_on_descriptor(self, fd, timeout_seconds):
311
"""select() on a file descriptor, waiting for nonblocking read()
313
This will raise a ConnectionTimeout exception if we do not get a
314
readable handle before timeout_seconds.
317
t_end = self._timer() + timeout_seconds
318
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
320
while not rs and not xs and self._timer() < t_end:
324
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
325
except (select.error, socket.error) as e:
326
err = getattr(e, 'errno', None)
327
if err is None and getattr(e, 'args', None) is not None:
328
# select.error doesn't have 'errno', it just has args[0]
330
if err in _bad_file_descriptor:
331
return # Not a socket indicates read() will fail
332
elif err == errno.EINTR:
333
# Interrupted, keep looping.
337
return # Socket may already be closed
340
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
341
% (timeout_seconds,))
343
def _serve_one_request(self, protocol):
344
"""Read one request from input, process, send back a response.
346
:param protocol: a SmartServerRequestProtocol.
351
self._serve_one_request_unguarded(protocol)
352
except KeyboardInterrupt:
354
except Exception as e:
355
self.terminate_due_to_error()
357
def terminate_due_to_error(self):
358
"""Called when an unhandled exception from the protocol occurs."""
359
raise NotImplementedError(self.terminate_due_to_error)
361
def _read_bytes(self, desired_count):
362
"""Get some bytes from the medium.
364
:param desired_count: number of bytes we want to read.
366
raise NotImplementedError(self._read_bytes)
369
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
371
def __init__(self, sock, backing_transport, root_client_path='/',
375
:param sock: the socket the server will read from. It will be put
378
SmartServerStreamMedium.__init__(
379
self, backing_transport, root_client_path=root_client_path,
381
sock.setblocking(True)
383
# Get the getpeername now, as we might be closed later when we care.
385
self._client_info = sock.getpeername()
387
self._client_info = '<unknown>'
390
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
393
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
396
def _serve_one_request_unguarded(self, protocol):
397
while protocol.next_read_size():
398
# We can safely try to read large chunks. If there is less data
399
# than MAX_SOCKET_CHUNK ready, the socket will just return a
400
# short read immediately rather than block.
401
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
405
protocol.accept_bytes(bytes)
407
self._push_back(protocol.unused_data)
409
def _disconnect_client(self):
410
"""Close the current connection. We stopped due to a timeout/etc."""
413
def _wait_for_bytes_with_timeout(self, timeout_seconds):
414
"""Wait for more bytes to be read, but timeout if none available.
416
This allows us to detect idle connections, and stop trying to read from
417
them, without setting the socket itself to non-blocking. This also
418
allows us to specify when we watch for idle timeouts.
420
:return: None, this will raise ConnectionTimeout if we time out before
423
return self._wait_on_descriptor(self.socket, timeout_seconds)
425
def _read_bytes(self, desired_count):
426
return osutils.read_bytes_from_socket(
427
self.socket, self._report_activity)
429
def terminate_due_to_error(self):
430
# TODO: This should log to a server log file, but no such thing
431
# exists yet. Andrew Bennetts 2006-09-29.
435
def _write_out(self, bytes):
436
tstart = osutils.timer_func()
437
osutils.send_all(self.socket, bytes, self._report_activity)
438
if 'hpss' in debug.debug_flags:
439
thread_id = _thread.get_ident()
440
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
441
% ('wrote', thread_id, len(bytes),
442
osutils.timer_func() - tstart))
445
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
447
def __init__(self, in_file, out_file, backing_transport, timeout=None):
448
"""Construct new server.
450
:param in_file: Python file from which requests can be read.
451
:param out_file: Python file to write responses.
452
:param backing_transport: Transport for the directory served.
454
SmartServerStreamMedium.__init__(self, backing_transport,
456
if sys.platform == 'win32':
457
# force binary mode for files
459
for f in (in_file, out_file):
460
fileno = getattr(f, 'fileno', None)
462
msvcrt.setmode(fileno(), os.O_BINARY)
467
"""See SmartServerStreamMedium.serve"""
468
# This is the regular serve, except it adds signal trapping for soft
470
stop_gracefully = self._stop_gracefully
471
signals.register_on_hangup(id(self), stop_gracefully)
473
return super(SmartServerPipeStreamMedium, self).serve()
475
signals.unregister_on_hangup(id(self))
477
def _serve_one_request_unguarded(self, protocol):
479
# We need to be careful not to read past the end of the current
480
# request, or else the read from the pipe will block, so we use
481
# protocol.next_read_size().
482
bytes_to_read = protocol.next_read_size()
483
if bytes_to_read == 0:
484
# Finished serving this request.
487
bytes = self.read_bytes(bytes_to_read)
489
# Connection has been closed.
493
protocol.accept_bytes(bytes)
495
def _disconnect_client(self):
500
def _wait_for_bytes_with_timeout(self, timeout_seconds):
501
"""Wait for more bytes to be read, but timeout if none available.
503
This allows us to detect idle connections, and stop trying to read from
504
them, without setting the socket itself to non-blocking. This also
505
allows us to specify when we watch for idle timeouts.
507
:return: None, this will raise ConnectionTimeout if we time out before
510
if (getattr(self._in, 'fileno', None) is None
511
or sys.platform == 'win32'):
512
# You can't select() file descriptors on Windows.
515
return self._wait_on_descriptor(self._in, timeout_seconds)
516
except io.UnsupportedOperation:
519
def _read_bytes(self, desired_count):
520
return self._in.read(desired_count)
522
def terminate_due_to_error(self):
523
# TODO: This should log to a server log file, but no such thing
524
# exists yet. Andrew Bennetts 2006-09-29.
528
def _write_out(self, bytes):
529
self._out.write(bytes)
532
class SmartClientMediumRequest(object):
533
"""A request on a SmartClientMedium.
535
Each request allows bytes to be provided to it via accept_bytes, and then
536
the response bytes to be read via read_bytes.
539
request.accept_bytes('123')
540
request.finished_writing()
541
result = request.read_bytes(3)
542
request.finished_reading()
544
It is up to the individual SmartClientMedium whether multiple concurrent
545
requests can exist. See SmartClientMedium.get_request to obtain instances
546
of SmartClientMediumRequest, and the concrete Medium you are using for
547
details on concurrency and pipelining.
550
def __init__(self, medium):
551
"""Construct a SmartClientMediumRequest for the medium medium."""
552
self._medium = medium
553
# we track state by constants - we may want to use the same
554
# pattern as BodyReader if it gets more complex.
555
# valid states are: "writing", "reading", "done"
556
self._state = "writing"
558
def accept_bytes(self, bytes):
559
"""Accept bytes for inclusion in this request.
561
This method may not be called after finished_writing() has been
562
called. It depends upon the Medium whether or not the bytes will be
563
immediately transmitted. Message based Mediums will tend to buffer the
564
bytes until finished_writing() is called.
566
:param bytes: A bytestring.
568
if self._state != "writing":
569
raise errors.WritingCompleted(self)
570
self._accept_bytes(bytes)
572
def _accept_bytes(self, bytes):
573
"""Helper for accept_bytes.
575
Accept_bytes checks the state of the request to determing if bytes
576
should be accepted. After that it hands off to _accept_bytes to do the
579
raise NotImplementedError(self._accept_bytes)
581
def finished_reading(self):
582
"""Inform the request that all desired data has been read.
584
This will remove the request from the pipeline for its medium (if the
585
medium supports pipelining) and any further calls to methods on the
586
request will raise ReadingCompleted.
588
if self._state == "writing":
589
raise errors.WritingNotComplete(self)
590
if self._state != "reading":
591
raise errors.ReadingCompleted(self)
593
self._finished_reading()
595
def _finished_reading(self):
596
"""Helper for finished_reading.
598
finished_reading checks the state of the request to determine if
599
finished_reading is allowed, and if it is hands off to _finished_reading
600
to perform the action.
602
raise NotImplementedError(self._finished_reading)
604
def finished_writing(self):
605
"""Finish the writing phase of this request.
607
This will flush all pending data for this request along the medium.
608
After calling finished_writing, you may not call accept_bytes anymore.
610
if self._state != "writing":
611
raise errors.WritingCompleted(self)
612
self._state = "reading"
613
self._finished_writing()
615
def _finished_writing(self):
616
"""Helper for finished_writing.
618
finished_writing checks the state of the request to determine if
619
finished_writing is allowed, and if it is hands off to _finished_writing
620
to perform the action.
622
raise NotImplementedError(self._finished_writing)
624
def read_bytes(self, count):
625
"""Read bytes from this requests response.
627
This method will block and wait for count bytes to be read. It may not
628
be invoked until finished_writing() has been called - this is to ensure
629
a message-based approach to requests, for compatibility with message
630
based mediums like HTTP.
632
if self._state == "writing":
633
raise errors.WritingNotComplete(self)
634
if self._state != "reading":
635
raise errors.ReadingCompleted(self)
636
return self._read_bytes(count)
638
def _read_bytes(self, count):
639
"""Helper for SmartClientMediumRequest.read_bytes.
641
read_bytes checks the state of the request to determing if bytes
642
should be read. After that it hands off to _read_bytes to do the
645
By default this forwards to self._medium.read_bytes because we are
646
operating on the medium's stream.
648
return self._medium.read_bytes(count)
651
line = self._read_line()
652
if not line.endswith(b'\n'):
653
# end of file encountered reading from server
654
raise errors.ConnectionReset(
655
"Unexpected end of message. Please check connectivity "
656
"and permissions, and report a bug if problems persist.")
659
def _read_line(self):
660
"""Helper for SmartClientMediumRequest.read_line.
662
By default this forwards to self._medium._get_line because we are
663
operating on the medium's stream.
665
return self._medium._get_line()
668
class _VfsRefuser(object):
669
"""An object that refuses all VFS requests.
674
client._SmartClient.hooks.install_named_hook(
675
'call', self.check_vfs, 'vfs refuser')
677
def check_vfs(self, params):
679
request_method = request.request_handlers.get(params.method)
681
# A method we don't know about doesn't count as a VFS method.
683
if issubclass(request_method, vfs.VfsRequest):
684
raise HpssVfsRequestNotAllowed(params.method, params.args)
687
class _DebugCounter(object):
688
"""An object that counts the HPSS calls made to each client medium.
690
When a medium is garbage-collected, or failing that when
691
breezy.global_state exits, the total number of calls made on that medium
692
are reported via trace.note.
696
self.counts = weakref.WeakKeyDictionary()
697
client._SmartClient.hooks.install_named_hook(
698
'call', self.increment_call_count, 'hpss call counter')
699
breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
701
def track(self, medium):
702
"""Start tracking calls made to a medium.
704
This only keeps a weakref to the medium, so shouldn't affect the
707
medium_repr = repr(medium)
708
# Add this medium to the WeakKeyDictionary
709
self.counts[medium] = dict(count=0, vfs_count=0,
710
medium_repr=medium_repr)
711
# Weakref callbacks are fired in reverse order of their association
712
# with the referenced object. So we add a weakref *after* adding to
713
# the WeakKeyDict so that we can report the value from it before the
714
# entry is removed by the WeakKeyDict's own callback.
715
ref = weakref.ref(medium, self.done)
717
def increment_call_count(self, params):
718
# Increment the count in the WeakKeyDictionary
719
value = self.counts[params.medium]
722
request_method = request.request_handlers.get(params.method)
724
# A method we don't know about doesn't count as a VFS method.
726
if issubclass(request_method, vfs.VfsRequest):
727
value['vfs_count'] += 1
730
value = self.counts[ref]
731
count, vfs_count, medium_repr = (
732
value['count'], value['vfs_count'], value['medium_repr'])
733
# In case this callback is invoked for the same ref twice (by the
734
# weakref callback and by the atexit function), set the call count back
735
# to 0 so this item won't be reported twice.
737
value['vfs_count'] = 0
739
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
740
count, vfs_count, medium_repr))
743
for ref in list(self.counts.keys()):
747
_debug_counter = None
751
class SmartClientMedium(SmartMedium):
752
"""Smart client is a medium for sending smart protocol requests over."""
754
def __init__(self, base):
755
super(SmartClientMedium, self).__init__()
757
self._protocol_version_error = None
758
self._protocol_version = None
759
self._done_hello = False
760
# Be optimistic: we assume the remote end can accept new remote
761
# requests until we get an error saying otherwise.
762
# _remote_version_is_before tracks the bzr version the remote side
763
# can be based on what we've seen so far.
764
self._remote_version_is_before = None
765
# Install debug hook function if debug flag is set.
766
if 'hpss' in debug.debug_flags:
767
global _debug_counter
768
if _debug_counter is None:
769
_debug_counter = _DebugCounter()
770
_debug_counter.track(self)
771
if 'hpss_client_no_vfs' in debug.debug_flags:
773
if _vfs_refuser is None:
774
_vfs_refuser = _VfsRefuser()
776
def _is_remote_before(self, version_tuple):
777
"""Is it possible the remote side supports RPCs for a given version?
781
needed_version = (1, 2)
782
if medium._is_remote_before(needed_version):
783
fallback_to_pre_1_2_rpc()
787
except UnknownSmartMethod:
788
medium._remember_remote_is_before(needed_version)
789
fallback_to_pre_1_2_rpc()
791
:seealso: _remember_remote_is_before
793
if self._remote_version_is_before is None:
794
# So far, the remote side seems to support everything
796
return version_tuple >= self._remote_version_is_before
798
def _remember_remote_is_before(self, version_tuple):
799
"""Tell this medium that the remote side is older the given version.
801
:seealso: _is_remote_before
803
if (self._remote_version_is_before is not None and
804
version_tuple > self._remote_version_is_before):
805
# We have been told that the remote side is older than some version
806
# which is newer than a previously supplied older-than version.
807
# This indicates that some smart verb call is not guarded
808
# appropriately (it should simply not have been tried).
810
"_remember_remote_is_before(%r) called, but "
811
"_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
812
if 'hpss' in debug.debug_flags:
813
ui.ui_factory.show_warning(
814
"_remember_remote_is_before(%r) called, but "
815
"_remember_remote_is_before(%r) was called previously."
816
% (version_tuple, self._remote_version_is_before))
818
self._remote_version_is_before = version_tuple
820
def protocol_version(self):
821
"""Find out if 'hello' smart request works."""
822
if self._protocol_version_error is not None:
823
raise self._protocol_version_error
824
if not self._done_hello:
826
medium_request = self.get_request()
827
# Send a 'hello' request in protocol version one, for maximum
828
# backwards compatibility.
829
client_protocol = protocol.SmartClientRequestProtocolOne(
831
client_protocol.query_version()
832
self._done_hello = True
833
except errors.SmartProtocolError as e:
834
# Cache the error, just like we would cache a successful
836
self._protocol_version_error = e
840
def should_probe(self):
841
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
844
Some transports are unambiguously smart-only; there's no need to check
845
if the transport is able to carry smart requests, because that's all
846
it is for. In those cases, this method should return False.
848
But some HTTP transports can sometimes fail to carry smart requests,
849
but still be usuable for accessing remote bzrdirs via plain file
850
accesses. So for those transports, their media should return True here
851
so that RemoteBzrDirFormat can determine if it is appropriate for that
856
def disconnect(self):
857
"""If this medium maintains a persistent connection, close it.
859
The default implementation does nothing.
862
def remote_path_from_transport(self, transport):
863
"""Convert transport into a path suitable for using in a request.
865
Note that the resulting remote path doesn't encode the host name or
866
anything but path, so it is only safe to use it in requests sent over
867
the medium from the matching transport.
869
medium_base = urlutils.join(self.base, '/')
870
rel_url = urlutils.relative_url(medium_base, transport.base)
871
return urlutils.unquote(rel_url)
874
class SmartClientStreamMedium(SmartClientMedium):
875
"""Stream based medium common class.
877
SmartClientStreamMediums operate on a stream. All subclasses use a common
878
SmartClientStreamMediumRequest for their requests, and should implement
879
_accept_bytes and _read_bytes to allow the request objects to send and
883
def __init__(self, base):
884
SmartClientMedium.__init__(self, base)
885
self._current_request = None
887
def accept_bytes(self, bytes):
888
self._accept_bytes(bytes)
891
"""The SmartClientStreamMedium knows how to close the stream when it is
897
"""Flush the output stream.
899
This method is used by the SmartClientStreamMediumRequest to ensure that
900
all data for a request is sent, to avoid long timeouts or deadlocks.
902
raise NotImplementedError(self._flush)
904
def get_request(self):
905
"""See SmartClientMedium.get_request().
907
SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
910
return SmartClientStreamMediumRequest(self)
913
"""We have been disconnected, reset current state.
915
This resets things like _current_request and connected state.
918
self._current_request = None
921
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
922
"""A client medium using simple pipes.
924
This client does not manage the pipes: it assumes they will always be open.
927
def __init__(self, readable_pipe, writeable_pipe, base):
928
SmartClientStreamMedium.__init__(self, base)
929
self._readable_pipe = readable_pipe
930
self._writeable_pipe = writeable_pipe
932
def _accept_bytes(self, data):
933
"""See SmartClientStreamMedium.accept_bytes."""
935
self._writeable_pipe.write(data)
937
if e.errno in (errno.EINVAL, errno.EPIPE):
938
raise errors.ConnectionReset(
939
"Error trying to write to subprocess", e)
941
self._report_activity(len(data), 'write')
944
"""See SmartClientStreamMedium._flush()."""
945
# Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
946
# However, testing shows that even when the child process is
947
# gone, this doesn't error.
948
self._writeable_pipe.flush()
950
def _read_bytes(self, count):
951
"""See SmartClientStreamMedium._read_bytes."""
952
bytes_to_read = min(count, _MAX_READ_SIZE)
953
data = self._readable_pipe.read(bytes_to_read)
954
self._report_activity(len(data), 'read')
958
class SSHParams(object):
959
"""A set of parameters for starting a remote bzr via SSH."""
961
def __init__(self, host, port=None, username=None, password=None,
962
bzr_remote_path='bzr'):
965
self.username = username
966
self.password = password
967
self.bzr_remote_path = bzr_remote_path
970
class SmartSSHClientMedium(SmartClientStreamMedium):
971
"""A client medium using SSH.
973
It delegates IO to a SmartSimplePipesClientMedium or
974
SmartClientAlreadyConnectedSocketMedium (depending on platform).
977
def __init__(self, base, ssh_params, vendor=None):
978
"""Creates a client that will connect on the first use.
980
:param ssh_params: A SSHParams instance.
981
:param vendor: An optional override for the ssh vendor to use. See
982
breezy.transport.ssh for details on ssh vendors.
984
self._real_medium = None
985
self._ssh_params = ssh_params
986
# for the benefit of progress making a short description of this
988
self._scheme = 'bzr+ssh'
989
# SmartClientStreamMedium stores the repr of this object in its
990
# _DebugCounter so we have to store all the values used in our repr
991
# method before calling the super init.
992
SmartClientStreamMedium.__init__(self, base)
993
self._vendor = vendor
994
self._ssh_connection = None
997
if self._ssh_params.port is None:
1000
maybe_port = ':%s' % self._ssh_params.port
1001
if self._ssh_params.username is None:
1004
maybe_user = '%s@' % self._ssh_params.username
1005
return "%s(%s://%s%s%s/)" % (
1006
self.__class__.__name__,
1009
self._ssh_params.host,
1012
def _accept_bytes(self, bytes):
1013
"""See SmartClientStreamMedium.accept_bytes."""
1014
self._ensure_connection()
1015
self._real_medium.accept_bytes(bytes)
1017
def disconnect(self):
1018
"""See SmartClientMedium.disconnect()."""
1019
if self._real_medium is not None:
1020
self._real_medium.disconnect()
1021
self._real_medium = None
1022
if self._ssh_connection is not None:
1023
self._ssh_connection.close()
1024
self._ssh_connection = None
1026
def _ensure_connection(self):
1027
"""Connect this medium if not already connected."""
1028
if self._real_medium is not None:
1030
if self._vendor is None:
1031
vendor = ssh._get_ssh_vendor()
1033
vendor = self._vendor
1034
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1035
self._ssh_params.password, self._ssh_params.host,
1036
self._ssh_params.port,
1037
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1038
'--directory=/', '--allow-writes'])
1039
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1040
if io_kind == 'socket':
1041
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1042
self.base, io_object)
1043
elif io_kind == 'pipes':
1044
read_from, write_to = io_object
1045
self._real_medium = SmartSimplePipesClientMedium(
1046
read_from, write_to, self.base)
1048
raise AssertionError(
1049
"Unexpected io_kind %r from %r"
1050
% (io_kind, self._ssh_connection))
1051
for hook in transport.Transport.hooks["post_connect"]:
1055
"""See SmartClientStreamMedium._flush()."""
1056
self._real_medium._flush()
1058
def _read_bytes(self, count):
1059
"""See SmartClientStreamMedium.read_bytes."""
1060
if self._real_medium is None:
1061
raise errors.MediumNotConnected(self)
1062
return self._real_medium.read_bytes(count)
1065
# Port 4155 is the default port for bzr://, registered with IANA.
1066
BZR_DEFAULT_INTERFACE = None
1067
BZR_DEFAULT_PORT = 4155
1070
class SmartClientSocketMedium(SmartClientStreamMedium):
1071
"""A client medium using a socket.
1073
This class isn't usable directly. Use one of its subclasses instead.
1076
def __init__(self, base):
1077
SmartClientStreamMedium.__init__(self, base)
1079
self._connected = False
1081
def _accept_bytes(self, bytes):
1082
"""See SmartClientMedium.accept_bytes."""
1083
self._ensure_connection()
1084
osutils.send_all(self._socket, bytes, self._report_activity)
1086
def _ensure_connection(self):
1087
"""Connect this medium if not already connected."""
1088
raise NotImplementedError(self._ensure_connection)
1091
"""See SmartClientStreamMedium._flush().
1093
For sockets we do no flushing. For TCP sockets we may want to turn off
1094
TCP_NODELAY and add a means to do a flush, but that can be done in the
1098
def _read_bytes(self, count):
1099
"""See SmartClientMedium.read_bytes."""
1100
if not self._connected:
1101
raise errors.MediumNotConnected(self)
1102
return osutils.read_bytes_from_socket(
1103
self._socket, self._report_activity)
1105
def disconnect(self):
1106
"""See SmartClientMedium.disconnect()."""
1107
if not self._connected:
1109
self._socket.close()
1111
self._connected = False
1114
class SmartTCPClientMedium(SmartClientSocketMedium):
1115
"""A client medium that creates a TCP connection."""
1117
def __init__(self, host, port, base):
1118
"""Creates a client that will connect on the first use."""
1119
SmartClientSocketMedium.__init__(self, base)
1123
def _ensure_connection(self):
1124
"""Connect this medium if not already connected."""
1127
if self._port is None:
1128
port = BZR_DEFAULT_PORT
1130
port = int(self._port)
1132
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1133
socket.SOCK_STREAM, 0, 0)
1134
except socket.gaierror as xxx_todo_changeme:
1135
(err_num, err_msg) = xxx_todo_changeme.args
1136
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1137
(self._host, port, err_msg))
1138
# Initialize err in case there are no addresses returned:
1139
last_err = socket.error("no address found for %s" % self._host)
1140
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1142
self._socket = socket.socket(family, socktype, proto)
1143
self._socket.setsockopt(socket.IPPROTO_TCP,
1144
socket.TCP_NODELAY, 1)
1145
self._socket.connect(sockaddr)
1146
except socket.error as err:
1147
if self._socket is not None:
1148
self._socket.close()
1153
if self._socket is None:
1154
# socket errors either have a (string) or (errno, string) as their
1156
if isinstance(last_err.args, str):
1157
err_msg = last_err.args
1159
err_msg = last_err.args[1]
1160
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1161
(self._host, port, err_msg))
1162
self._connected = True
1163
for hook in transport.Transport.hooks["post_connect"]:
1167
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1168
"""A client medium for an already connected socket.
1170
Note that this class will assume it "owns" the socket, so it will close it
1171
when its disconnect method is called.
1174
def __init__(self, base, sock):
1175
SmartClientSocketMedium.__init__(self, base)
1177
self._connected = True
1179
def _ensure_connection(self):
1180
# Already connected, by definition! So nothing to do.
1184
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1185
"""A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
1187
def __init__(self, medium):
1188
SmartClientMediumRequest.__init__(self, medium)
1189
# check that we are safe concurrency wise. If some streams start
1190
# allowing concurrent requests - i.e. via multiplexing - then this
1191
# assert should be moved to SmartClientStreamMedium.get_request,
1192
# and the setting/unsetting of _current_request likewise moved into
1193
# that class : but its unneeded overhead for now. RBC 20060922
1194
if self._medium._current_request is not None:
1195
raise errors.TooManyConcurrentRequests(self._medium)
1196
self._medium._current_request = self
1198
def _accept_bytes(self, bytes):
1199
"""See SmartClientMediumRequest._accept_bytes.
1201
This forwards to self._medium._accept_bytes because we are operating
1202
on the mediums stream.
1204
self._medium._accept_bytes(bytes)
1206
def _finished_reading(self):
1207
"""See SmartClientMediumRequest._finished_reading.
1209
This clears the _current_request on self._medium to allow a new
1210
request to be created.
1212
if self._medium._current_request is not self:
1213
raise AssertionError()
1214
self._medium._current_request = None
1216
def _finished_writing(self):
1217
"""See SmartClientMediumRequest._finished_writing.
1219
This invokes self._medium._flush to ensure all bytes are transmitted.
1221
self._medium._flush()