/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to fetch.py

Add simple tests and docstrings for GraphWalker.

Show diffs side-by-side

added added

removed removed

Lines of Context:
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
 
from bzrlib import osutils, urlutils
 
17
from bzrlib import osutils, ui, urlutils
18
18
from bzrlib.errors import InvalidRevisionId
19
19
from bzrlib.inventory import Inventory
20
20
from bzrlib.repository import InterRepository
21
21
from bzrlib.trace import info
 
22
from bzrlib.tsort import topo_sort
22
23
 
23
 
from bzrlib.plugins.git import git
24
 
from bzrlib.plugins.git.repository import LocalGitRepository, GitRepository, GitFormat
 
24
from bzrlib.plugins.git.repository import (
 
25
        LocalGitRepository, 
 
26
        GitRepository, 
 
27
        GitFormat,
 
28
        )
25
29
from bzrlib.plugins.git.remote import RemoteGitRepository
26
30
 
 
31
import dulwich as git
 
32
from dulwich.client import SimpleFetchGraphWalker
27
33
from dulwich.objects import Commit
28
34
 
29
35
from cStringIO import StringIO
30
36
 
31
37
 
32
38
class BzrFetchGraphWalker(object):
 
39
    """GraphWalker implementation that uses a Bazaar repository."""
33
40
 
34
41
    def __init__(self, repository, mapping):
35
42
        self.repository = repository
38
45
        self.heads = set(repository.all_revision_ids())
39
46
        self.parents = {}
40
47
 
 
48
    def __iter__(self):
 
49
        return iter(self.next, None)
 
50
 
41
51
    def ack(self, sha):
42
52
        revid = self.mapping.revision_id_foreign_to_bzr(sha)
43
53
        self.remove(revid)
44
54
 
45
55
    def remove(self, revid):
46
56
        self.done.add(revid)
47
 
        if ref in self.heads:
 
57
        if revid in self.heads:
48
58
            self.heads.remove(revid)
49
59
        if revid in self.parents:
50
60
            for p in self.parents[revid]:
58
68
            self.heads.update([p for p in ps if not p in self.done])
59
69
            try:
60
70
                self.done.add(ret)
61
 
                return self.mapping.revision_id_bzr_to_foreign(ret)
 
71
                return self.mapping.revision_id_bzr_to_foreign(ret)[0]
62
72
            except InvalidRevisionId:
63
73
                pass
64
74
        return None
65
75
 
66
76
 
67
 
def import_git_blob(repo, mapping, path, blob, inv):
 
77
def import_git_blob(repo, mapping, path, blob, inv, parent_invs, executable):
68
78
    """Import a git blob object into a bzr repository.
69
79
 
70
80
    :param repo: bzr repository
72
82
    :param blob: A git blob
73
83
    """
74
84
    file_id = mapping.generate_file_id(path)
75
 
    repo.texts.add_lines((file_id, blob.id),
76
 
        [], #FIXME 
 
85
    text_revision = inv.revision_id
 
86
    repo.texts.add_lines((file_id, text_revision),
 
87
        [(file_id, p[file_id].revision) for p in parent_invs if file_id in p],
77
88
        osutils.split_lines(blob.data))
78
89
    ie = inv.add_path(path, "file", file_id)
79
 
 
80
 
 
81
 
def import_git_tree(repo, mapping, path, tree, inv, lookup_object):
 
90
    ie.revision = text_revision
 
91
    ie.text_size = len(blob.data)
 
92
    ie.text_sha1 = osutils.sha_string(blob.data)
 
93
    ie.executable = executable
 
94
 
 
95
 
 
96
def import_git_tree(repo, mapping, path, tree, inv, parent_invs, lookup_object):
82
97
    """Import a git tree object into a bzr repository.
83
98
 
84
99
    :param repo: A Bzr repository object
87
102
    :param inv: Inventory object
88
103
    """
89
104
    file_id = mapping.generate_file_id(path)
90
 
    repo.texts.add_lines((file_id, tree.id),
91
 
        [], #FIXME 
 
105
    text_revision = inv.revision_id
 
106
    repo.texts.add_lines((file_id, text_revision),
 
107
        [(file_id, p[file_id].revision) for p in parent_invs if file_id in p],
92
108
        [])
93
 
    inv.add_path(path, "directory", file_id)
 
109
    ie = inv.add_path(path, "directory", file_id)
 
110
    ie.revision = text_revision
94
111
    for mode, name, hexsha in tree.entries():
95
112
        entry_kind = (mode & 0700000) / 0100000
96
113
        basename = name.decode("utf-8")
100
117
            child_path = urlutils.join(path, name)
101
118
        if entry_kind == 0:
102
119
            tree = lookup_object(hexsha)
103
 
            import_git_tree(repo, mapping, child_path, tree, inv, lookup_object)
 
120
            import_git_tree(repo, mapping, child_path, tree, inv, parent_invs, lookup_object)
104
121
        elif entry_kind == 1:
105
122
            blob = lookup_object(hexsha)
106
 
            import_git_blob(repo, mapping, child_path, blob, inv)
 
123
            fs_mode = mode & 0777
 
124
            import_git_blob(repo, mapping, child_path, blob, inv, parent_invs, bool(fs_mode & 0111))
107
125
        else:
108
126
            raise AssertionError("Unknown blob kind, perms=%r." % (mode,))
109
127
 
110
128
 
111
 
def import_git_objects(repo, mapping, object_iter):
 
129
def import_git_objects(repo, mapping, num_objects, object_iter, pb=None):
112
130
    """Import a set of git objects into a bzr repository.
113
131
 
114
132
    :param repo: Bazaar repository
115
133
    :param mapping: Mapping to use
 
134
    :param num_objects: Number of objects.
116
135
    :param object_iter: Iterator over Git objects.
117
136
    """
118
137
    # TODO: a more (memory-)efficient implementation of this
119
138
    objects = {}
120
 
    for o in object_iter:
 
139
    for i, (o, _) in enumerate(object_iter):
 
140
        if pb is not None:
 
141
            pb.update("fetching objects", i, num_objects) 
121
142
        objects[o.id] = o
 
143
    graph = []
122
144
    root_trees = {}
 
145
    revisions = {}
123
146
    # Find and convert commit objects
124
147
    for o in objects.itervalues():
125
148
        if isinstance(o, Commit):
126
149
            rev = mapping.import_commit(o)
127
 
            root_trees[rev] = objects[o.tree]
 
150
            root_trees[rev.revision_id] = objects[o.tree]
 
151
            revisions[rev.revision_id] = rev
 
152
            graph.append((rev.revision_id, rev.parent_ids))
 
153
    # Order the revisions
128
154
    # Create the inventory objects
129
 
    for rev, root_tree in root_trees.iteritems():
 
155
    for i, revid in enumerate(topo_sort(graph)):
 
156
        if pb is not None:
 
157
            pb.update("fetching revisions", i, len(graph))
 
158
        root_tree = root_trees[revid]
 
159
        rev = revisions[revid]
130
160
        # We have to do this here, since we have to walk the tree and 
131
161
        # we need to make sure to import the blobs / trees with the riht 
132
162
        # path; this may involve adding them more than once.
136
166
            if sha in objects:
137
167
                return objects[sha]
138
168
            return reconstruct_git_object(repo, mapping, sha)
139
 
        import_git_tree(repo, mapping, "", root_tree, inv, lookup_object)
 
169
        parent_invs = [repo.get_inventory(r) for r in rev.parent_ids]
 
170
        import_git_tree(repo, mapping, "", root_tree, inv, parent_invs, 
 
171
            lookup_object)
140
172
        repo.add_revision(rev.revision_id, rev, inv)
141
173
 
142
174
 
159
191
    raise KeyError("No such object %s" % sha)
160
192
 
161
193
 
162
 
class InterGitRepository(InterRepository):
 
194
class InterGitNonGitRepository(InterRepository):
163
195
 
164
196
    _matching_repo_format = GitFormat()
165
197
 
176
208
        if mapping is None:
177
209
            mapping = self.source.get_mapping()
178
210
        def progress(text):
179
 
            if pb is not None:
180
 
                pb.note("git: %s" % text)
181
 
            else:
182
 
                info("git: %s" % text)
 
211
            pb.update("git: %s" % text.rstrip("\r\n"), 0, 0)
183
212
        def determine_wants(heads):
184
213
            if revision_id is None:
185
214
                ret = heads.values()
186
215
            else:
187
 
                ret = [mapping.revision_id_bzr_to_foreign(revision_id)]
 
216
                ret = [mapping.revision_id_bzr_to_foreign(revision_id)[0]]
188
217
            return [rev for rev in ret if not self.target.has_revision(mapping.revision_id_foreign_to_bzr(rev))]
189
218
        graph_walker = BzrFetchGraphWalker(self.target, mapping)
190
 
        self.target.lock_write()
 
219
        create_pb = None
 
220
        if pb is None:
 
221
            create_pb = pb = ui.ui_factory.nested_progress_bar()
191
222
        try:
192
 
            self.target.start_write_group()
 
223
            self.target.lock_write()
193
224
            try:
194
 
                import_git_objects(self.target, mapping,
195
 
                    iter(self.source.fetch_objects(determine_wants, graph_walker, 
196
 
                        progress)))
 
225
                self.target.start_write_group()
 
226
                try:
 
227
                    (num_objects, objects_iter) = \
 
228
                            self.source.fetch_objects(determine_wants, 
 
229
                                graph_walker, progress)
 
230
                    import_git_objects(self.target, mapping, num_objects, 
 
231
                                       objects_iter, pb)
 
232
                finally:
 
233
                    self.target.commit_write_group()
197
234
            finally:
198
 
                self.target.commit_write_group()
 
235
                self.target.unlock()
199
236
        finally:
200
 
            self.target.unlock()
 
237
            if create_pb:
 
238
                create_pb.finished()
201
239
 
202
240
    @staticmethod
203
241
    def is_compatible(source, target):
204
242
        """Be compatible with GitRepository."""
205
243
        # FIXME: Also check target uses VersionedFile
206
 
        return (isinstance(source, LocalGitRepository) and 
207
 
                target.supports_rich_root())
 
244
        return (isinstance(source, GitRepository) and 
 
245
                target.supports_rich_root() and
 
246
                not isinstance(target, GitRepository))
 
247
 
 
248
 
 
249
class InterGitRepository(InterRepository):
 
250
 
 
251
    _matching_repo_format = GitFormat()
 
252
 
 
253
    @staticmethod
 
254
    def _get_repo_format_to_test():
 
255
        return None
 
256
 
 
257
    def copy_content(self, revision_id=None, pb=None):
 
258
        """See InterRepository.copy_content."""
 
259
        self.fetch(revision_id, pb, find_ghosts=False)
 
260
 
 
261
    def fetch(self, revision_id=None, pb=None, find_ghosts=False, 
 
262
              mapping=None):
 
263
        if mapping is None:
 
264
            mapping = self.source.get_mapping()
 
265
        def progress(text):
 
266
            info("git: %s", text)
 
267
        r = self.target._git
 
268
        if revision_id is None:
 
269
            determine_wants = lambda x: [y for y in x.values() if not y in r.object_store]
 
270
        else:
 
271
            args = [mapping.revision_id_bzr_to_foreign(revision_id)[0]]
 
272
            determine_wants = lambda x: [y for y in args if not y in r.object_store]
 
273
 
 
274
        graphwalker = SimpleFetchGraphWalker(r.heads().values(), r.get_parents)
 
275
        f, commit = r.object_store.add_pack()
 
276
        try:
 
277
            self.source._git.fetch_pack(path, determine_wants, graphwalker, f.write, progress)
 
278
            f.close()
 
279
            commit()
 
280
        except:
 
281
            f.close()
 
282
            raise
 
283
 
 
284
    @staticmethod
 
285
    def is_compatible(source, target):
 
286
        """Be compatible with GitRepository."""
 
287
        return (isinstance(source, GitRepository) and 
 
288
                isinstance(target, GitRepository))