/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to fetch.py

Revert locking fixes for now, as they break bzr 1.13 compatibility.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2008 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
from bzrlib import osutils, ui, urlutils
 
18
from bzrlib.errors import InvalidRevisionId, NoSuchRevision
 
19
from bzrlib.inventory import Inventory
 
20
from bzrlib.repository import InterRepository
 
21
from bzrlib.trace import info
 
22
from bzrlib.tsort import topo_sort
 
23
 
 
24
from bzrlib.plugins.git.repository import (
 
25
        LocalGitRepository, 
 
26
        GitRepository, 
 
27
        GitFormat,
 
28
        )
 
29
from bzrlib.plugins.git.converter import GitObjectConverter
 
30
from bzrlib.plugins.git.remote import RemoteGitRepository
 
31
 
 
32
import dulwich as git
 
33
from dulwich.client import SimpleFetchGraphWalker
 
34
from dulwich.objects import Commit
 
35
 
 
36
from cStringIO import StringIO
 
37
 
 
38
 
 
39
class BzrFetchGraphWalker(object):
 
40
    """GraphWalker implementation that uses a Bazaar repository."""
 
41
 
 
42
    def __init__(self, repository, mapping):
 
43
        self.repository = repository
 
44
        self.mapping = mapping
 
45
        self.done = set()
 
46
        self.heads = set(repository.all_revision_ids())
 
47
        self.parents = {}
 
48
 
 
49
    def __iter__(self):
 
50
        return iter(self.next, None)
 
51
 
 
52
    def ack(self, sha):
 
53
        revid = self.mapping.revision_id_foreign_to_bzr(sha)
 
54
        self.remove(revid)
 
55
 
 
56
    def remove(self, revid):
 
57
        self.done.add(revid)
 
58
        if revid in self.heads:
 
59
            self.heads.remove(revid)
 
60
        if revid in self.parents:
 
61
            for p in self.parents[revid]:
 
62
                self.remove(p)
 
63
 
 
64
    def next(self):
 
65
        while self.heads:
 
66
            ret = self.heads.pop()
 
67
            ps = self.repository.get_parent_map([ret])[ret]
 
68
            self.parents[ret] = ps
 
69
            self.heads.update([p for p in ps if not p in self.done])
 
70
            try:
 
71
                self.done.add(ret)
 
72
                return self.mapping.revision_id_bzr_to_foreign(ret)[0]
 
73
            except InvalidRevisionId:
 
74
                pass
 
75
        return None
 
76
 
 
77
 
 
78
def import_git_blob(repo, mapping, path, blob, inv, parent_invs, gitmap, executable):
 
79
    """Import a git blob object into a bzr repository.
 
80
 
 
81
    :param repo: bzr repository
 
82
    :param path: Path in the tree
 
83
    :param blob: A git blob
 
84
    """
 
85
    file_id = mapping.generate_file_id(path)
 
86
    text_revision = inv.revision_id
 
87
    repo.texts.add_lines((file_id, text_revision),
 
88
        [(file_id, p[file_id].revision) for p in parent_invs if file_id in p],
 
89
        osutils.split_lines(blob.data))
 
90
    ie = inv.add_path(path, "file", file_id)
 
91
    ie.revision = text_revision
 
92
    ie.text_size = len(blob.data)
 
93
    ie.text_sha1 = osutils.sha_string(blob.data)
 
94
    ie.executable = executable
 
95
    gitmap._idmap.add_entry(blob.sha().hexdigest(), "blob", (ie.file_id, ie.revision))
 
96
 
 
97
 
 
98
def import_git_tree(repo, mapping, path, tree, inv, parent_invs, 
 
99
                    gitmap, lookup_object):
 
100
    """Import a git tree object into a bzr repository.
 
101
 
 
102
    :param repo: A Bzr repository object
 
103
    :param path: Path in the tree
 
104
    :param tree: A git tree object
 
105
    :param inv: Inventory object
 
106
    """
 
107
    file_id = mapping.generate_file_id(path)
 
108
    text_revision = inv.revision_id
 
109
    repo.texts.add_lines((file_id, text_revision),
 
110
        [(file_id, p[file_id].revision) for p in parent_invs if file_id in p],
 
111
        [])
 
112
    ie = inv.add_path(path, "directory", file_id)
 
113
    ie.revision = text_revision
 
114
    gitmap._idmap.add_entry(tree.sha().hexdigest(), "tree", (file_id, text_revision))
 
115
    for mode, name, hexsha in tree.entries():
 
116
        entry_kind = (mode & 0700000) / 0100000
 
117
        basename = name.decode("utf-8")
 
118
        if path == "":
 
119
            child_path = name
 
120
        else:
 
121
            child_path = urlutils.join(path, name)
 
122
        if entry_kind == 0:
 
123
            tree = lookup_object(hexsha)
 
124
            import_git_tree(repo, mapping, child_path, tree, inv, parent_invs, gitmap, lookup_object)
 
125
        elif entry_kind == 1:
 
126
            blob = lookup_object(hexsha)
 
127
            fs_mode = mode & 0777
 
128
            import_git_blob(repo, mapping, child_path, blob, inv, parent_invs, gitmap, bool(fs_mode & 0111))
 
129
        else:
 
130
            raise AssertionError("Unknown blob kind, perms=%r." % (mode,))
 
131
 
 
132
 
 
133
def import_git_objects(repo, mapping, object_iter, target_git_object_retriever, 
 
134
        pb=None):
 
135
    """Import a set of git objects into a bzr repository.
 
136
 
 
137
    :param repo: Bazaar repository
 
138
    :param mapping: Mapping to use
 
139
    :param object_iter: Iterator over Git objects.
 
140
    """
 
141
    # TODO: a more (memory-)efficient implementation of this
 
142
    graph = []
 
143
    root_trees = {}
 
144
    revisions = {}
 
145
    # Find and convert commit objects
 
146
    for o in object_iter.iterobjects():
 
147
        if isinstance(o, Commit):
 
148
            rev = mapping.import_commit(o)
 
149
            root_trees[rev.revision_id] = object_iter[o.tree]
 
150
            revisions[rev.revision_id] = rev
 
151
            graph.append((rev.revision_id, rev.parent_ids))
 
152
            target_git_object_retriever._idmap.add_entry(o.sha().hexdigest(), "commit", (rev.revision_id, o._tree))
 
153
    # Order the revisions
 
154
    # Create the inventory objects
 
155
    for i, revid in enumerate(topo_sort(graph)):
 
156
        if pb is not None:
 
157
            pb.update("fetching revisions", i, len(graph))
 
158
        root_tree = root_trees[revid]
 
159
        rev = revisions[revid]
 
160
        # We have to do this here, since we have to walk the tree and 
 
161
        # we need to make sure to import the blobs / trees with the riht 
 
162
        # path; this may involve adding them more than once.
 
163
        inv = Inventory()
 
164
        inv.revision_id = rev.revision_id
 
165
        def lookup_object(sha):
 
166
            if sha in object_iter:
 
167
                return object_iter[sha]
 
168
            return target_git_object_retriever[sha]
 
169
        parent_invs = [repo.get_inventory(r) for r in rev.parent_ids]
 
170
        import_git_tree(repo, mapping, "", root_tree, inv, parent_invs, 
 
171
            target_git_object_retriever, lookup_object)
 
172
        repo.add_revision(rev.revision_id, rev, inv)
 
173
 
 
174
 
 
175
class InterGitNonGitRepository(InterRepository):
 
176
 
 
177
    _matching_repo_format = GitFormat()
 
178
 
 
179
    @staticmethod
 
180
    def _get_repo_format_to_test():
 
181
        return None
 
182
 
 
183
    def copy_content(self, revision_id=None, pb=None):
 
184
        """See InterRepository.copy_content."""
 
185
        self.fetch(revision_id, pb, find_ghosts=False)
 
186
 
 
187
    def fetch_objects(self, determine_wants, mapping, pb=None):
 
188
        def progress(text):
 
189
            pb.update("git: %s" % text.rstrip("\r\n"), 0, 0)
 
190
        graph_walker = BzrFetchGraphWalker(self.target, mapping)
 
191
        create_pb = None
 
192
        if pb is None:
 
193
            create_pb = pb = ui.ui_factory.nested_progress_bar()
 
194
        target_git_object_retriever = GitObjectConverter(self.target, mapping)
 
195
        
 
196
        try:
 
197
            self.target.lock_write()
 
198
            try:
 
199
                self.target.start_write_group()
 
200
                try:
 
201
                    objects_iter = self.source.fetch_objects(determine_wants, 
 
202
                                graph_walker, 
 
203
                                target_git_object_retriever.__getitem__, 
 
204
                                progress)
 
205
                    import_git_objects(self.target, mapping, objects_iter, 
 
206
                            target_git_object_retriever, pb)
 
207
                finally:
 
208
                    self.target.commit_write_group()
 
209
            finally:
 
210
                self.target.unlock()
 
211
        finally:
 
212
            if create_pb:
 
213
                create_pb.finished()
 
214
 
 
215
    def fetch(self, revision_id=None, pb=None, find_ghosts=False, 
 
216
              mapping=None, fetch_spec=None):
 
217
        self.fetch_refs(revision_id=revision_id, pb=pb, find_ghosts=find_ghosts,
 
218
                mapping=mapping, fetch_spec=fetch_spec)
 
219
 
 
220
    def fetch_refs(self, revision_id=None, pb=None, find_ghosts=False, 
 
221
              mapping=None, fetch_spec=None):
 
222
        if mapping is None:
 
223
            mapping = self.source.get_mapping()
 
224
        if revision_id is not None:
 
225
            interesting_heads = [revision_id]
 
226
        elif fetch_spec is not None:
 
227
            interesting_heads = fetch_spec.heads
 
228
        else:
 
229
            interesting_heads = None
 
230
        self._refs = {}
 
231
        def determine_wants(refs):
 
232
            self._refs = refs
 
233
            if interesting_heads is None:
 
234
                ret = [sha for (ref, sha) in refs.iteritems() if not ref.endswith("^{}")]
 
235
            else:
 
236
                ret = [mapping.revision_id_bzr_to_foreign(revid)[0] for revid in interesting_heads]
 
237
            return [rev for rev in ret if not self.target.has_revision(mapping.revision_id_foreign_to_bzr(rev))]
 
238
        self.fetch_objects(determine_wants, mapping, pb)
 
239
        return self._refs
 
240
 
 
241
    @staticmethod
 
242
    def is_compatible(source, target):
 
243
        """Be compatible with GitRepository."""
 
244
        # FIXME: Also check target uses VersionedFile
 
245
        return (isinstance(source, GitRepository) and 
 
246
                target.supports_rich_root() and
 
247
                not isinstance(target, GitRepository))
 
248
 
 
249
 
 
250
class InterGitRepository(InterRepository):
 
251
 
 
252
    _matching_repo_format = GitFormat()
 
253
 
 
254
    @staticmethod
 
255
    def _get_repo_format_to_test():
 
256
        return None
 
257
 
 
258
    def copy_content(self, revision_id=None, pb=None):
 
259
        """See InterRepository.copy_content."""
 
260
        self.fetch(revision_id, pb, find_ghosts=False)
 
261
 
 
262
    def fetch(self, revision_id=None, pb=None, find_ghosts=False, 
 
263
              mapping=None, fetch_spec=None):
 
264
        if mapping is None:
 
265
            mapping = self.source.get_mapping()
 
266
        def progress(text):
 
267
            info("git: %s", text)
 
268
        r = self.target._git
 
269
        if revision_id is not None:
 
270
            args = [mapping.revision_id_bzr_to_foreign(revision_id)[0]]
 
271
        elif fetch_spec is not None:
 
272
            args = [mapping.revision_id_bzr_to_foreign(revid)[0] for revid in fetch_spec.heads]
 
273
        if fetch_spec is None and revision_id is None:
 
274
            determine_wants = r.object_store.determine_wants_all
 
275
        else:
 
276
            determine_wants = lambda x: [y for y in args if not y in r.object_store]
 
277
 
 
278
        graphwalker = SimpleFetchGraphWalker(r.heads().values(), r.get_parents)
 
279
        f, commit = r.object_store.add_pack()
 
280
        try:
 
281
            self.source._git.fetch_pack(path, determine_wants, graphwalker, f.write, progress)
 
282
            f.close()
 
283
            commit()
 
284
        except:
 
285
            f.close()
 
286
            raise
 
287
 
 
288
    @staticmethod
 
289
    def is_compatible(source, target):
 
290
        """Be compatible with GitRepository."""
 
291
        return (isinstance(source, GitRepository) and 
 
292
                isinstance(target, GitRepository))