Add files via upload

This commit is contained in:
Easymode 2025-02-19 00:37:29 +00:00 committed by GitHub
parent a9c6b5cbf0
commit faef87a80a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View File

@ -32,8 +32,8 @@ class PointConditioner(torch.nn.Module):
# open-source version of miche # open-source version of miche
if model_name == 'miche-256-feature': if model_name == 'miche-256-feature':
ckpt_path = None
dir = os.path.dirname(os.path.abspath(__file__)) dir = os.path.dirname(os.path.abspath(__file__))
ckpt_path = None#os.path.join(dir, '..\shapevae-256.ckpt')
model_path = os.path.join(dir, '..\shapevae-256.yaml') model_path = os.path.join(dir, '..\shapevae-256.yaml')
config_path = model_path config_path = model_path

View File

@ -70,6 +70,7 @@ class MeshTransformer(Module):
conditioned_on_pc = True, conditioned_on_pc = True,
encoder_name = 'miche-256-feature', encoder_name = 'miche-256-feature',
encoder_freeze = False, encoder_freeze = False,
cond_dim = 768
): ):
super().__init__() super().__init__()
@ -107,7 +108,7 @@ class MeshTransformer(Module):
# load point_cloud encoder # load point_cloud encoder
if conditioned_on_pc: if conditioned_on_pc:
print(f'Point cloud encoder: {encoder_name} | freeze: {encoder_freeze}') print(f'Point cloud encoder: {encoder_name} | freeze: {encoder_freeze}')
self.conditioner = PointConditioner(model_name=encoder_name, freeze=encoder_freeze) self.conditioner = PointConditioner(cond_dim=cond_dim, model_name=encoder_name, freeze=encoder_freeze)
cross_attn_dim_context = self.conditioner.dim_latent cross_attn_dim_context = self.conditioner.dim_latent
else: else:
raise NotImplementedError raise NotImplementedError