diff --git a/hy3dgen/shapegen/bpt/README.md b/hy3dgen/shapegen/bpt/README.md new file mode 100644 index 0000000..589e125 --- /dev/null +++ b/hy3dgen/shapegen/bpt/README.md @@ -0,0 +1,10 @@ +# BPT Installation + +Original repo: https://github.com/whaohan/bpt + + +### Installation +pip install -r requirements.txt + +### Download weights (From main Hunyuan3D2 directory) +huggingface-cli download whaohan/bpt --local-dir ./weights \ No newline at end of file diff --git a/hy3dgen/shapegen/bpt/__pycache__/utils.cpython-312.pyc b/hy3dgen/shapegen/bpt/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..246dccd Binary files /dev/null and b/hy3dgen/shapegen/bpt/__pycache__/utils.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/LICENSE b/hy3dgen/shapegen/bpt/miche/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/hy3dgen/shapegen/bpt/miche/__init__.py b/hy3dgen/shapegen/bpt/miche/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hy3dgen/shapegen/bpt/miche/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..97f5d22 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/__pycache__/encode.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/__pycache__/encode.cpython-312.pyc new file mode 100644 index 0000000..a44f74d Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/__pycache__/encode.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/encode.py b/hy3dgen/shapegen/bpt/miche/encode.py new file mode 100644 index 0000000..f755c7b --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/encode.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +import argparse +from omegaconf import OmegaConf +import numpy as np +import torch +from .michelangelo.utils.misc import instantiate_from_config + +def load_surface(fp): + + with np.load(fp) as input_pc: + surface = input_pc['points'] + normal = input_pc['normals'] + + rng = np.random.default_rng() + ind = rng.choice(surface.shape[0], 4096, replace=False) + surface = torch.FloatTensor(surface[ind]) + normal = torch.FloatTensor(normal[ind]) + + surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() + + return surface + +def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): + + surface = load_surface(args.pointcloud_path) + # old_surface = surface.clone() + + # surface[0,:,0]*=-1 + # surface[0,:,1]*=-1 + surface[0,:,2]*=-1 + + # encoding + shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) + shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) + + # decoding + latents = model.model.shape_model.decode(shape_zq) + # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) + + return 0 + +def load_model(ckpt_path="shapevae-256.ckpt", config_path="shapevae-256.yaml"): + model_config = OmegaConf.load(config_path) + print(model_config) + if hasattr(model_config, "model"): + model_config = model_config.model + + model = instantiate_from_config(model_config, ckpt_path=ckpt_path) + model = model.eval() + + return model +if __name__ == "__main__": + ''' + 1. Reconstruct point cloud + 2. Image-conditioned generation + 3. Text-conditioned generation + ''' + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', + help='Path to the input point cloud') + parser.add_argument("--image_path", type=str, help='Path to the input image') + parser.add_argument("--text", type=str, + help='Input text within a format: A 3D model of motorcar; Porsche 911.') + parser.add_argument("--output_dir", type=str, default='./output') + parser.add_argument("-s", "--seed", type=int, default=0) + args = parser.parse_args() + + print(f'-----------------------------------------------------------------------------') + print(f'>>> Output directory: {args.output_dir}') + print(f'-----------------------------------------------------------------------------') + + reconstruction(args, load_model(args)) diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..dbe81bb Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..a656886 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__init__.py new file mode 100644 index 0000000..49fc098 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .volume import generate_dense_grid_points + diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..f0be88f Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__pycache__/volume.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__pycache__/volume.cpython-312.pyc new file mode 100644 index 0000000..25d8c4b Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/__pycache__/volume.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/volume.py b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/volume.py new file mode 100644 index 0000000..9c98418 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/graphics/primitives/volume.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +import numpy as np + +# produce dense points +def generate_dense_grid_points(bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_depth: int, + indexing: str = "ij"): + length = bbox_max - bbox_min + num_cells = np.exp2(octree_depth) + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + xyz = xyz.reshape(-1, 3) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..db553e8 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__init__.py new file mode 100644 index 0000000..0729b49 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import checkpoint diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..8003ebc Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/checkpoint.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/checkpoint.cpython-312.pyc new file mode 100644 index 0000000..bc25e88 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/checkpoint.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/distributions.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/distributions.cpython-312.pyc new file mode 100644 index 0000000..1699afe Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/distributions.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/embedder.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/embedder.cpython-312.pyc new file mode 100644 index 0000000..8c505d5 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/embedder.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-312.pyc new file mode 100644 index 0000000..301e004 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/checkpoint.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/checkpoint.py new file mode 100644 index 0000000..55775b0 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/checkpoint.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +import torch +from typing import Callable, Iterable, Sequence, Union + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False +): + # Evaluate a function without caching intermediate activations, allowing for + # reduced memory at the expense of extra compute in the backward pass. + # :param func: the function to evaluate. + # :param inputs: the argument sequence to pass to `func`. + # :param params: a sequence of parameters `func` depends on but does not + # explicitly take as arguments. + # :param flag: if False, disable gradient checkpointing. + # :param use_deepspeed: if True, use deepspeed + if flag: + if use_deepspeed: + import deepspeed + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/distributions.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/distributions.py new file mode 100644 index 0000000..1115dcb --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/distributions.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- + +import torch +import numpy as np +from typing import Union, List + + +class DiagonalGaussianDistribution(object): + # Gaussian distribution + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + # sample from the guassian distribution + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + # Compute the KL divergence between two gaussians. + # Shapes are automatically broadcasted, so batches can be compared to + # scalars, among other use cases. + + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/embedder.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/embedder.py new file mode 100644 index 0000000..223de82 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/embedder.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import torch +import torch.nn as nn +import math + +VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class LearnedFourierEmbedder(nn.Module): + """ following @crowsonkb "s lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, in_channels, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // in_channels + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + def forward(self, x): + """ + + Args: + x (torch.FloatTensor): [..., c] + + Returns: + x (torch.FloatTensor): [..., d] + """ + + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class TriplaneLearnedFourierEmbedder(nn.Module): + def __init__(self, in_channels, dim): + super().__init__() + + self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + + self.out_dim = in_channels + dim + + def forward(self, x): + + yz_embed = self.yz_plane_embedder(x) + xz_embed = self.xz_plane_embedder(x) + xy_embed = self.xy_plane_embedder(x) + + embed = yz_embed + xz_embed + xy_embed + + return embed + + +def sequential_pos_embed(num_len, embed_dim): + assert embed_dim % 2 == 0 + + pos = torch.arange(num_len, dtype=torch.float32) + omega = torch.arange(embed_dim // 2, dtype=torch.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return embeddings + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, + num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, + log2_hashmap_size=19, desired_resolution=None): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, + logspace=True, include_input=True) + return embedder_obj, embedder_obj.out_dim + + elif embed_type == "hashgrid": + raise NotImplementedError + + elif embed_type == "sphere_harmonic": + raise NotImplementedError + + else: + raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/transformer_blocks.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/transformer_blocks.py new file mode 100644 index 0000000..8aaabd7 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/modules/transformer_blocks.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + +from hy3dgen.shapegen.bpt.miche.michelangelo.models.modules.checkpoint import checkpoint + +# Initialize linear layers with normal distribution weights and zero biases +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + +# Multihead attention module +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, # Context size + width: int, # Width of the input tensor + heads: int, # Number of attention heads + init_scale: float, # Initialization scale for weights + qkv_bias: bool, # Whether to use bias in QKV layers + flash: bool = False # Whether to use flash attention + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + +# QKV multihead attention module +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.flash = flash + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + if self.flash: + out = F.scaled_dot_product_attention(q, k, v) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + +# Residual attention block module +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + use_checkpoint: bool = False, + n_ctx: int, # Context size + width: int, # Width of the input tensor + heads: int, # Number of attention heads + init_scale: float, # Initialization scale for weights + qkv_bias: bool, # Whether to use bias in QKV layers + flash: bool = False # Whether to use flash attention + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + +# Multihead cross attention module +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + width: int, # Width of the input tensor + heads: int, # Number of attention heads + init_scale: float, # Initialization scale for weights + qkv_bias: bool, # Whether to use bias in QKV layers + flash: bool = False # Whether to use flash attention + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + +# QKV multihead cross attention module +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, + flash: bool = False, n_data: Optional[int] = None): + + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + self.flash = flash + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + if self.flash: + out = F.scaled_dot_product_attention(q, k, v) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + +# Residual cross attention block module +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + data_width: Optional[int] = None, + width: int, # Width of the input tensor + heads: int, # Number of attention heads + init_scale: float, # Initialization scale for weights + qkv_bias: bool, # Whether to use bias in QKV layers + flash: bool = False # Whether to use flash attention + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + +# MLP Module +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int, + init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + +# Transformer Module +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + layers: int, + use_checkpoint: bool = False, + n_ctx: int, # Context size + width: int, # Width of the input tensor + heads: int, # Number of attention heads + init_scale: float, # Initialization scale for weights + qkv_bias: bool, # Whether to use bias in QKV layers + flash: bool = False # Whether to use flash attention + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..2d3ea6e Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/asl_pl_module.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/asl_pl_module.cpython-312.pyc new file mode 100644 index 0000000..b57e95a Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/asl_pl_module.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/clip_asl_module.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/clip_asl_module.cpython-312.pyc new file mode 100644 index 0000000..9f5e97d Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/clip_asl_module.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/inference_utils.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/inference_utils.cpython-312.pyc new file mode 100644 index 0000000..aff7098 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/inference_utils.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/loss.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000..b3db2a4 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/loss.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/sal_perceiver.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/sal_perceiver.cpython-312.pyc new file mode 100644 index 0000000..5f38f47 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/sal_perceiver.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/tsal_base.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/tsal_base.cpython-312.pyc new file mode 100644 index 0000000..f8c4262 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/__pycache__/tsal_base.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/asl_pl_module.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/asl_pl_module.py new file mode 100644 index 0000000..9b84bf0 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/asl_pl_module.py @@ -0,0 +1,383 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple, Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F +from torch import nn +from torch.optim import lr_scheduler +from typing import Union +from functools import partial + +from .....miche.michelangelo.utils import instantiate_from_config + +from .tsal_base import ( + AlignedShapeAsLatentModule, + ShapeAsLatentModule, + Latent2MeshOutput, + AlignedMeshOutput +) +from .....miche.michelangelo.models.tsal.inference_utils import extract_geometry +import trimesh + +class AlignedShapeAsLatentPLModule(nn.Module): + def __init__(self, *, + shape_module_cfg, + aligned_module_cfg, + loss_cfg, + optimizer_cfg: Optional[DictConfig] = None, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + shape_model: ShapeAsLatentModule = instantiate_from_config( + shape_module_cfg, device=None, dtype=None + ) + self.model: AlignedShapeAsLatentModule = instantiate_from_config( + aligned_module_cfg, shape_model=shape_model + ) + + self.loss = instantiate_from_config(loss_cfg) + + self.optimizer_cfg = optimizer_cfg + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def set_shape_model_only(self): + self.model.set_shape_model_only() + + @property + def latent_shape(self): + return self.model.shape_model.latent_shape + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def configure_optimizers(self) -> Tuple[List, List]: + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + def forward(self, + surface: torch.FloatTensor, + image: torch.FloatTensor, + text: torch.FloatTensor, + volume_queries: torch.FloatTensor): + # Args: + # surface (torch.FloatTensor): + # image (torch.FloatTensor): + # text (torch.FloatTensor): + # volume_queries (torch.FloatTensor): + # + # Returns: + + embed_outputs, shape_z = self.model(surface, image, text) + + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) + latents = self.model.shape_model.decode(shape_zq) + logits = self.model.shape_model.query_geometry(volume_queries, latents) + + return embed_outputs, logits, posterior + + def encode(self, surface: torch.FloatTensor, sample_posterior=True): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + shape_embed, shape_zq, posterior = self.model.shape_model.encode( + pc=pc, feats=feats, sample_posterior=sample_posterior + ) + + return shape_zq + + def encode_latents(self, surface: torch.FloatTensor): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + shape_embed, shape_latents = self.model.shape_model.encode_latents( + pc=pc, feats=feats + ) + shape_embed = shape_embed.unsqueeze(1) + assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256 + cat_latents = torch.cat([shape_embed, shape_latents], dim=1) + + return cat_latents + + def recon(self, surface): + cat_latents = self.encode_latents(surface) + shape_latents = cat_latents[:, 1:] + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_latents) + + # decoding + latents = self.model.shape_model.decode(shape_zq) + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # reconstruction + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=surface.device, + batch_size=surface.shape[0], + bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth=7, + num_chunks=10000, + ) + recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) + + return recon_mesh + + + def to_shape_latents(self, latents): + + shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False) + return self.model.shape_model.decode(shape_zq) + + def decode(self, + z_q, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] + outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) + + return outputs + + def training_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + #Args: + # batch (dict): the batch sample, and it contains: + # - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] + # - image (torch.FloatTensor): [bs, 3, 224, 224] + # - text (torch.FloatTensor): [bs, num_templates, 77] + # - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] + # + # batch_idx (int): + # + # optimizer_idx (int): + # + # Returns: + # loss (torch.FloatTensor): + + surface = batch["surface"] + image = batch["image"] + text = batch["text"] + + volume_queries = batch["geo_points"][..., 0:3] + shape_labels = batch["geo_points"][..., -1] + + embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) + + aeloss, log_dict_ae = self.loss( + **embed_outputs, + posteriors=posteriors, + shape_logits=shape_logits, + shape_labels=shape_labels, + split="train" + ) + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: + + surface = batch["surface"] + image = batch["image"] + text = batch["text"] + + volume_queries = batch["geo_points"][..., 0:3] + shape_labels = batch["geo_points"][..., -1] + + embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) + + aeloss, log_dict_ae = self.loss( + **embed_outputs, + posteriors=posteriors, + shape_logits=shape_logits, + shape_labels=shape_labels, + split="val" + ) + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def visual_alignment(self, + surface: torch.FloatTensor, + image: torch.FloatTensor, + text: torch.FloatTensor, + description: Optional[List[str]] = None, + bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000) -> List[AlignedMeshOutput]: + + """ + + Args: + surface: + image: + text: + description: + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + device = surface.device + bs = surface.shape[0] + + embed_outputs, shape_z = self.model(surface, image, text) + + # calculate the similarity + image_embed = embed_outputs["image_embed"] + text_embed = embed_outputs["text_embed"] + shape_embed = embed_outputs["shape_embed"] + + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # B x B + shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) + + # B x B + shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) + + # shape reconstruction + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) + latents = self.model.shape_model.decode(shape_zq) + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # 2. decode geometry + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=bs, + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = AlignedMeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + out.surface = surface[i].cpu().numpy() + out.image = image[i].cpu().numpy() + if description is not None: + out.text = description[i] + out.shape_text_similarity = shape_text_similarity[i, i] + out.shape_image_similarity = shape_image_similarity[i, i] + + outputs.append(out) + + return outputs + + def latent2mesh(self, + latents: torch.FloatTensor, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + """ + + Args: + latents: [bs, num_latents, dim] + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # 2. decode geometry + device = latents.device + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=len(latents), + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Latent2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + + outputs.append(out) + + return outputs diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/clip_asl_module.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/clip_asl_module.py new file mode 100644 index 0000000..a5c9562 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/clip_asl_module.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- + +import torch +from torch import nn +from einops import rearrange +from transformers import CLIPModel + +from hy3dgen.shapegen.bpt.miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule + + +class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): + + def __init__(self, *, + shape_model, + clip_model_version: str = "openai/clip-vit-large-patch14"): + + super().__init__() + + # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) + # for params in self.clip_model.parameters(): + # params.requires_grad = False + self.clip_model = None + self.shape_model = shape_model + self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width)) + # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5) + + def set_shape_model_only(self): + self.clip_model = None + + def encode_shape_embed(self, surface, return_latents: bool = False): + """ + + Args: + surface (torch.FloatTensor): [bs, n, 3 + c] + return_latents (bool): + + Returns: + x (torch.FloatTensor): [bs, projection_dim] + shape_latents (torch.FloatTensor): [bs, m, d] + """ + + pc = surface[..., 0:3] + feats = surface[..., 3:] + + shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) + x = shape_embed @ self.shape_projection + + if return_latents: + return x, shape_latents + else: + return x + + def encode_image_embed(self, image): + """ + + Args: + image (torch.FloatTensor): [bs, 3, h, w] + + Returns: + x (torch.FloatTensor): [bs, projection_dim] + """ + + x = self.clip_model.get_image_features(image) + + return x + + def encode_text_embed(self, text): + x = self.clip_model.get_text_features(text) + return x + + def forward(self, surface, image, text): + """ + + Args: + surface (torch.FloatTensor): + image (torch.FloatTensor): [bs, 3, 224, 224] + text (torch.LongTensor): [bs, num_templates, 77] + + Returns: + embed_outputs (dict): the embedding outputs, and it contains: + - image_embed (torch.FloatTensor): + - text_embed (torch.FloatTensor): + - shape_embed (torch.FloatTensor): + - logit_scale (float): + """ + + # # text embedding + # text_embed_all = [] + # for i in range(text.shape[0]): + # text_for_one_sample = text[i] + # text_embed = self.encode_text_embed(text_for_one_sample) + # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + # text_embed = text_embed.mean(dim=0) + # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + # text_embed_all.append(text_embed) + # text_embed_all = torch.stack(text_embed_all) + + b = text.shape[0] + text_tokens = rearrange(text, "b t l -> (b t) l") + text_embed = self.encode_text_embed(text_tokens) + text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) + text_embed = text_embed.mean(dim=1) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + + # image embedding + image_embed = self.encode_image_embed(image) + + # shape embedding + shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) + + embed_outputs = { + "image_embed": image_embed, + "text_embed": text_embed, + "shape_embed": shape_embed, + # "logit_scale": self.clip_model.logit_scale.exp() + } + + return embed_outputs, shape_latents diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/inference_utils.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/inference_utils.py new file mode 100644 index 0000000..1086a95 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/inference_utils.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +import torch +from tqdm import tqdm +from einops import repeat +import numpy as np +from typing import Callable, Tuple, List, Union, Optional +from skimage import measure + +from .....miche.michelangelo.graphics.primitives import generate_dense_grid_points + + +@torch.no_grad() +def extract_geometry(geometric_func: Callable, + device: torch.device, + batch_size: int = 1, + bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000, + disable: bool = True): + + # Args: + # geometric_func: + # device: + # bounds: + # octree_depth: + # batch_size: + # num_chunks: + # disable: + # Returns: + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_depth=octree_depth, + indexing="ij" + ) + xyz_samples = torch.FloatTensor(xyz_samples) + + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), + desc="Implicit Function:", disable=disable, leave=False): + queries = xyz_samples[start: start + num_chunks, :].to(device) + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + + logits = geometric_func(batch_queries) + batch_logits.append(logits.cpu()) + + grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() + + mesh_v_f = [] + has_surface = np.zeros((batch_size,), dtype=np.bool_) + for i in range(batch_size): + try: + vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") + vertices = vertices / grid_size * bbox_size + bbox_min + # vertices[:, [0, 1]] = vertices[:, [1, 0]] + mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) + has_surface[i] = True + + except ValueError: + mesh_v_f.append((None, None)) + has_surface[i] = False + + except RuntimeError: + mesh_v_f.append((None, None)) + has_surface[i] = False + + return mesh_v_f, has_surface diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/loss.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/loss.py new file mode 100644 index 0000000..a2aa24c --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/loss.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional + +from hy3dgen.shapegen.bpt.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution +from hy3dgen.shapegen.bpt.miche.michelangelo.utils import misc + + +class ContrastKLNearFar(nn.Module): + def __init__(self, + contrast_weight: float = 1.0, + near_weight: float = 0.1, + kl_weight: float = 1.0, + num_near_samples: Optional[int] = None): + + super().__init__() + + self.labels = None + self.last_local_batch_size = None + + self.contrast_weight = contrast_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + shape_embed: torch.FloatTensor, + text_embed: torch.FloatTensor, + image_embed: torch.FloatTensor, + logit_scale: torch.FloatTensor, + posteriors: Optional[DiagonalGaussianDistribution], + shape_logits: torch.FloatTensor, + shape_labels: torch.FloatTensor, + split: Optional[str] = "train", **kwargs): + + # shape_embed: torch.FloatTensor + # text_embed: torch.FloatTensor + # image_embed: torch.FloatTensor + # logit_scale: torch.FloatTensor + # posteriors: Optional[DiagonalGaussianDistribution] + # shape_logits: torch.FloatTensor + # shape_labels: torch.FloatTensor + + local_batch_size = shape_embed.size(0) + + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * misc.get_rank() + torch.arange( + local_batch_size, device=shape_embed.device + ).long() + self.last_local_batch_size = local_batch_size + + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # gather features from all GPUs + shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( + [shape_embed, text_embed, image_embed] + ) + + # cosine similarity as logits + logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() + logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() + logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() + logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() + contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + + F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ + (F.cross_entropy(logits_per_shape_image, self.labels) + + F.cross_entropy(logits_per_image_shape, self.labels)) / 2 + + # shape reconstruction + if self.num_near_samples is None: + num_vol = shape_logits.shape[1] // 2 + else: + num_vol = shape_logits.shape[1] - self.num_near_samples + + vol_logits = shape_logits[:, 0:num_vol] + vol_labels = shape_labels[:, 0:num_vol] + + near_logits = shape_logits[:, num_vol:] + near_labels = shape_labels[:, num_vol:] + + # occupancy loss + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight + + # compute accuracy + with torch.no_grad(): + pred = torch.argmax(logits_per_shape_text, dim=-1) + correct = pred.eq(self.labels).sum() + shape_text_acc = 100 * correct / local_batch_size + + pred = torch.argmax(logits_per_shape_image, dim=-1) + correct = pred.eq(self.labels).sum() + shape_image_acc = 100 * correct / local_batch_size + + preds = shape_logits >= 0 + accuracy = (preds == shape_labels).float() + accuracy = accuracy.mean() + + log = { + "{}/contrast".format(split): contrast_loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/shape_text_acc".format(split): shape_text_acc, + "{}/shape_image_acc".format(split): shape_image_acc, + "{}/total_loss".format(split): loss.clone().detach(), + "{}/accuracy".format(split): accuracy, + } + + if posteriors is not None: + log[f"{split}/mean"] = posteriors.mean.mean().detach() + log[f"{split}/std_mean"] = posteriors.std.mean().detach() + log[f"{split}/std_max"] = posteriors.std.max().detach() + + return loss, log diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/sal_perceiver.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/sal_perceiver.py new file mode 100644 index 0000000..82fe326 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/sal_perceiver.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from typing import Optional +from einops import repeat +import math + +from hy3dgen.shapegen.bpt.miche.michelangelo.models.modules import checkpoint +from hy3dgen.shapegen.bpt.miche.michelangelo.models.modules.embedder import FourierEmbedder +from hy3dgen.shapegen.bpt.miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution +from hy3dgen.shapegen.bpt.miche.michelangelo.models.modules.transformer_blocks import ( + ResidualCrossAttentionBlock, + Transformer +) + +from .tsal_base import ShapeAsLatentModule + + +class CrossAttentionEncoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + + self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) + + self.fourier_embedder = fourier_embedder + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) + self.cross_attn = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + + self.self_attn = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=False + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) + else: + self.ln_post = None + + def _forward(self, pc, feats): + + # Args: + # pc (torch.FloatTensor): [B, N, 3] + # feats (torch.FloatTensor or None): [B, N, C] + + bs = pc.shape[0] + + data = self.fourier_embedder(pc) + if feats is not None: + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + query = repeat(self.query, "m c -> b m c", b=bs) + latents = self.cross_attn(query, data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents, pc + + def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): + + # Args: + # pc (torch.FloatTensor): [B, N, 3] + # feats (torch.FloatTensor or None): [B, N, C] + + + return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) + + +class CrossAttentionDecoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.fourier_embedder = fourier_embedder + + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + queries = self.query_proj(self.fourier_embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x = self.output_proj(x) + return x + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) + + +class ShapeAsLatentPerceiver(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.embed_dim = embed_dim + if embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) + self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) + self.latent_shape = (num_latents, embed_dim) + else: + self.latent_shape = (num_latents, width) + + self.transformer = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=num_decoder_layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + # geometry decoder + self.geo_decoder = CrossAttentionDecoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + + + # Args: + # pc (torch.FloatTensor): [B, N, 3] + # feats (torch.FloatTensor or None): [B, N, C] + # sample_posterior (bool): + + # Returns: + # latents (torch.FloatTensor) + # center_pos (torch.FloatTensor or None): + # posterior (DiagonalGaussianDistribution or None): + + + latents, center_pos = self.encoder(pc, feats) + + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + latents = posterior.sample() + else: + latents = posterior.mode() + + return latents, center_pos, posterior + + def decode(self, latents: torch.FloatTensor): + latents = self.post_kl(latents) + return self.transformer(latents) + + def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + logits = self.geo_decoder(queries, latents).squeeze(-1) + return logits + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + + # Args: + # pc (torch.FloatTensor): [B, N, 3] + # feats (torch.FloatTensor or None): [B, N, C] + # volume_queries (torch.FloatTensor): [B, P, 3] + # sample_posterior (bool): + + # Returns: + # logits (torch.FloatTensor): [B, P] + # center_pos (torch.FloatTensor): [B, M, 3] + # posterior (DiagonalGaussianDistribution or None). + + + + latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(latents) + logits = self.query_geometry(volume_queries, latents) + + return logits, center_pos, posterior + + +class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__( + device=device, + dtype=dtype, + num_latents=1 + num_latents, + point_feats=point_feats, + embed_dim=embed_dim, + num_freqs=num_freqs, + include_pi=include_pi, + width=width, + heads=heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.width = width + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + + # Args: + # pc (torch.FloatTensor): [B, N, 3] + # feats (torch.FloatTensor or None): [B, N, c] + # sample_posterior (bool): + + # Returns: + # shape_embed (torch.FloatTensor) + # kl_embed (torch.FloatTensor): + # posterior (DiagonalGaussianDistribution or None): + + + shape_embed, latents = self.encode_latents(pc, feats) + kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) + + return shape_embed, kl_embed, posterior + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _ = self.encoder(pc, feats) + + shape_embed = x[:, 0] + latents = x[:, 1:] + + return shape_embed, latents + + def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + kl_embed = posterior.sample() + else: + kl_embed = posterior.mode() + else: + kl_embed = latents + + return kl_embed, posterior + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + + # Args: + # pc (torch.FloatTensor): [B, N, 3] + # feats (torch.FloatTensor or None): [B, N, C] + # volume_queries (torch.FloatTensor): [B, P, 3] + # sample_posterior (bool): + + # Returns: + # shape_embed (torch.FloatTensor): [B, projection_dim] + # logits (torch.FloatTensor): [B, M] + # posterior (DiagonalGaussianDistribution or None). + + + shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(kl_embed) + logits = self.query_geometry(volume_queries, latents) + + return shape_embed, logits, posterior diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/tsal_base.py b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/tsal_base.py new file mode 100644 index 0000000..0de0859 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/models/tsal/tsal_base.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from typing import Tuple, List, Optional + +# Base class for output of Point to Mesh transformation +class Point2MeshOutput(object): + def __init__(self): + self.mesh_v = None # Vertices of the mesh + self.mesh_f = None # Faces of the mesh + self.center = None # Center of the mesh + self.pc = None # Point cloud data + + +# Base class for output of Latent to Mesh transformation +class Latent2MeshOutput(object): + def __init__(self): + self.mesh_v = None # Vertices of the mesh + self.mesh_f = None # Faces of the mesh + + +# Base class for output of Aligned Mesh transformation +class AlignedMeshOutput(object): + def __init__(self): + self.mesh_v = None # Vertices of the mesh + self.mesh_f = None # Faces of the mesh + self.surface = None # Surface data + self.image = None # Aligned image data + self.text: Optional[str] = None # Aligned text data + self.shape_text_similarity: Optional[float] = None # Similarity between shape and text + self.shape_image_similarity: Optional[float] = None # Similarity between shape and image + + +# Base class for Shape as Latent with Point to Mesh transformation module +class ShapeAsLatentPLModule(nn.Module): + latent_shape: Tuple[int] # Shape of the latent space + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +# Base class for Shape as Latent module +class ShapeAsLatentModule(nn.Module): + latent_shape: Tuple[int, int] # Shape of the latent space + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + +# Base class for Aligned Shape as Latent with Point to Mesh transformation module +class AlignedShapeAsLatentPLModule(nn.Module): + latent_shape: Tuple[int] # Shape of the latent space + + def set_shape_model_only(self): + raise NotImplementedError + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +# Base class for Aligned Shape as Latent module +class AlignedShapeAsLatentModule(nn.Module): + shape_model: ShapeAsLatentModule # Shape model module + latent_shape: Tuple[int, int] # Shape of the latent space + + + def __init__(self, *args, **kwargs): + super().__init__() + + def set_shape_model_only(self): + raise NotImplementedError + + def encode_image_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_text_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_shape_embed(self, *args, **kwargs): + raise NotImplementedError + +# Base class for Textured Shape as Latent module +class TexturedShapeAsLatentModule(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + def query_color(self, *args, **kwargs): + raise NotImplementedError diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__init__.py b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__init__.py new file mode 100644 index 0000000..6e9efc9 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .misc import instantiate_from_config diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..8858c31 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__pycache__/misc.cpython-312.pyc b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__pycache__/misc.cpython-312.pyc new file mode 100644 index 0000000..584d7c4 Binary files /dev/null and b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/__pycache__/misc.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/miche/michelangelo/utils/misc.py b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/misc.py new file mode 100644 index 0000000..ca56d58 --- /dev/null +++ b/hy3dgen/shapegen/bpt/miche/michelangelo/utils/misc.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +import importlib + +import torch +import torch.distributed as dist + +import sys +sys.path.append(r"C:\Remade\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-Hunyuan3DWrapper-main") + +from hy3dgen.shapegen.bpt.miche.michelangelo.models.tsal import asl_pl_module + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def get_obj_from_config(config): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + return get_obj_from_str(config["target"]) + + +def instantiate_from_config(config, **kwargs): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + cls = get_obj_from_str(config["target"]) + + params = config.get("params", dict()) + # params.update(kwargs) + # instance = cls(**params) + kwargs.update(params) + instance = cls(**kwargs) + + return instance + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def all_gather_batch(tensors): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather( + tensor_all, + tensor, + async_op=False # performance opt + ) + + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor diff --git a/hy3dgen/shapegen/bpt/model/__init__.py b/hy3dgen/shapegen/bpt/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hy3dgen/shapegen/bpt/model/__pycache__/__init__.cpython-312.pyc b/hy3dgen/shapegen/bpt/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..71d9e54 Binary files /dev/null and b/hy3dgen/shapegen/bpt/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/model/__pycache__/data_utils.cpython-312.pyc b/hy3dgen/shapegen/bpt/model/__pycache__/data_utils.cpython-312.pyc new file mode 100644 index 0000000..45d70d8 Binary files /dev/null and b/hy3dgen/shapegen/bpt/model/__pycache__/data_utils.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/model/__pycache__/miche_conditioner.cpython-312.pyc b/hy3dgen/shapegen/bpt/model/__pycache__/miche_conditioner.cpython-312.pyc new file mode 100644 index 0000000..dfcd6a5 Binary files /dev/null and b/hy3dgen/shapegen/bpt/model/__pycache__/miche_conditioner.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/model/__pycache__/model.cpython-312.pyc b/hy3dgen/shapegen/bpt/model/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000..a594f4b Binary files /dev/null and b/hy3dgen/shapegen/bpt/model/__pycache__/model.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/model/__pycache__/serializaiton.cpython-312.pyc b/hy3dgen/shapegen/bpt/model/__pycache__/serializaiton.cpython-312.pyc new file mode 100644 index 0000000..a07fc57 Binary files /dev/null and b/hy3dgen/shapegen/bpt/model/__pycache__/serializaiton.cpython-312.pyc differ diff --git a/hy3dgen/shapegen/bpt/model/data_utils.py b/hy3dgen/shapegen/bpt/model/data_utils.py new file mode 100644 index 0000000..5572ed8 --- /dev/null +++ b/hy3dgen/shapegen/bpt/model/data_utils.py @@ -0,0 +1,194 @@ +"""Mesh data utilities.""" +import random +import networkx as nx +import numpy as np +# import pyrr +from six.moves import range +import trimesh +from scipy.spatial.transform import Rotation + + +def to_mesh(vertices, faces, transpose=True, post_process=False): + if transpose: + vertices = vertices[:, [1, 2, 0]] + + if faces.min() == 1: + faces = (np.array(faces) - 1).tolist() + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + + if post_process: + mesh.merge_vertices() + mesh.update_faces(mesh.unique_faces()) + mesh.fix_normals() + return mesh + + +def center_vertices(vertices): + """Translate the vertices so that bounding box is centered at zero.""" + vert_min = vertices.min(axis=0) + vert_max = vertices.max(axis=0) + vert_center = 0.5 * (vert_min + vert_max) + # vert_center = np.mean(vertices, axis=0) + return vertices - vert_center + + +def face_to_cycles(face): + """Find cycles in face.""" + g = nx.Graph() + for v in range(len(face) - 1): + g.add_edge(face[v], face[v + 1]) + g.add_edge(face[-1], face[0]) + return list(nx.cycle_basis(g)) + + +def block_index(vertex, block_size=32): + return (vertex[2] // block_size, vertex[1] // block_size, vertex[0] // block_size) + +def block_id(block_index, num_blocks=4): + return block_index[0] * num_blocks**2 + block_index[1] * num_blocks + block_index[2] + + +def normalize_vertices_scale(vertices, scale=0.95): + """Scale the vertices so that the long axis of the bounding box is one.""" + vert_min = vertices.min(axis=0) + vert_max = vertices.max(axis=0) + extents = (vert_max - vert_min).max() + return 2.0 * scale * vertices / (extents + 1e-6) + + +def quantize_process_mesh(vertices, faces, quantization_bits=8, block_first_order=True, block_size=32, num_blocks=4): + """Quantize vertices, remove resulting duplicates and reindex faces.""" + vertices = discretize(vertices, num_discrete=2**quantization_bits) + vertices, inv = np.unique(vertices, axis=0, return_inverse=True) + + if block_first_order: + block_indices = np.array([block_index(v, block_size) for v in vertices]) + block_ids = np.array([block_id(b, num_blocks) for b in block_indices]) + sort_inds = np.lexsort((vertices[:, 0], vertices[:, 1], vertices[:, 2], block_ids)) + else: + # Sort vertices by z then y then x. + sort_inds = np.lexsort(vertices.T) + + vertices = vertices[sort_inds] + faces = [np.argsort(sort_inds)[inv[f]] for f in faces] + + sub_faces = [] + for f in faces: + cliques = face_to_cycles(f) + for c in cliques: + c_length = len(c) + if c_length > 2: + d = np.argmin(f) + sub_faces.append([f[(d + i) % c_length] for i in range(c_length)]) + + faces = sub_faces + + # Sort faces by lowest vertex indices. If two faces have the same lowest + # index then sort by next lowest and so on. + faces.sort(key=lambda f: tuple(sorted(f))) + num_verts = vertices.shape[0] + vert_connected = np.equal( + np.arange(num_verts)[:, None], np.hstack(faces)[None] + ).any(axis=-1) + vertices = vertices[vert_connected] + + # Re-index faces to re-ordered vertices. + vert_indices = np.arange(num_verts) - np.cumsum(1 - vert_connected.astype("int")) + faces = [vert_indices[f].tolist() for f in faces] + + return vertices, faces + + +def process_mesh(vertices, faces, quantization_bits=8, augment=True, augment_dict=None): + """Process mesh vertices and faces.""" + + # Transpose so that z-axis is vertical. + vertices = vertices[:, [2, 0, 1]] + + # Translate the vertices so that bounding box is centered at zero. + vertices = center_vertices(vertices) + + if augment: + vertices = augment_mesh(vertices, **augment_dict) + + # Scale the vertices so that the long diagonal of the bounding box is equal + # to one. + vertices = normalize_vertices_scale(vertices) + + # Quantize and sort vertices, remove resulting duplicates, sort and reindex + # faces. + vertices, faces = quantize_process_mesh( + vertices, faces, quantization_bits=quantization_bits + ) + vertices = undiscretize(vertices, num_discrete=2**quantization_bits) + + + # Discard degenerate meshes without faces. + return { + "vertices": vertices, + "faces": faces, + } + + +def load_process_mesh(mesh_obj_path, quantization_bits=8, augment=False, augment_dict=None): + """Load obj file and process.""" + # Load mesh + mesh = trimesh.load(mesh_obj_path, force='mesh', process=False) + return process_mesh(mesh.vertices, mesh.faces, quantization_bits, augment=augment, augment_dict=augment_dict) + + +def augment_mesh(vertices, scale_min=0.95, scale_max=1.05, rotation=0., jitter_strength=0.): + '''scale vertices by a factor in [0.75, 1.25]''' + + # vertices [nv, 3] + for i in range(3): + # Generate a random scale factor + scale = random.uniform(scale_min, scale_max) + + # independently applied scaling across each axis of vertices + vertices[:, i] *= scale + + if rotation != 0.: + axis = [random.uniform(-1, 1), random.uniform(-1, 1), random.uniform(-1, 1)] + radian = np.pi / 180 * rotation + rotation = Rotation.from_rotvec(radian * np.array(axis)) + vertices =rotation.apply(vertices) + + + if jitter_strength != 0.: + jitter_amount = np.random.uniform(-jitter_strength, jitter_strength) + vertices += jitter_amount + + + return vertices + + +def discretize( + t, + continuous_range = (-1, 1), + num_discrete: int = 128 +): + lo, hi = continuous_range + assert hi > lo + + t = (t - lo) / (hi - lo) + t *= num_discrete + t -= 0.5 + + return t.round().astype(np.int32).clip(min = 0, max = num_discrete - 1) + + +def undiscretize( + t, + continuous_range = (-1, 1), + num_discrete: int = 128 +): + lo, hi = continuous_range + assert hi > lo + + t = t.astype(np.float32) + + t += 0.5 + t /= num_discrete + return t * (hi - lo) + lo + diff --git a/hy3dgen/shapegen/bpt/model/miche_conditioner.py b/hy3dgen/shapegen/bpt/model/miche_conditioner.py new file mode 100644 index 0000000..1a744d5 --- /dev/null +++ b/hy3dgen/shapegen/bpt/model/miche_conditioner.py @@ -0,0 +1,90 @@ +import torch +import os +from torch import nn +from beartype import beartype +from ..miche.encode import load_model +from ..miche.michelangelo.models.tsal import asl_pl_module + +# helper functions + +def exists(val): + return val is not None + +def default(*values): + for value in values: + if exists(value): + return value + return None + + +# point-cloud encoder from Michelangelo +@beartype +class PointConditioner(torch.nn.Module): + def __init__( + self, + *, + dim_latent = None, + model_name = 'miche-256-feature', + cond_dim = 768, + freeze = True, + ): + super().__init__() + + # open-source version of miche + if model_name == 'miche-256-feature': + ckpt_path = None + dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(dir, '..\shapevae-256.yaml') + config_path = model_path + + self.feature_dim = 1024 # embedding dimension + self.cond_length = 257 # length of embedding + self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path) + + # additional layers to connect miche and GPT + self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim) + self.cond_proj = nn.Linear(cond_dim, self.feature_dim) + + else: + raise NotImplementedError + + # whether to finetuen point-cloud encoder + if freeze: + for parameter in self.point_encoder.parameters(): + parameter.requires_grad = False + + self.freeze = freeze + self.model_name = model_name + self.dim_latent = default(dim_latent, self.feature_dim) + + self.register_buffer('_device_param', torch.tensor(0.), persistent = False) + + + @property + def device(self): + return next(self.buffers()).device + + + def embed_pc(self, pc_normal): + # encode point cloud to embeddings + if self.model_name == 'miche-256-feature': + point_feature = self.point_encoder.encode_latents(pc_normal) + pc_embed_head = self.cond_head_proj(point_feature[:, 0:1]) + pc_embed = self.cond_proj(point_feature[:, 1:]) + pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1) + + return pc_embed + + + def forward( + self, + pc = None, + pc_embeds = None, + ): + if pc_embeds is None: + pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype)) + + assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings' + + return pc_embeds + diff --git a/hy3dgen/shapegen/bpt/model/model.py b/hy3dgen/shapegen/bpt/model/model.py new file mode 100644 index 0000000..8ec2d4d --- /dev/null +++ b/hy3dgen/shapegen/bpt/model/model.py @@ -0,0 +1,382 @@ +import math +import torch +from torch import nn, Tensor +from torch.nn import Module +import torch.nn.functional as F +from einops import rearrange, repeat, pack +from pytorch_custom_utils import save_load +from beartype import beartype +from beartype.typing import Union, Tuple, Callable, Optional, Any +from einops import rearrange, repeat, pack +from x_transformers import Decoder +from x_transformers.x_transformers import LayerIntermediates +from x_transformers.autoregressive_wrapper import ( + eval_decorator, + top_k, +) +from .miche_conditioner import PointConditioner +from functools import partial +from tqdm import tqdm +from .data_utils import discretize + +# helper functions + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +def first(it): + return it[0] + +def divisible_by(num, den): + return (num % den) == 0 + +def pad_at_dim(t, padding, dim = -1, value = 0): + ndim = t.ndim + right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1) + zeros = (0, 0) * right_dims + return F.pad(t, (*zeros, *padding), value = value) + + +# main class of auto-regressive Transformer +@save_load() +class MeshTransformer(Module): + @beartype + def __init__( + self, + *, + dim: Union[int, Tuple[int, int]] = 1024, # hidden size of Transformer + max_seq_len = 10000, # max sequence length + flash_attn = True, # wether to use flash attention + attn_depth = 24, # number of layers + attn_dim_head = 64, # dim for each head + attn_heads = 16, # number of heads + attn_kwargs: dict = dict( + ff_glu = True, + num_mem_kv = 4, + attn_qk_norm = True, + ), + dropout = 0.0, + pad_id = -1, + coor_continuous_range = (-1., 1.), + num_discrete_coors = 2**int(7), + block_size = 8, + offset_size = 16, + mode = 'vertices', + special_token = -2, + use_special_block = True, + conditioned_on_pc = True, + encoder_name = 'miche-256-feature', + encoder_freeze = False, + ): + super().__init__() + + if use_special_block: + # block_ids, offset_ids, special_block_ids + vocab_size = block_size**3 + offset_size**3 + block_size**3 + self.sp_block_embed = nn.Parameter(torch.randn(1, dim)) + else: + # block_ids, offset_ids, special_token + vocab_size = block_size**3 + offset_size**3 + 1 + self.special_token = special_token + self.special_token_cb = block_size**3 + offset_size**3 + + self.use_special_block = use_special_block + + self.sos_token = nn.Parameter(torch.randn(dim)) + self.eos_token_id = vocab_size + self.mode = mode + self.token_embed = nn.Embedding(vocab_size + 1, dim) + self.num_discrete_coors = num_discrete_coors + self.coor_continuous_range = coor_continuous_range + self.block_size = block_size + self.offset_size = offset_size + self.abs_pos_emb = nn.Embedding(max_seq_len, dim) + self.max_seq_len = max_seq_len + self.conditioner = None + self.conditioned_on_pc = conditioned_on_pc + cross_attn_dim_context = None + + self.block_embed = nn.Parameter(torch.randn(1, dim)) + self.offset_embed = nn.Parameter(torch.randn(1, dim)) + + assert self.block_size * self.offset_size == self.num_discrete_coors + + # load point_cloud encoder + if conditioned_on_pc: + print(f'Point cloud encoder: {encoder_name} | freeze: {encoder_freeze}') + self.conditioner = PointConditioner(model_name=encoder_name, freeze=encoder_freeze) + cross_attn_dim_context = self.conditioner.dim_latent + else: + raise NotImplementedError + + # main autoregressive attention network + self.decoder = Decoder( + dim = dim, + depth = attn_depth, + dim_head = attn_dim_head, + heads = attn_heads, + attn_flash = flash_attn, + attn_dropout = dropout, + ff_dropout = dropout, + cross_attend = conditioned_on_pc, + cross_attn_dim_context = cross_attn_dim_context, + cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition + **attn_kwargs + ) + + self.to_logits = nn.Linear(dim, vocab_size + 1) + self.pad_id = pad_id + self.discretize_face_coords = partial( + discretize, + num_discrete = num_discrete_coors, + continuous_range = coor_continuous_range + ) + + @property + def device(self): + return next(self.parameters()).device + + + @eval_decorator + @torch.no_grad() + @beartype + def generate( + self, + prompt: Optional[Tensor] = None, + pc: Optional[Tensor] = None, + cond_embeds: Optional[Tensor] = None, + batch_size: Optional[int] = 1, + filter_logits_fn: Callable = top_k, + filter_kwargs: dict = dict(), + temperature = 0.5, + return_codes = False, + cache_kv = True, + max_seq_len = None, + face_coords_to_file: Optional[Callable[[Tensor], Any]] = None, + tqdm_position = 0, + ): + max_seq_len = default(max_seq_len, self.max_seq_len) + + if exists(prompt): + assert not exists(batch_size) + + prompt = rearrange(prompt, 'b ... -> b (...)') + assert prompt.shape[-1] <= self.max_seq_len + + batch_size = prompt.shape[0] + + # encode point cloud + if cond_embeds is None: + if self.conditioned_on_pc: + cond_embeds = self.conditioner(pc = pc) + + batch_size = default(batch_size, 1) + + codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device)) + + curr_length = codes.shape[-1] + + cache = None + eos_iter = None + # predict tokens auto-regressively + for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position, + desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False): + + output = self.forward_on_codes( + codes, + return_loss = False, + return_cache = cache_kv, + append_eos = False, + cond_embeds = cond_embeds, + cache = cache + ) + + if cache_kv: + logits, cache = output + else: + logits = output + + # sample code from logits + logits = logits[:, -1] + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + probs = F.softmax(filtered_logits / temperature, dim=-1) + sample = torch.multinomial(probs, 1) + codes, _ = pack([codes, sample], 'b *') + + # Check if all sequences have encountered EOS at least once + is_eos_codes = (codes == self.eos_token_id) + if is_eos_codes.any(dim=-1).all(): + # Record the iteration (i.e. current sequence length) when EOS is first detected in all sequences + if eos_iter is None: + eos_iter = codes.shape[-1] + # Once we've generated 20% more tokens than eos_iter, break out of the loop + if codes.shape[-1] >= int(eos_iter * 1.2): + break + + # mask out to padding anything after the first eos + + mask = is_eos_codes.float().cumsum(dim = -1) >= 1 + codes = codes.masked_fill(mask, self.pad_id) + + # early return of raw residual quantizer codes + + if return_codes: + # codes = rearrange(codes, 'b (n q) -> b n q', q = 2) + if not self.use_special_block: + codes[codes == self.special_token_cb] = self.special_token + return codes + + face_coords, face_mask = self.decode_codes(codes) + + if not exists(face_coords_to_file): + return face_coords, face_mask + + files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)] + return files + + + def forward( + self, + *, + codes: Optional[Tensor] = None, + cache: Optional[LayerIntermediates] = None, + **kwargs + ): + # convert special tokens + if not self.use_special_block: + codes[codes == self.special_token] = self.special_token_cb + + return self.forward_on_codes(codes, cache = cache, **kwargs) + + + def forward_on_codes( + self, + codes = None, + return_loss = True, + return_cache = False, + append_eos = True, + cache = None, + pc = None, + cond_embeds = None, + ): + # handle conditions + + attn_context_kwargs = dict() + + if self.conditioned_on_pc: + assert exists(pc) ^ exists(cond_embeds), 'point cloud should be given' + + # preprocess faces and vertices + if not exists(cond_embeds): + cond_embeds = self.conditioner( + pc = pc, + pc_embeds = cond_embeds, + ) + + attn_context_kwargs = dict( + context = cond_embeds, + context_mask = None, + ) + + # take care of codes that may be flattened + + if codes.ndim > 2: + codes = rearrange(codes, 'b ... -> b (...)') + + # prepare mask for position embedding of block and offset tokens + block_mask = (0 <= codes) & (codes < self.block_size**3) + offset_mask = (self.block_size**3 <= codes) & (codes < self.block_size**3 + self.offset_size**3) + if self.use_special_block: + sp_block_mask = ( + self.block_size**3 + self.offset_size**3 <= codes + ) & ( + codes < self.block_size**3 + self.offset_size**3 + self.block_size**3 + ) + + + # get some variable + + batch, seq_len, device = *codes.shape, codes.device + + assert seq_len <= self.max_seq_len, \ + f'received codes of length {seq_len} but needs to be less than {self.max_seq_len}' + + # auto append eos token + + if append_eos: + assert exists(codes) + + code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1) + + codes = F.pad(codes, (0, 1), value = 0) # value=-1 + + batch_arange = torch.arange(batch, device = device) + + batch_arange = rearrange(batch_arange, '... -> ... 1') + code_lens = rearrange(code_lens, '... -> ... 1') + + codes[batch_arange, code_lens] = self.eos_token_id + + + # if returning loss, save the labels for cross entropy + + if return_loss: + assert seq_len > 0 + codes, labels = codes[:, :-1], codes + + # token embed + + codes = codes.masked_fill(codes == self.pad_id, 0) + codes = self.token_embed(codes) + + # codebook embed + absolute positions + + seq_arange = torch.arange(codes.shape[-2], device = device) + codes = codes + self.abs_pos_emb(seq_arange) + + # add positional embedding for block and offset token + block_embed = repeat(self.block_embed, '1 d -> b n d', n = seq_len, b = batch) + offset_embed = repeat(self.offset_embed, '1 d -> b n d', n = seq_len, b = batch) + codes[block_mask] += block_embed[block_mask] + codes[offset_mask] += offset_embed[offset_mask] + + if self.use_special_block: + sp_block_embed = repeat(self.sp_block_embed, '1 d -> b n d', n = seq_len, b = batch) + codes[sp_block_mask] += sp_block_embed[sp_block_mask] + + # auto prepend sos token + + sos = repeat(self.sos_token, 'd -> b d', b = batch) + codes, _ = pack([sos, codes], 'b * d') + + # attention + + attended, intermediates_with_cache = self.decoder( + codes, + cache = cache, + return_hiddens = True, + **attn_context_kwargs + ) + + # logits + + logits = self.to_logits(attended) + + if not return_loss: + if not return_cache: + return logits + + return logits, intermediates_with_cache + + # loss + + ce_loss = F.cross_entropy( + rearrange(logits, 'b n c -> b c n'), + labels, + ignore_index = self.pad_id + ) + + return ce_loss diff --git a/hy3dgen/shapegen/bpt/model/serializaiton.py b/hy3dgen/shapegen/bpt/model/serializaiton.py new file mode 100644 index 0000000..97c359d --- /dev/null +++ b/hy3dgen/shapegen/bpt/model/serializaiton.py @@ -0,0 +1,241 @@ +import trimesh +import numpy as np +from .data_utils import discretize, undiscretize + + +def patchified_mesh(mesh: trimesh.Trimesh, special_token = -2, fix_orient=True): + sequence = [] + unvisited = np.full(len(mesh.faces), True) + degrees = mesh.vertex_degree.copy() + + # with fix_orient=True, the normal would be correct. + # but this may increase the difficulty for learning. + if fix_orient: + face_orient = {} + for ind, face in enumerate(mesh.faces): + v0, v1, v2 = face[0], face[1], face[2] + face_orient['{}-{}-{}'.format(v0, v1, v2)] = True + face_orient['{}-{}-{}'.format(v1, v2, v0)] = True + face_orient['{}-{}-{}'.format(v2, v0, v1)] = True + face_orient['{}-{}-{}'.format(v2, v1, v0)] = False + face_orient['{}-{}-{}'.format(v1, v0, v2)] = False + face_orient['{}-{}-{}'.format(v0, v2, v1)] = False + + while sum(unvisited): + unvisited_faces = mesh.faces[unvisited] + + # select the patch center + cur_face = unvisited_faces[0] + max_deg_vertex_id = np.argmax(degrees[cur_face]) + max_deg_vertex = cur_face[max_deg_vertex_id] + + # find all connected faces + selected_faces = [] + for face_idx in mesh.vertex_faces[max_deg_vertex]: + if face_idx != -1 and unvisited[face_idx]: + face = mesh.faces[face_idx] + u, v = sorted([vertex for vertex in face if vertex != max_deg_vertex]) + selected_faces.append([u, v, face_idx]) + + face_patch = set() + selected_faces = sorted(selected_faces) + + # select the start vertex, select it if it only appears once (the start or end), + # else select the lowest index + cnt = {} + for u, v, _ in selected_faces: + cnt[u] = cnt.get(u, 0) + 1 + cnt[v] = cnt.get(v, 0) + 1 + starts = [] + for vertex, num in cnt.items(): + if num == 1: + starts.append(vertex) + start_idx = min(starts) if len(starts) else selected_faces[0][0] + + res = [start_idx] + while len(res) <= len(selected_faces): + vertex = res[-1] + for u_i, v_i, face_idx_i in selected_faces: + if face_idx_i not in face_patch and vertex in (u_i, v_i): + u_i, v_i = (u_i, v_i) if vertex == u_i else (v_i, u_i) + res.append(v_i) + face_patch.add(face_idx_i) + break + + if res[-1] == vertex: + break + + if fix_orient and len(res) >= 2 and not face_orient['{}-{}-{}'.format(max_deg_vertex, res[0], res[1])]: + res = res[::-1] + + # reduce the degree of related vertices and mark the visited faces + degrees[max_deg_vertex] = len(selected_faces) - len(res) + 1 + for pos_idx, vertex in enumerate(res): + if pos_idx in [0, len(res) - 1]: + degrees[vertex] -= 1 + else: + degrees[vertex] -= 2 + for face_idx in face_patch: + unvisited[face_idx] = False + sequence.extend( + [mesh.vertices[max_deg_vertex]] + + [mesh.vertices[vertex_idx] for vertex_idx in res] + + [[special_token] * 3] + ) + + assert sum(degrees) == 0, 'All degrees should be zero' + + return np.array(sequence) + + + +def get_block_representation( + sequence, + block_size=8, + offset_size=16, + block_compressed=True, + special_token=-2, + use_special_block=True + ): + ''' + convert coordinates from Cartesian system to block indexes. + ''' + special_block_base = block_size**3 + offset_size**3 + # prepare coordinates + sp_mask = sequence != special_token + sp_mask = np.all(sp_mask, axis=1) + coords = sequence[sp_mask].reshape(-1, 3) + coords = discretize(coords) + + # convert [x, y, z] to [block_id, offset_id] + block_id = coords // offset_size + block_id = block_id[:, 0] * block_size**2 + block_id[:, 1] * block_size + block_id[:, 2] + offset_id = coords % offset_size + offset_id = offset_id[:, 0] * offset_size**2 + offset_id[:, 1] * offset_size + offset_id[:, 2] + offset_id += block_size**3 + block_coords = np.concatenate([block_id[..., None], offset_id[..., None]], axis=-1).astype(np.int64) + sequence[:, :2][sp_mask] = block_coords + sequence = sequence[:, :2] + + # convert to codes + codes = [] + cur_block_id = sequence[0, 0] + codes.append(cur_block_id) + for i in range(len(sequence)): + if sequence[i, 0] == special_token: + if not use_special_block: + codes.append(special_token) + cur_block_id = special_token + + elif sequence[i, 0] == cur_block_id: + if block_compressed: + codes.append(sequence[i, 1]) + else: + codes.extend([sequence[i, 0], sequence[i, 1]]) + + else: + if use_special_block and cur_block_id == special_token: + block_id = sequence[i, 0] + special_block_base + else: + block_id = sequence[i, 0] + codes.extend([block_id, sequence[i, 1]]) + cur_block_id = block_id + + codes = np.array(codes).astype(np.int64) + sequence = codes + + return sequence.flatten() + + +def BPT_serialize(mesh: trimesh.Trimesh): + # serialize mesh with BPT + + # 1. patchify faces into patches + sequence = patchified_mesh(mesh, special_token=-2) + + # 2. convert coordinates to block-wise indexes + codes = get_block_representation( + sequence, block_size=8, offset_size=16, + block_compressed=True, special_token=-2, use_special_block=True + ) + return codes + + +def decode_block(sequence, compressed=True, block_size=8, offset_size=16): + + # decode from compressed representation + if compressed: + res = [] + res_block = 0 + for token_id in range(len(sequence)): + if block_size**3 + offset_size**3 > sequence[token_id] >= block_size**3: + res.append([res_block, sequence[token_id]]) + elif block_size**3 > sequence[token_id] >= 0: + res_block = sequence[token_id] + else: + print('[Warning] too large offset idx!', token_id, sequence[token_id]) + sequence = np.array(res) + + block_id, offset_id = np.array_split(sequence, 2, axis=-1) + + # from hash representation to xyz + coords = [] + offset_id -= block_size**3 + for i in [2, 1, 0]: + axis = (block_id // block_size**i) * offset_size + (offset_id // offset_size**i) + block_id %= block_size**i + offset_id %= offset_size**i + coords.append(axis) + + coords = np.concatenate(coords, axis=-1) # (nf 3) + + # back to continuous space + coords = undiscretize(coords) + + return coords + + +def BPT_deserialize(sequence, block_size=8, offset_size=16, compressed=True, special_token=-2, use_special_block=True): + # decode codes back to coordinates + + special_block_base = block_size**3 + offset_size**3 + start_idx = 0 + vertices = [] + for i in range(len(sequence)): + sub_seq = [] + if not use_special_block and (sequence[i] == special_token or i == len(sequence) - 1): + sub_seq = sequence[start_idx:i] + sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size) + start_idx = i + 1 + + elif use_special_block and \ + (special_block_base <= sequence[i] < special_block_base + block_size**3 or i == len(sequence)-1): + if i != 0: + sub_seq = sequence[start_idx:i] if i != len(sequence) - 1 else sequence[start_idx: i+1] + if special_block_base <= sub_seq[0] < special_block_base + block_size**3: + sub_seq[0] -= special_block_base + sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size) + start_idx = i + + if len(sub_seq): + center, sub_seq = sub_seq[0], sub_seq[1:] + for j in range(len(sub_seq) - 1): + vertices.extend([center.reshape(1, 3), sub_seq[j].reshape(1, 3), sub_seq[j+1].reshape(1, 3)]) + + # (nf, 3) + return np.concatenate(vertices, axis=0) + + +if __name__ == '__main__': + # a simple demo for serialize and deserialize mesh with bpt + from data_utils import load_process_mesh, to_mesh + import torch + mesh = load_process_mesh('/path/to/your/mesh', quantization_bits=7) + mesh['faces'] = np.array(mesh['faces']) + mesh = to_mesh(mesh['vertices'], mesh['faces'], transpose=True) + mesh.export('gt.obj') + codes = BPT_serialize(mesh) + coordinates = BPT_deserialize(codes) + faces = torch.arange(1, len(coordinates) + 1).view(-1, 3) + mesh = to_mesh(coordinates, faces, transpose=False, post_process=False) + mesh.export('reconstructed.obj') diff --git a/hy3dgen/shapegen/bpt/requirements.txt b/hy3dgen/shapegen/bpt/requirements.txt new file mode 100644 index 0000000..3769a05 --- /dev/null +++ b/hy3dgen/shapegen/bpt/requirements.txt @@ -0,0 +1,30 @@ +meshgpt_pytorch==0.6.7 +pytorch-custom-utils==0.0.21 +accelerate>=0.25.0 +beartype +classifier-free-guidance-pytorch==0.5.1 +einops>=0.7.0 +ema-pytorch +pytorch-warmup +torch_geometric +torchtyping +vector-quantize-pytorch==1.12.8 +x-transformers==1.26.6 +tqdm +matplotlib +wandb +pyrr +trimesh +opencv-python +pyrender +open3d-python +easydict +chardet +deepspeed +omegaconf +scikit-image +setuptools +pytorch_lightning +mesh2sdf +numpy +point-cloud-utils \ No newline at end of file diff --git a/hy3dgen/shapegen/bpt/utils.py b/hy3dgen/shapegen/bpt/utils.py new file mode 100644 index 0000000..48a5101 --- /dev/null +++ b/hy3dgen/shapegen/bpt/utils.py @@ -0,0 +1,86 @@ +import trimesh +import numpy as np +from x_transformers.autoregressive_wrapper import top_p, top_k + + +class Dataset: + ''' + A toy dataset for inference + ''' + def __init__(self, input_type, input_list): + super().__init__() + self.data = [] + if input_type == 'pc_normal': + for input_path in input_list: + # load npy + cur_data = np.load(input_path) + # sample 4096 + assert cur_data.shape[0] >= 4096, "input pc_normal should have at least 4096 points" + idx = np.random.choice(cur_data.shape[0], 4096, replace=False) + cur_data = cur_data[idx] + self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) + + elif input_type == 'mesh': + mesh_list, pc_list = [], [] + for input_path in input_list: + # sample point cloud and normal from mesh + cur_data = trimesh.load(input_path, force='mesh') + cur_data = apply_normalize(cur_data) + mesh_list.append(cur_data) + pc_list.append(sample_pc(cur_data, pc_num=4096, with_normal=True)) + + for input_path, cur_data in zip(input_list, pc_list): + self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) + + print(f"dataset total data samples: {len(self.data)}") + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + data_dict = {} + data_dict['pc_normal'] = self.data[idx]['pc_normal'] + data_dict['uid'] = self.data[idx]['uid'] + + return data_dict + + +def joint_filter(logits, k = 50, p=0.95): + logits = top_k(logits, k = k) + logits = top_p(logits, thres = p) + return logits + + +def apply_normalize(mesh): + ''' + normalize mesh to [-1, 1] + ''' + bbox = mesh.bounds + center = (bbox[1] + bbox[0]) / 2 + scale = (bbox[1] - bbox[0]).max() + + mesh.apply_translation(-center) + mesh.apply_scale(1 / scale * 2 * 0.95) + + return mesh + + + +def sample_pc(trimesh, pc_num, with_normal=False): + mesh = apply_normalize(trimesh) + + if not with_normal: + points, _ = mesh.sample(pc_num, return_index=True) + return points + + points, face_idx = mesh.sample(50000, return_index=True) + normals = mesh.face_normals[face_idx] + pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16) + + # random sample point cloud + ind = np.random.choice(pc_normal.shape[0], pc_num, replace=False) + pc_normal = pc_normal[ind] + + return pc_normal + +