interpolate with and without vmap

Created Diff never expires
108 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
98 lines
105 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
96 lines
HloModule jit_interpolate_backward.17
HloModule jit_interpolate_backward.17


%region_0.283 (Arg_0.284: f32[], Arg_1.285: f32[]) -> f32[] {
%region_0.286 (Arg_0.287: f32[], Arg_1.288: f32[]) -> f32[] {
%Arg_0.284 = f32[] parameter(0)
%Arg_0.287 = f32[] parameter(0)
%Arg_1.285 = f32[] parameter(1)
%Arg_1.288 = f32[] parameter(1)
ROOT %add.286 = f32[] add(f32[] %Arg_0.284, f32[] %Arg_1.285), metadata={op_name="jit(interpolate_backward)/jit(main)/reduce_sum[axes=(0, 1, 2)]" source_file="time_vmap_and_generate_hlo_just_interpolate.py" source_line=46}
ROOT %add.289 = f32[] add(f32[] %Arg_0.287, f32[] %Arg_1.288), metadata={op_name="jit(interpolate_backward)/jit(main)/reduce_sum[axes=(0, 1, 2)]" source_file="time_vmap_and_generate_hlo_just_interpolate.py" source_line=46}
}
}


%region_4.303 (Arg_0.304: f32[], Arg_1.305: f32[]) -> f32[] {
%region_4.306 (Arg_0.307: f32[], Arg_1.308: f32[]) -> f32[] {
%Arg_0.304 = f32[] parameter(0)
%Arg_0.307 = f32[] parameter(0)
%Arg_1.305 = f32[] parameter(1)
%Arg_1.308 = f32[] parameter(1)
ROOT %add.306 = f32[] add(f32[] %Arg_0.304, f32[] %Arg_1.305), metadata={op_name="add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
ROOT %add.309 = f32[] add(f32[] %Arg_0.307, f32[] %Arg_1.308), metadata={op_name="add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
}
}


%input_fused_computation_scatter (param_0.150: f32[3,48,4096,443], param_1.215: f32[3,4096,443], param_2.216: f32[3,4096,443]) -> f32[3,48,128,128] {
%input_fused_computation_scatter (param_0.154: f32[3,4096,48,443], param_1.215: f32[3,4096,443], param_2.216: f32[3,4096,443]) -> f32[3,48,128,128] {
%constant_54 = f32[] constant(0)
%constant_54 = f32[] constant(0)
%broadcast.169 = f32[3,48,128,128]{3,2,1,0} broadcast(f32[] %constant_54), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(_map_coordinates)))))/jit(jit_transpose(jvp(vmap(vmap(_map_coordinates)))))/broadcast_in_dim[shape=(3, 48, 128, 128) broadcast_dimensions=()]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.171 = f32[3,48,128,128]{3,2,1,0} broadcast(f32[] %constant_54), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/jit(jit_transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/broadcast_in_dim[shape=(3, 48, 128, 128) broadcast_dimensions=()]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%iota.56 = s32[3,4096,443,1]{2,1,0,3} iota(), iota_dimension=0, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/iota[dtype=int32 shape=(3, 4096, 443, 1) dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%iota.56 = s32[3,4096,443,1]{2,1,0,3} iota(), iota_dimension=0, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/iota[dtype=int32 shape=(3, 4096, 443, 1) dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_387 = s32[] constant(0)
%constant_387 = s32[] constant(0)
%broadcast.511 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_387), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/max" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.512 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_387), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/max" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%param_2.216 = f32[3,4096,443]{2,1,0} parameter(2)
%param_2.216 = f32[3,4096,443]{2,1,0} parameter(2)
%floor.133 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_2.216), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%floor.133 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_2.216), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.76 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.133), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.76 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.133), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_386 = s32[] constant(127)
%constant_386 = s32[] constant(127)
%broadcast.510 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_386), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.511 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_386), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.75 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.511, s32[3,4096,443]{2,1,0} %convert.76, s32[3,4096,443]{2,1,0} %broadcast.510), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.75 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.512, s32[3,4096,443]{2,1,0} %convert.76, s32[3,4096,443]{2,1,0} %broadcast.511), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.93 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.75, s32[3,4096,443]{2,1,0} %broadcast.511), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.93 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.75, s32[3,4096,443]{2,1,0} %broadcast.512), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_385 = s32[] constant(128)
%constant_385 = s32[] constant(128)
%broadcast.509 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_385), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.510 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_385), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.180 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.75, s32[3,4096,443]{2,1,0} %broadcast.509), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.180 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.75, s32[3,4096,443]{2,1,0} %broadcast.510), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.93 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.93, s32[3,4096,443]{2,1,0} %add.180, s32[3,4096,443]{2,1,0} %clamp.75), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.93 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.93, s32[3,4096,443]{2,1,0} %add.180, s32[3,4096,443]{2,1,0} %clamp.75), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.107 = s32[3,4096,443,1]{2,1,0,3} bitcast(s32[3,4096,443]{2,1,0} %select.93), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.107 = s32[3,4096,443,1]{2,1,0,3} bitcast(s32[3,4096,443]{2,1,0} %select.93), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%param_1.215 = f32[3,4096,443]{2,1,0} parameter(1)
%param_1.215 = f32[3,4096,443]{2,1,0} parameter(1)
%floor.132 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_1.215), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%floor.132 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_1.215), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.75 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.132), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.75 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.132), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.74 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.511, s32[3,4096,443]{2,1,0} %convert.75, s32[3,4096,443]{2,1,0} %broadcast.510), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.74 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.512, s32[3,4096,443]{2,1,0} %convert.75, s32[3,4096,443]{2,1,0} %broadcast.511), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.92 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.74, s32[3,4096,443]{2,1,0} %broadcast.511), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.92 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.74, s32[3,4096,443]{2,1,0} %broadcast.512), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.179 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.74, s32[3,4096,443]{2,1,0} %broadcast.509), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.179 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.74, s32[3,4096,443]{2,1,0} %broadcast.510), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.92 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.92, s32[3,4096,443]{2,1,0} %add.179, s32[3,4096,443]{2,1,0} %clamp.74), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.92 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.92, s32[3,4096,443]{2,1,0} %add.179, s32[3,4096,443]{2,1,0} %clamp.74), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.106 = s32[3,4096,443,1]{2,1,0,3} bitcast(s32[3,4096,443]{2,1,0} %select.92), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.106 = s32[3,4096,443,1]{2,1,0,3} bitcast(s32[3,4096,443]{2,1,0} %select.92), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%concatenate.58 = s32[3,4096,443,3]{2,1,0,3} concatenate(s32[3,4096,443,1]{2,1,0,3} %iota.56, s32[3,4096,443,1]{2,1,0,3} %bitcast.107, s32[3,4096,443,1]{2,1,0,3} %bitcast.106), dimensions={3}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/concatenate[dimension=3]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%concatenate.58 = s32[3,4096,443,3]{2,1,0,3} concatenate(s32[3,4096,443,1]{2,1,0,3} %iota.56, s32[3,4096,443,1]{2,1,0,3} %bitcast.107, s32[3,4096,443,1]{2,1,0,3} %bitcast.106), dimensions={3}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/concatenate[dimension=3]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%copy.8 = s32[3,4096,443,3]{3,2,1,0} copy(s32[3,4096,443,3]{2,1,0,3} %concatenate.58), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/concatenate[dimension=3]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%copy.8 = s32[3,4096,443,3]{3,2,1,0} copy(s32[3,4096,443,3]{2,1,0,3} %concatenate.58), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/concatenate[dimension=3]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_364 = f32[] constant(1)
%constant_364 = f32[] constant(1)
%broadcast.487 = f32[3,4096,443]{2,1,0} broadcast(f32[] %constant_364), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.488 = f32[3,4096,443]{2,1,0} broadcast(f32[] %constant_364), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.97 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_2.216, f32[3,4096,443]{2,1,0} %floor.133), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.97 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_2.216, f32[3,4096,443]{2,1,0} %floor.133), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.96 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %broadcast.487, f32[3,4096,443]{2,1,0} %subtract.97), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.96 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %broadcast.488, f32[3,4096,443]{2,1,0} %subtract.97), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.95 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_1.215, f32[3,4096,443]{2,1,0} %floor.132), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.95 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_1.215, f32[3,4096,443]{2,1,0} %floor.132), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.94 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %broadcast.487, f32[3,4096,443]{2,1,0} %subtract.95), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.94 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %broadcast.488, f32[3,4096,443]{2,1,0} %subtract.95), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%multiply.93 = f32[3,4096,443]{2,1,0} multiply(f32[3,4096,443]{2,1,0} %subtract.96, f32[3,4096,443]{2,1,0} %subtract.94), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/mul" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%multiply.94 = f32[3,4096,443]{2,1,0} multiply(f32[3,4096,443]{2,1,0} %subtract.96, f32[3,4096,443]{2,1,0} %subtract.94), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/mul" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.486 = f32[3,48,4096,443]{3,2,1,0} broadcast(f32[3,4096,443]{2,1,0} %multiply.93), dimensions={0,2,3}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/mul" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.487 = f32[3,4096,48,443]{3,2,1,0} broadcast(f32[3,4096,443]{2,1,0} %multiply.94), dimensions={0,1,3}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/mul" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%param_0.150 = f32[3,48,4096,443]{3,2,1,0} parameter(0)
%param_0.154 = f32[3,4096,48,443]{3,2,1,0} parameter(0)
%multiply.40 = f32[3,48,4096,443]{3,2,1,0} multiply(f32[3,48,4096,443]{3,2,1,0} %broadcast.486, f32[3,48,4096,443]{3,2,1,0} %param_0.150), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(_map_coordinates)))))/jit(jit_transpose(jvp(vmap(vmap(_map_coordinates)))))/mul" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%multiply.41 = f32[3,4096,48,443]{3,2,1,0} multiply(f32[3,4096,48,443]{3,2,1,0} %broadcast.487, f32[3,4096,48,443]{3,2,1,0} %param_0.154), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/jit(jit_transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/mul" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
ROOT %scatter.12 = f32[3,48,128,128]{3,2,1,0} scatter(f32[3,48,128,128]{3,2,1,0} %broadcast.169, s32[3,4096,443,3]{3,2,1,0} %copy.8, f32[3,48,4096,443]{3,2,1,0} %multiply.40), update_window_dims={1}, inserted_window_dims={0,2,3}, scatter_dims_to_operand_dims={0,2,3}, index_vector_dim=3, to_apply=%region_4.303, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(_map_coordinates)))))/jit(jit_transpose(jvp(vmap(vmap(_map_coordinates)))))/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0, 2, 3), scatter_dims_to_operand_dims=(0, 2, 3)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
ROOT %scatter.12 = f32[3,48,128,128]{3,2,1,0} scatter(f32[3,48,128,128]{3,2,1,0} %broadcast.171, s32[3,4096,443,3]{3,2,1,0} %copy.8, f32[3,4096,48,443]{3,2,1,0} %multiply.41), update_window_dims={2}, inserted_window_dims={0,2,3}, scatter_dims_to_operand_dims={0,2,3}, index_vector_dim=3, to_apply=%region_4.306, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/jit(jit_transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/scatter-add[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(2,), inserted_window_dims=(0, 2, 3), scatter_dims_to_operand_dims=(0, 2, 3)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
}
}


%region_3.299 (Arg_0.300: f32[], Arg_1.301: f32[]) -> f32[] {
%region_3.302 (Arg_0.303: f32[], Arg_1.304: f32[]) -> f32[] {
%Arg_0.300 = f32[] parameter(0)
%Arg_0.303 = f32[] parameter(0)
%Arg_1.301 = f32[] parameter(1)
%Arg_1.304 = f32[] parameter(1)
ROOT %add.302 = f32[] add(f32[] %Arg_0.300, f32[] %Arg_1.301), metadata={op_name="add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
ROOT %add.305 = f32[] add(f32[] %Arg_0.303, f32[] %Arg_1.304), metadata={op_name="add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
}
}


%input_fused_computation_scatter.1 (param_0.147: f32[3,48,4096,443], param_1.207: f32[3,48,128,128], param_2.209: f32[3,48,128,128], param_3.173: f32[3,4096,443], param_4.100: f32[3,4096,443]) -> f32[3,48,128,128] {
%input_fused_computation_scatter.1 (param_0.151: f32[3,4096,48,443], param_1.207: f32[3,48,128,128], param_2.209: f32[3,48,128,128], param_3.173: f32[3,4096,443], param_4.100: f32[3,4096,443]) -> f32[3,48,128,128] {
%param_1.207 = f32[3,48,128,128]{3,2,1,0} parameter(1)
%param_1.207 = f32[3,48,128,128]{3,2,1,0} parameter(1)
%param_2.209 = f32[3,48,128,128]{3,2,1,0} parameter(2)
%param_2.209 = f32[3,48,128,128]{3,2,1,0} parameter(2)
%add.50 = f32[3,48,128,128]{3,2,1,0} add(f32[3,48,128,128]{3,2,1,0} %param_1.207, f32[3,48,128,128]{3,2,1,0} %param_2.209), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(_map_coordinates)))))/jit(jit_transpose(jvp(vmap(vmap(_map_coordinates)))))/add_any" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.50 = f32[3,48,128,128]{3,2,1,0} add(f32[3,48,128,128]{3,2,1,0} %param_1.207, f32[3,48,128,128]{3,2,1,0} %param_2.209), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/jit(jit_transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/add_any" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%iota.52 = s32[3,4096,443,1]{3,2,1,0} iota(), iota_dimension=0, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/iota[dtype=int32 shape=(3, 4096, 443, 1) dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%iota.52 = s32[3,4096,443,1]{3,2,1,0} iota(), iota_dimension=0, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/iota[dtype=int32 shape=(3, 4096, 443, 1) dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_354 = s32[] constant(0)
%constant_354 = s32[] constant(0)
%broadcast.479 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_354), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/max" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.480 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_354), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/max" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%param_4.100 = f32[3,4096,443]{2,1,0} parameter(4)
%param_4.100 = f32[3,4096,443]{2,1,0} parameter(4)
%floor.115 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_4.100), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%floor.115 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_4.100), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.68 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.115), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.68 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.115), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_352 = s32[] constant(127)
%constant_352 = s32[] constant(127)
%broadcast.478 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_352), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.479 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_352), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.67 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.479, s32[3,4096,443]{2,1,0} %convert.68, s32[3,4096,443]{2,1,0} %broadcast.478), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.67 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.480, s32[3,4096,443]{2,1,0} %convert.68, s32[3,4096,443]{2,1,0} %broadcast.479), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.85 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.67, s32[3,4096,443]{2,1,0} %broadcast.479), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.85 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.67, s32[3,4096,443]{2,1,0} %broadcast.480), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_351 = s32[] constant(128)
%constant_351 = s32[] constant(128)
%broadcast.477 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_351), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.478 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_351), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.170 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.67, s32[3,4096,443]{2,1,0} %broadcast.477), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.170 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.67, s32[3,4096,443]{2,1,0} %broadcast.478), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.85 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.85, s32[3,4096,443]{2,1,0} %add.170, s32[3,4096,443]{2,1,0} %clamp.67), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.85 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.85, s32[3,4096,443]{2,1,0} %add.170, s32[3,4096,443]{2,1,0} %clamp.67), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.99 = s32[3,4096,443,1]{3,2,1,0} bitcast(s32[3,4096,443]{2,1,0} %select.85), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.99 = s32[3,4096,443,1]{3,2,1,0} bitcast(s32[3,4096,443]{2,1,0} %select.85), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%param_3.173 = f32[3,4096,443]{2,1,0} parameter(3)
%param_3.173 = f32[3,4096,443]{2,1,0} parameter(3)
%floor.114 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_3.173), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%floor.114 = f32[3,4096,443]{2,1,0} floor(f32[3,4096,443]{2,1,0} %param_3.173), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/floor" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.67 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.114), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%convert.67 = s32[3,4096,443]{2,1,0} convert(f32[3,4096,443]{2,1,0} %floor.114), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_353 = s32[] constant(1)
%constant_353 = s32[] constant(1)
%broadcast.475 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_353), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.476 = s32[3,4096,443]{2,1,0} broadcast(s32[] %constant_353), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.169 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %convert.67, s32[3,4096,443]{2,1,0} %broadcast.475), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.169 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %convert.67, s32[3,4096,443]{2,1,0} %broadcast.476), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.66 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.479, s32[3,4096,443]{2,1,0} %add.169, s32[3,4096,443]{2,1,0} %broadcast.478), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/jit(vmap(clip))/jit(jit_vmap(clip))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%clamp.66 = s32[3,4096,443]{2,1,0} clamp(s32[3,4096,443]{2,1,0} %broadcast.480, s32[3,4096,443]{2,1,0} %add.169, s32[3,4096,443]{2,1,0} %broadcast.479), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(vmap(vmap(clip)))/jit(jit_vmap(vmap(clip)))/min" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.84 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.66, s32[3,4096,443]{2,1,0} %broadcast.479), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%compare.84 = pred[3,4096,443]{2,1,0} compare(s32[3,4096,443]{2,1,0} %clamp.66, s32[3,4096,443]{2,1,0} %broadcast.480), direction=LT, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/lt" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.168 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.66, s32[3,4096,443]{2,1,0} %broadcast.477), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.168 = s32[3,4096,443]{2,1,0} add(s32[3,4096,443]{2,1,0} %clamp.66, s32[3,4096,443]{2,1,0} %broadcast.478), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/add" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.84 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.84, s32[3,4096,443]{2,1,0} %add.168, s32[3,4096,443]{2,1,0} %clamp.66), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%select.84 = s32[3,4096,443]{2,1,0} select(pred[3,4096,443]{2,1,0} %compare.84, s32[3,4096,443]{2,1,0} %add.168, s32[3,4096,443]{2,1,0} %clamp.66), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/select_n" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.98 = s32[3,4096,443,1]{3,2,1,0} bitcast(s32[3,4096,443]{2,1,0} %select.84), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%bitcast.98 = s32[3,4096,443,1]{3,2,1,0} bitcast(s32[3,4096,443]{2,1,0} %select.84), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/broadcast_in_dim[shape=(3, 4096, 443, 1) broadcast_dimensions=(0, 1, 2)]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%concatenate.54 = s32[3,4096,443,3]{3,2,1,0} concatenate(s32[3,4096,443,1]{3,2,1,0} %iota.52, s32[3,4096,443,1]{3,2,1,0} %bitcast.99, s32[3,4096,443,1]{3,2,1,0} %bitcast.98), dimensions={3}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/concatenate[dimension=3]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%concatenate.54 = s32[3,4096,443,3]{3,2,1,0} concatenate(s32[3,4096,443,1]{3,2,1,0} %iota.52, s32[3,4096,443,1]{3,2,1,0} %bitcast.99, s32[3,4096,443,1]{3,2,1,0} %bitcast.98), dimensions={3}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/concatenate[dimension=3]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_312 = f32[] constant(1)
%constant_312 = f32[] constant(1)
%broadcast.439 = f32[3,4096,443]{2,1,0} broadcast(f32[] %constant_312), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.440 = f32[3,4096,443]{2,1,0} broadcast(f32[] %constant_312), dimensions={}, metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.81 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_4.100, f32[3,4096,443]{2,1,0} %floor.115), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.81 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_4.100, f32[3,4096,443]{2,1,0} %floor.115), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.80 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %broadcast.439, f32[3,4096,443]{2,1,0} %subtract.81), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%subtract.80 = f32[3,4096,443]{2,1,0}
%subtract.79 = f32[3,4096,443]{2,1,0} subtract(f32[3,4096,443]{2,1,0} %param_3.173, f32[3,4096,443]{2,1,0} %floor.114), metadata={op_name="jit(interpolate_backward)/jit(main)/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(_map_coordinates))))/jit(jit_jvp(vmap(vmap(_map_coordinates))))/sub" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%