diff --git a/hy3dgen/shapegen/bpt/model/miche_conditioner.py b/hy3dgen/shapegen/bpt/model/miche_conditioner.py index 1a744d5..c5909c2 100644 --- a/hy3dgen/shapegen/bpt/model/miche_conditioner.py +++ b/hy3dgen/shapegen/bpt/model/miche_conditioner.py @@ -32,8 +32,8 @@ class PointConditioner(torch.nn.Module): # open-source version of miche if model_name == 'miche-256-feature': - ckpt_path = None dir = os.path.dirname(os.path.abspath(__file__)) + ckpt_path = None#os.path.join(dir, '..\shapevae-256.ckpt') model_path = os.path.join(dir, '..\shapevae-256.yaml') config_path = model_path diff --git a/hy3dgen/shapegen/bpt/model/model.py b/hy3dgen/shapegen/bpt/model/model.py index 8ec2d4d..8832060 100644 --- a/hy3dgen/shapegen/bpt/model/model.py +++ b/hy3dgen/shapegen/bpt/model/model.py @@ -70,6 +70,7 @@ class MeshTransformer(Module): conditioned_on_pc = True, encoder_name = 'miche-256-feature', encoder_freeze = False, + cond_dim = 768 ): super().__init__() @@ -107,7 +108,7 @@ class MeshTransformer(Module): # load point_cloud encoder if conditioned_on_pc: print(f'Point cloud encoder: {encoder_name} | freeze: {encoder_freeze}') - self.conditioner = PointConditioner(model_name=encoder_name, freeze=encoder_freeze) + self.conditioner = PointConditioner(cond_dim=cond_dim, model_name=encoder_name, freeze=encoder_freeze) cross_attn_dim_context = self.conditioner.dim_latent else: raise NotImplementedError