676
809
while not _body_decoder.finished_reading:
677
810
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
678
811
bytes = self._request.read_bytes(bytes_wanted)
813
# end of file encountered reading from server
814
raise errors.ConnectionReset(
815
"Connection lost while reading streamed body.")
679
816
_body_decoder.accept_bytes(bytes)
680
817
for body_bytes in iter(_body_decoder.read_next_chunk, None):
681
if 'hpss' in debug.debug_flags:
818
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
682
819
mutter(' %d byte chunk read',
685
822
self._request.finished_reading()
825
def build_server_protocol_three(backing_transport, write_func,
827
request_handler = request.SmartServerRequestHandler(
828
backing_transport, commands=request.request_handlers,
829
root_client_path=root_client_path)
830
responder = ProtocolThreeResponder(write_func)
831
message_handler = message.ConventionalRequestHandler(request_handler, responder)
832
return ProtocolThreeDecoder(message_handler)
835
class ProtocolThreeDecoder(_StatefulDecoder):
837
response_marker = RESPONSE_VERSION_THREE
838
request_marker = REQUEST_VERSION_THREE
840
def __init__(self, message_handler, expect_version_marker=False):
841
_StatefulDecoder.__init__(self)
842
self._has_dispatched = False
844
if expect_version_marker:
845
self.state_accept = self._state_accept_expecting_protocol_version
846
# We're expecting at least the protocol version marker + some
848
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
850
self.state_accept = self._state_accept_expecting_headers
851
self._number_needed_bytes = 4
852
self.decoding_failed = False
853
self.request_handler = self.message_handler = message_handler
855
def accept_bytes(self, bytes):
856
self._number_needed_bytes = None
858
_StatefulDecoder.accept_bytes(self, bytes)
859
except KeyboardInterrupt:
861
except errors.SmartMessageHandlerError, exception:
862
# We do *not* set self.decoding_failed here. The message handler
863
# has raised an error, but the decoder is still able to parse bytes
864
# and determine when this message ends.
865
log_exception_quietly()
866
self.message_handler.protocol_error(exception.exc_value)
867
# The state machine is ready to continue decoding, but the
868
# exception has interrupted the loop that runs the state machine.
869
# So we call accept_bytes again to restart it.
870
self.accept_bytes('')
871
except Exception, exception:
872
# The decoder itself has raised an exception. We cannot continue
874
self.decoding_failed = True
875
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
876
# This happens during normal operation when the client tries a
877
# protocol version the server doesn't understand, so no need to
878
# log a traceback every time.
879
# Note that this can only happen when
880
# expect_version_marker=True, which is only the case on the
884
log_exception_quietly()
885
self.message_handler.protocol_error(exception)
887
def _extract_length_prefixed_bytes(self):
888
if len(self._in_buffer) < 4:
889
# A length prefix by itself is 4 bytes, and we don't even have that
891
raise _NeedMoreBytes(4)
892
(length,) = struct.unpack('!L', self._in_buffer[:4])
893
end_of_bytes = 4 + length
894
if len(self._in_buffer) < end_of_bytes:
895
# We haven't yet read as many bytes as the length-prefix says there
897
raise _NeedMoreBytes(end_of_bytes)
898
# Extract the bytes from the buffer.
899
bytes = self._in_buffer[4:end_of_bytes]
900
self._in_buffer = self._in_buffer[end_of_bytes:]
903
def _extract_prefixed_bencoded_data(self):
904
prefixed_bytes = self._extract_length_prefixed_bytes()
906
decoded = bdecode(prefixed_bytes)
908
raise errors.SmartProtocolError(
909
'Bytes %r not bencoded' % (prefixed_bytes,))
912
def _extract_single_byte(self):
913
if self._in_buffer == '':
914
# The buffer is empty
915
raise _NeedMoreBytes(1)
916
one_byte = self._in_buffer[0]
917
self._in_buffer = self._in_buffer[1:]
920
def _state_accept_expecting_protocol_version(self):
921
needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
923
# We don't have enough bytes to check if the protocol version
924
# marker is right. But we can check if it is already wrong by
925
# checking that the start of MESSAGE_VERSION_THREE matches what
927
# [In fact, if the remote end isn't bzr we might never receive
928
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
929
# are wrong then we should just raise immediately rather than
931
if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
932
# We have enough bytes to know the protocol version is wrong
933
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
934
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
935
if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
936
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
937
self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
938
self.state_accept = self._state_accept_expecting_headers
940
def _state_accept_expecting_headers(self):
941
decoded = self._extract_prefixed_bencoded_data()
942
if type(decoded) is not dict:
943
raise errors.SmartProtocolError(
944
'Header object %r is not a dict' % (decoded,))
945
self.state_accept = self._state_accept_expecting_message_part
947
self.message_handler.headers_received(decoded)
949
raise errors.SmartMessageHandlerError(sys.exc_info())
951
def _state_accept_expecting_message_part(self):
952
message_part_kind = self._extract_single_byte()
953
if message_part_kind == 'o':
954
self.state_accept = self._state_accept_expecting_one_byte
955
elif message_part_kind == 's':
956
self.state_accept = self._state_accept_expecting_structure
957
elif message_part_kind == 'b':
958
self.state_accept = self._state_accept_expecting_bytes
959
elif message_part_kind == 'e':
962
raise errors.SmartProtocolError(
963
'Bad message kind byte: %r' % (message_part_kind,))
965
def _state_accept_expecting_one_byte(self):
966
byte = self._extract_single_byte()
967
self.state_accept = self._state_accept_expecting_message_part
969
self.message_handler.byte_part_received(byte)
971
raise errors.SmartMessageHandlerError(sys.exc_info())
973
def _state_accept_expecting_bytes(self):
974
# XXX: this should not buffer whole message part, but instead deliver
975
# the bytes as they arrive.
976
prefixed_bytes = self._extract_length_prefixed_bytes()
977
self.state_accept = self._state_accept_expecting_message_part
979
self.message_handler.bytes_part_received(prefixed_bytes)
981
raise errors.SmartMessageHandlerError(sys.exc_info())
983
def _state_accept_expecting_structure(self):
984
structure = self._extract_prefixed_bencoded_data()
985
self.state_accept = self._state_accept_expecting_message_part
987
self.message_handler.structure_part_received(structure)
989
raise errors.SmartMessageHandlerError(sys.exc_info())
992
self.unused_data = self._in_buffer
994
self.state_accept = self._state_accept_reading_unused
996
self.message_handler.end_received()
998
raise errors.SmartMessageHandlerError(sys.exc_info())
1000
def _state_accept_reading_unused(self):
1001
self.unused_data += self._in_buffer
1002
self._in_buffer = ''
1004
def next_read_size(self):
1005
if self.state_accept == self._state_accept_reading_unused:
1007
elif self.decoding_failed:
1008
# An exception occured while processing this message, probably from
1009
# self.message_handler. We're not sure that this state machine is
1010
# in a consistent state, so just signal that we're done (i.e. give
1014
if self._number_needed_bytes is not None:
1015
return self._number_needed_bytes - len(self._in_buffer)
1017
raise AssertionError("don't know how many bytes are expected!")
1020
class _ProtocolThreeEncoder(object):
1022
response_marker = request_marker = MESSAGE_VERSION_THREE
1024
def __init__(self, write_func):
1026
self._real_write_func = write_func
1028
def _write_func(self, bytes):
1033
self._real_write_func(self._buf)
1036
def _serialise_offsets(self, offsets):
1037
"""Serialise a readv offset list."""
1039
for start, length in offsets:
1040
txt.append('%d,%d' % (start, length))
1041
return '\n'.join(txt)
1043
def _write_protocol_version(self):
1044
self._write_func(MESSAGE_VERSION_THREE)
1046
def _write_prefixed_bencode(self, structure):
1047
bytes = bencode(structure)
1048
self._write_func(struct.pack('!L', len(bytes)))
1049
self._write_func(bytes)
1051
def _write_headers(self, headers):
1052
self._write_prefixed_bencode(headers)
1054
def _write_structure(self, args):
1055
self._write_func('s')
1058
if type(arg) is unicode:
1059
utf8_args.append(arg.encode('utf8'))
1061
utf8_args.append(arg)
1062
self._write_prefixed_bencode(utf8_args)
1064
def _write_end(self):
1065
self._write_func('e')
1068
def _write_prefixed_body(self, bytes):
1069
self._write_func('b')
1070
self._write_func(struct.pack('!L', len(bytes)))
1071
self._write_func(bytes)
1073
def _write_error_status(self):
1074
self._write_func('oE')
1076
def _write_success_status(self):
1077
self._write_func('oS')
1080
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1082
def __init__(self, write_func):
1083
_ProtocolThreeEncoder.__init__(self, write_func)
1084
self.response_sent = False
1085
self._headers = {'Software version': bzrlib.__version__}
1087
def send_error(self, exception):
1088
if self.response_sent:
1089
raise AssertionError(
1090
"send_error(%s) called, but response already sent."
1092
if isinstance(exception, errors.UnknownSmartMethod):
1093
failure = request.FailedSmartServerResponse(
1094
('UnknownMethod', exception.verb))
1095
self.send_response(failure)
1097
self.response_sent = True
1098
self._write_protocol_version()
1099
self._write_headers(self._headers)
1100
self._write_error_status()
1101
self._write_structure(('error', str(exception)))
1104
def send_response(self, response):
1105
if self.response_sent:
1106
raise AssertionError(
1107
"send_response(%r) called, but response already sent."
1109
self.response_sent = True
1110
self._write_protocol_version()
1111
self._write_headers(self._headers)
1112
if response.is_successful():
1113
self._write_success_status()
1115
self._write_error_status()
1116
self._write_structure(response.args)
1117
if response.body is not None:
1118
self._write_prefixed_body(response.body)
1119
elif response.body_stream is not None:
1120
for chunk in response.body_stream:
1121
self._write_prefixed_body(chunk)
1126
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1128
def __init__(self, medium_request):
1129
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1130
self._medium_request = medium_request
1133
def set_headers(self, headers):
1134
self._headers = headers.copy()
1136
def call(self, *args):
1137
if 'hpss' in debug.debug_flags:
1138
mutter('hpss call: %s', repr(args)[1:-1])
1139
base = getattr(self._medium_request._medium, 'base', None)
1140
if base is not None:
1141
mutter(' (to %s)', base)
1142
self._request_start_time = time.time()
1143
self._write_protocol_version()
1144
self._write_headers(self._headers)
1145
self._write_structure(args)
1147
self._medium_request.finished_writing()
1149
def call_with_body_bytes(self, args, body):
1150
"""Make a remote call of args with body bytes 'body'.
1152
After calling this, call read_response_tuple to find the result out.
1154
if 'hpss' in debug.debug_flags:
1155
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1156
path = getattr(self._medium_request._medium, '_path', None)
1157
if path is not None:
1158
mutter(' (to %s)', path)
1159
mutter(' %d bytes', len(body))
1160
self._request_start_time = time.time()
1161
self._write_protocol_version()
1162
self._write_headers(self._headers)
1163
self._write_structure(args)
1164
self._write_prefixed_body(body)
1166
self._medium_request.finished_writing()
1168
def call_with_body_readv_array(self, args, body):
1169
"""Make a remote call with a readv array.
1171
The body is encoded with one line per readv offset pair. The numbers in
1172
each pair are separated by a comma, and no trailing \n is emitted.
1174
if 'hpss' in debug.debug_flags:
1175
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1176
path = getattr(self._medium_request._medium, '_path', None)
1177
if path is not None:
1178
mutter(' (to %s)', path)
1179
self._request_start_time = time.time()
1180
self._write_protocol_version()
1181
self._write_headers(self._headers)
1182
self._write_structure(args)
1183
readv_bytes = self._serialise_offsets(body)
1184
if 'hpss' in debug.debug_flags:
1185
mutter(' %d bytes in readv request', len(readv_bytes))
1186
self._write_prefixed_body(readv_bytes)
1188
self._medium_request.finished_writing()