sjc / app.py
amankishore's picture
Updated app.py
7a11626
raw
history blame
No virus
5.27 kB
import numpy as np
import torch
from my.utils import tqdm
from my.utils.seed import seed_everything
from run_img_sampling import SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig
from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
from run_sjc import render_one_view
device_glb = torch.device("cuda")
@torch.no_grad()
def evaluate(score_model, vox, poser):
H, W = poser.H, poser.W
vox.eval()
K, poses = poser.sample_test(100)
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
num_imgs = len(poses)
for i in (pbar := tqdm(range(num_imgs))):
pose = poses[i]
y, depth = render_one_view(vox, aabb, H, W, K, pose)
if isinstance(score_model, StableDiffusion):
y = score_model.decode(y)
pane, img, depth = vis_routine(y, depth)
# metric.put_artifact(
# "view_seq", ".mp4",
# lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1])
# )
def vis_routine(y, depth):
pane = nerf_vis(y, depth, final_H=256)
im = torch_samps_to_imgs(y)[0]
depth = depth.cpu().numpy()
return pane, im, depth
if __name__ == "__main__":
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
poser = pose.make()
sd_model = SD(variant='v1', v2_highres=False, prompt='A high quality photo of a delicious burger', scale=100.0, precision='autocast')
model = sd_model.make()
vox = VoxConfig(
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
blend_bg_texture=True, bg_texture_hw=4,
bbox_len=1.0)
vox = vox.make()
lr = 0.05
n_steps = 10000
emptiness_scale = 10
emptiness_weight = 10000
emptiness_step = 0.5
emptiness_multiplier = 20.0
depth_weight = 0
var_red = True
assert model.samps_centered()
_, target_H, target_W = model.data_shape()
bs = 1
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
H, W = poser.H, poser.W
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
ts = model.us[30:-10]
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
with tqdm(total=n_steps) as pbar:
for i in range(n_steps):
p = f"{prompt_prefixes[i]} {model.prompt}"
score_conds = model.prompts_emb([p])
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
if isinstance(model, StableDiffusion):
pass
else:
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
opt.zero_grad()
with torch.no_grad():
chosen_σs = np.random.choice(ts, bs, replace=False)
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
# chosen_σs = us[i]
noise = torch.randn(bs, *y.shape[1:], device=model.device)
zs = y + chosen_σs * noise
Ds = model.denoise(zs, chosen_σs, **score_conds)
if var_red:
grad = (Ds - y) / chosen_σs
else:
grad = (Ds - zs) / chosen_σs
grad = grad.mean(0, keepdim=True)
y.backward(-grad, retain_graph=True)
if depth_weight > 0:
center_depth = depth[7:-7, 7:-7]
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
center_depth_mean = center_depth.mean()
depth_diff = center_depth_mean - border_depth_mean
depth_loss = - torch.log(depth_diff + 1e-12)
depth_loss = depth_weight * depth_loss
depth_loss.backward(retain_graph=True)
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
emptiness_loss = emptiness_weight * emptiness_loss
if emptiness_step * n_steps <= i:
emptiness_loss *= emptiness_multiplier
emptiness_loss.backward()
opt.step()
# metric.put_scalars(**tsr_stats(y))
if every(pbar, percent=1):
with torch.no_grad():
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
# TODO: Output pane, img and depth to Gradio
pbar.update()
pbar.set_description(p)
# TODO: Save Checkpoint
ckpt = vox.state_dict()
# evaluate(model, vox, poser)
# TODO: Add code to stitch together the images and save them to a video