import torch def unet_add_concat_conds(unet, new_channels=4): with torch.no_grad(): new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) new_conv_in.bias = unet.conv_in.bias unet.conv_in = new_conv_in unet_original_forward = unet.forward def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()} c_concat = cross_attention_kwargs.pop('concat_conds') kwargs['cross_attention_kwargs'] = cross_attention_kwargs c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample) new_sample = torch.cat([sample, c_concat], dim=1) return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) unet.forward = hooked_unet_forward return