training_step with and without vmap

Created Diff never expires
67 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
151 lines
67 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
151 lines
HloModule jit_training_step.89, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias) }
HloModule jit_training_step.89, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias) }


%fused_computation (param_0.2: s32[], param_1.3: s32[], param_2.6: f32[], param_3.10: f32[]) -> pred[] {
%fused_computation (param_0.2: s32[], param_1.3: s32[], param_2.6: f32[], param_3.10: f32[]) -> pred[] {
%param_3.10 = f32[] parameter(3)
%param_3.10 = f32[] parameter(3)
%compare.128 = pred[] compare(f32[] %param_3.10, f32[] %param_3.10), direction=NE, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/ne" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.128 = pred[] compare(f32[] %param_3.10, f32[] %param_3.10), direction=NE, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/ne" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%constant_452 = s32[] constant(2143289344)
%constant_452 = s32[] constant(2143289344)
%constant_451 = f32[] constant(0)
%constant_451 = f32[] constant(0)
%compare.127 = pred[] compare(f32[] %param_3.10, f32[] %constant_451), direction=EQ, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/eq" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.127 = pred[] compare(f32[] %param_3.10, f32[] %constant_451), direction=EQ, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/eq" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%constant_450 = s32[] constant(0)
%constant_450 = s32[] constant(0)
%bitcast-convert.19 = s32[] bitcast-convert(f32[] %param_3.10), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.19 = s32[] bitcast-convert(f32[] %param_3.10), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.121 = s32[] select(pred[] %compare.127, s32[] %constant_450, s32[] %bitcast-convert.19), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.121 = s32[] select(pred[] %compare.127, s32[] %constant_450, s32[] %bitcast-convert.19), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.120 = s32[] select(pred[] %compare.128, s32[] %constant_452, s32[] %select.121), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.120 = s32[] select(pred[] %compare.128, s32[] %constant_452, s32[] %select.121), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.126 = pred[] compare(s32[] %select.120, s32[] %constant_450), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.126 = pred[] compare(s32[] %select.120, s32[] %constant_450), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%constant_449 = u32[] constant(2147483647)
%constant_449 = u32[] constant(2147483647)
%bitcast-convert.18 = u32[] bitcast-convert(f32[] %param_3.10), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=uint32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.18 = u32[] bitcast-convert(f32[] %param_3.10), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=uint32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%subtract.62 = u32[] subtract(u32[] %constant_449, u32[] %bitcast-convert.18), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/sub" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%subtract.62 = u32[] subtract(u32[] %constant_449, u32[] %bitcast-convert.18), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/sub" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.17 = s32[] bitcast-convert(u32[] %subtract.62), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.17 = s32[] bitcast-convert(u32[] %subtract.62), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.119 = s32[] select(pred[] %compare.126, s32[] %bitcast-convert.17, s32[] %select.120), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.119 = s32[] select(pred[] %compare.126, s32[] %bitcast-convert.17, s32[] %select.120), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%param_2.6 = f32[] parameter(2)
%param_2.6 = f32[] parameter(2)
%compare.125 = pred[] compare(f32[] %param_2.6, f32[] %param_2.6), direction=NE, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/ne" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.125 = pred[] compare(f32[] %param_2.6, f32[] %param_2.6), direction=NE, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/ne" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.124 = pred[] compare(f32[] %param_2.6, f32[] %constant_451), direction=EQ, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/eq" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.124 = pred[] compare(f32[] %param_2.6, f32[] %constant_451), direction=EQ, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/eq" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.16 = s32[] bitcast-convert(f32[] %param_2.6), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.16 = s32[] bitcast-convert(f32[] %param_2.6), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.118 = s32[] select(pred[] %compare.124, s32[] %constant_450, s32[] %bitcast-convert.16), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.118 = s32[] select(pred[] %compare.124, s32[] %constant_450, s32[] %bitcast-convert.16), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.117 = s32[] select(pred[] %compare.125, s32[] %constant_452, s32[] %select.118), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.117 = s32[] select(pred[] %compare.125, s32[] %constant_452, s32[] %select.118), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.123 = pred[] compare(s32[] %select.117, s32[] %constant_450), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.123 = pred[] compare(s32[] %select.117, s32[] %constant_450), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.15 = u32[] bitcast-convert(f32[] %param_2.6), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=uint32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.15 = u32[] bitcast-convert(f32[] %param_2.6), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=uint32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%subtract.61 = u32[] subtract(u32[] %constant_449, u32[] %bitcast-convert.15), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/sub" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%subtract.61 = u32[] subtract(u32[] %constant_449, u32[] %bitcast-convert.15), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/sub" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.14 = s32[] bitcast-convert(u32[] %subtract.61), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%bitcast-convert.14 = s32[] bitcast-convert(u32[] %subtract.61), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/bitcast_convert_type[new_dtype=int32]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.116 = s32[] select(pred[] %compare.123, s32[] %bitcast-convert.14, s32[] %select.117), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%select.116 = s32[] select(pred[] %compare.123, s32[] %bitcast-convert.14, s32[] %select.117), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/select_n" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.122 = pred[] compare(s32[] %select.119, s32[] %select.116), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.122 = pred[] compare(s32[] %select.119, s32[] %select.116), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.121 = pred[] compare(s32[] %select.116, s32[] %select.119), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.121 = pred[] compare(s32[] %select.116, s32[] %select.119), direction=LT, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(vmap(argsort)))/jit(jit_jvp(vmap(argsort)))/lt" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=410}
%compare.120 = pred[] compare(pred[] %compare.122, pred[] %compare.121), direction=EQ
%compare.120 = pred[] compare(pred[] %compare.122, pred[] %compare.121), direction=EQ
%param_0.2 = s32[] parameter(0)
%param_0.2 = s32[] parameter(0)
%param_1.3 = s32[] parameter(1)
%param_1.3 = s32[] parameter(1)
%compare.119 = pred[] compare(s32[] %param_0.2, s32[] %param_1.3), direction=LT
%compare.119 = pred[] compare(s32[] %param_0.2, s32[] %param_1.3), direction=LT
ROOT %select.115 = pred[] select(pred[] %compare.120, pred[] %compare.119, pred[] %compare.122)
ROOT %select.115 = pred[] select(pred[] %compare.120, pred[] %compare.119, pred[] %compare.122)
}
}


%region_5.625.clone (Arg_0.0: f32[], Arg_1.0: f32[], Arg_2.0: s32[], Arg_3.0: s32[], p.2.lhs.clone: s32[], p.2.rhs.clone: s32[]) -> pred[] {
%region_5.628.clone (Arg_0.0: f32[], Arg_1.0: f32[], Arg_2.0: s32[], Arg_3.0: s32[], p.2.lhs.clone: s32[], p.2.rhs.clone: s32[]) -> pred[] {
%Arg_2.0 = s32[] parameter(2)
%Arg_2.0 = s32[] parameter(2)
%Arg_3.0 = s32[] parameter(3)
%Arg_3.0 = s32[] parameter(3)
%p.2.lhs.clone = s32[] parameter(4)
%p.2.lhs.clone = s32[] parameter(4)
%p.2.rhs.clone = s32[] parameter(5)
%p.2.rhs.clone = s32[] parameter(5)
%Arg_1.0 = f32[] parameter(1)
%Arg_1.0 = f32[] parameter(1)
%Arg_0.0 = f32[] parameter(0)
%Arg_0.0 = f32[] parameter(0)
ROOT %fusion = pred[] fusion(s32[] %p.2.lhs.clone, s32[] %p.2.rhs.clone, f32[] %Arg_1.0, f32[] %Arg_0.0), kind=kLoop, calls=%fused_computation
ROOT %fusion = pred[] fusion(s32[] %p.2.lhs.clone, s32[] %p.2.rhs.clone, f32[] %Arg_1.0, f32[] %Arg_0.0), kind=kLoop, calls=%fused_computation
}
}


%region_27.1710 (Arg_0.1711: f32[], Arg_1.1712: f32[]) -> f32[] {
%region_27.1717 (Arg_0.1718: f32[], Arg_1.1719: f32[]) -> f32[] {
%Arg_0.1711 = f32[] parameter(0)
%Arg_0.1718 = f32[] parameter(0)
%Arg_1.1712 = f32[] parameter(1)
%Arg_1.1719 = f32[] parameter(1)
ROOT %add.1713 = f32[] add(f32[] %Arg_0.1711, f32[] %Arg_1.1712), metadata={op_name="jit(training_step)/jit(main)/jit(transpose(jvp(render_rays)))/jit(jit_transpose(jvp(render_rays)))/reduce_sum[axes=(0,)]" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/flax/linen/linear.py" source_line=195}
ROOT %add.1720 = f32[] add(f32[] %Arg_0.1718, f32[] %Arg_1.1719), metadata={op_name="jit(training_step)/jit(main)/jit(transpose(jvp(render_rays)))/jit(jit_transpose(jvp(render_rays)))/reduce_sum[axes=(0,)]" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/flax/linen/linear.py" source_line=195}
}
}


%fused_computation.1 (param_0.363: f32[]) -> (f32[], f32[]) {
%fused_computation.1 (param_0.369: f32[]) -> (f32[], f32[]) {
%param_0.363 = f32[] parameter(0)
%param_0.369 = f32[] parameter(0)
%constant_471 = f32[] constant(8.13802108e-05)
%constant_471 = f32[] constant(8.13802108e-05)
%multiply.396 = f32[] multiply(f32[] %param_0.363, f32[] %constant_471), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
%multiply.396 = f32[] multiply(f32[] %param_0.369, f32[] %constant_471), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
%constant_454 = f32[] constant(1e-10)
%constant_454 = f32[] constant(1e-10)
%maximum.59 = f32[] maximum(f32[] %multiply.396, f32[] %constant_454), metadata={op_name="jit(training_step)/jit(main)/max" source_file="/home/brent/tensorf-jax/tensorf/utils.py" source_line=8}
%maximum.59 = f32[] maximum(f32[] %multiply.396, f32[] %constant_454), metadata={op_name="jit(training_step)/jit(main)/max" source_file="/home/brent/tensorf-jax/tensorf/utils.py" source_line=8}
%log.3 = f32[] log(f32[] %maximum.59), metadata={op_name="jit(training_step)/jit(main)/log" source_file="/home/brent/tensorf-jax/tensorf/utils.py" source_line=17}
%log.3 = f32[] log(f32[] %maximum.59), metadata={op_name="jit(training_step)/jit(main)/log" source_file="/home/brent/tensorf-jax/tensorf/utils.py" source_line=17}
%constant_453 = f32[] constant(-4.34294462)
%constant_453 = f32[] constant(-4.34294462)
%multiply.232 = f32[] multiply(f32[] %log.3, f32[] %constant_453), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/utils.py" source_line=17}
%multiply.232 = f32[] multiply(f32[] %log.3, f32[] %constant_453), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/utils.py" source_line=17}
ROOT %tuple.44 = (f32[], f32[]) tuple(f32[] %multiply.232, f32[] %multiply.396)
ROOT %tuple.44 = (f32[], f32[]) tuple(f32[] %multiply.232, f32[] %multiply.396)
}
}


%region_10.1466 (Arg_0.1467: f32[], Arg_1.1468: f32[]) -> f32[] {
%region_10.1472 (Arg_0.1473: f32[], Arg_1.1474: f32[]) -> f32[] {
%Arg_0.1467 = f32[] parameter(0)
%Arg_0.1473 = f32[] parameter(0)
%Arg_1.1468 = f32[] parameter(1)
%Arg_1.1474 = f32[] parameter(1)
ROOT %add.1469 = f32[] add(f32[] %Arg_0.1467, f32[] %Arg_1.1468), metadata={op_name="jit(training_step)/jit(main)/reduce_sum[axes=(0, 1)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
ROOT %add.1475 = f32[] add(f32[] %Arg_0.1473, f32[] %Arg_1.1474), metadata={op_name="jit(training_step)/jit(main)/reduce_sum[axes=(0, 1)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
}
}


%input_fused_computation_reduce (param_0.364: f32[4096,3]) -> f32[] {
%input_fused_computation_reduce (param_0.370: f32[4096,3]) -> f32[] {
%param_0.364 = f32[4096,3]{1,0} parameter(0)
%param_0.370 = f32[4096,3]{1,0} parameter(0)
%multiply.233 = f32[4096,3]{1,0} multiply(f32[4096,3]{1,0} %param_0.364, f32[4096,3]{1,0} %param_0.364), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
%multiply.233 = f32[4096,3]{1,0} multiply(f32[4096,3]{1,0} %param_0.370, f32[4096,3]{1,0} %param_0.370), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
%bitcast.74 = f32[12288]{0} bitcast(f32[4096,3]{1,0} %multiply.233)
%bitcast.73 = f32[12288]{0} bitcast(f32[4096,3]{1,0} %multiply.233)
%constant_472 = f32[] constant(0)
%constant_472 = f32[] constant(0)
ROOT %reduce.36 = f32[] reduce(f32[12288]{0} %bitcast.74, f32[] %constant_472), dimensions={0}, to_apply=%region_10.1466, metadata={op_name="jit(training_step)/jit(main)/reduce_sum[axes=(0, 1)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
ROOT %reduce.36 = f32[] reduce(f32[12288]{0} %bitcast.73, f32[] %constant_472), dimensions={0}, to_apply=%region_10.1472, metadata={op_name="jit(training_step)/jit(main)/reduce_sum[axes=(0, 1)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=291}
}
}


%fused_computation.2 (param_0.12: u32[2], param_1.10: u32[2]) -> (u32[2], u32[2], u32[2]) {
%fused_computation.2 (param_0.12: u32[2], param_1.10: u32[2]) -> (u32[2], u32[2], u32[2]) {
%param_0.12 = u32[2]{0} parameter(0)
%param_0.12 = u32[2]{0} parameter(0)
%param_1.10 = u32[2]{0} parameter(1)
%param_1.10 = u32[2]{0} parameter(1)
%concatenate.100 = u32[4]{0} concatenate(u32[2]{0} %param_0.12, u32[2]{0} %param_1.10), dimensions={0}, metadata={op_name="jit(training_step)/jit(main)/concatenate[dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%concatenate.100 = u32[4]{0} concatenate(u32[2]{0} %param_0.12, u32[2]{0} %param_1.10), dimensions={0}, metadata={op_name="jit(training_step)/jit(main)/concatenate[dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%bitcast.76 = u32[2,2]{1,0} bitcast(u32[4]{0} %concatenate.100), metadata={op_name="jit(training_step)/jit(main)/reshape[new_sizes=(2, 2) dimensions=None]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%bitcast.75 = u32[2,2]{1,0} bitcast(u32[4]{0} %concatenate.100), metadata={op_name="jit(training_step)/jit(main)/reshape[new_sizes=(2, 2) dimensions=None]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%slice.274 = u32[1,2]{1,0} slice(u32[2,2]{1,0} %bitcast.76), slice={[1:2], [0:2]}, metadata={op_name="jit(training_step)/jit(main)/slice[start_indices=(1, 0) limit_indices=(2, 2) strides=(1, 1)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%slice.274 = u32[1,2]{1,0} slice(u32[2,2]{1,0} %bitcast.75), slice={[1:2], [0:2]}, metadata={op_name="jit(training_step)/jit(main)/slice[start_indices=(1, 0) limit_indices=(2, 2) strides=(1, 1)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%bitcast.75 = u32[2]{0} bitcast(u32[1,2]{1,0} %slice.274), metadata={op_name="jit(training_step)/jit(main)/squeeze[dimensions=(0,)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%bitcast.74 = u32[2]{0} bitcast(u32[1,2]{1,0} %slice.274), metadata={op_name="jit(training_step)/jit(main)/squeeze[dimensions=(0,)]" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=236}
%slice.358.clone.1 = u32[1]{0} slice(u32[2]{0} %param_0.12), slice={[1:2]}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%slice.358.clone.1 = u32[1]{0} slice(u32[2]{0} %param_0.12), slice={[1:2]}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%bitcast.135.clone.1 = u32[] bitcast(u32[1]{0} %slice.358.clone.1), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/squeeze[dimensions=(0,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%bitcast.133.clone.1 = u32[] bitcast(u32[1]{0} %slice.358.clone.1), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/squeeze[dimensions=(0,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%broadcast.954.clone.1 = u32[2]{0} broadcast(u32[] %bitcast.135.clone.1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/threefry2x32" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%broadcast.954.clone.1 = u32[2]{0} broadcast(u32[] %bitcast.133.clone.1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/threefry2x32" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%slice.360.clone.1 = u32[1]{0} slice(u32[2]{0} %param_0.12), slice={[0:1]}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%slice.360.clone.1 = u32[1]{0} slice(u32[2]{0} %param_0.12), slice={[0:1]}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%bitcast.136.clone.1 = u32[] bitcast(u32[1]{0} %slice.360.clone.1), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/squeeze[dimensions=(0,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%bitcast.134.clone.1 = u32[] bitcast(u32[1]{0} %slice.360.clone.1), metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/squeeze[dimensions=(0,)]" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%broadcast.955.clone.1 = u32[2]{0} broadcast(u32[] %bitcast.136.clone.1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/threefry2x32" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
%broadcast.955.clone.1 = u32[2]{0} broadcast(u32[] %bitcast.134.clone.1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/threefry2x32" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=112}
ROOT %tuple.87 = (u32[2]{0}, u32[2]{0}, u32[2]{0}) tuple(u32[2]{0} %bitcast.75, u32[2]{0} %broadcast.954.clone.1, u32[2]{0} %broadcast.955.clone.1)
ROOT %tuple.87 = (u32[2]{0}, u32[2]{0}, u32[2]{0}) tuple(u32[2]{0} %bitcast.74, u32[2]{0} %broadcast.954.clone.1, u32[2]{0} %broadcast.955.clone.1)
}
}


%fused_computation.3 (param_0.13: f32[3,16,128,128], param_1.976: f32[], param_2.996: f32[], param_3.881: f32[], param_4.645: f32[3,16,128,128], param_5.599: f32[3,16,128,128], param_6.528: f32[3,16,128,128], param_7.475: f32[3,16,128,128]) -> (f32[3,16,128,128], f32[3,16,128,128], f32[3,16,128,128]) {
%fused_computation.3 (param_0.13: f32[3,16,128,128], param_1.967: f32[], param_2.991: f32[], param_3.874: f32[], param_4.627: f32[3,16,128,128], param_5.567: f32[3,16,128,128], param_6.501: f32[3,16,128,128], param_7.433: f32[3,16,128,128]) -> (f32[3,16,128,128], f32[3,16,128,128], f32[3,16,128,128]) {
%param_0.13 = f32[3,16,128,128]{3,2,1,0} parameter(0)
%param_0.13 = f32[3,16,128,128]{3,2,1,0} parameter(0)
%param_5.599 = f32[3,16,128,128]{3,2,1,0} parameter(5)
%param_5.567 = f32[3,16,128,128]{3,2,1,0} parameter(5)
%param_6.528 = f32[3,16,128,128]{3,2,1,0} parameter(6)
%param_6.501 = f32[3,16,128,128]{3,2,1,0} parameter(6)
%add.491.clone.1 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %param_5.599, f32[3,16,128,128]{3,2,1,0} %param_6.528), metadata={op_name="jit(training_step)/jit(main)/jit(transpose(jvp(render_rays)))/jit(jit_transpose(jvp(render_rays)))/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(_map_coordinates)))))/jit(jit_transpose(jvp(vmap(vmap(_map_coordinates)))))/add_any" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%add.491.clone.1 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %param_5.567, f32[3,16,128,128]{3,2,1,0} %param_6.501), metadata={op_name="jit(training_step)/jit(main)/jit(transpose(jvp(render_rays)))/jit(jit_transpose(jvp(render_rays)))/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/jit(jit_transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/add_any" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%constant_478_clone_1 = f32[] constant(0.0100000007)
%constant_478_clone_1 = f32[] constant(0.0100000007)
%broadcast.800.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_478_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.799.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_478_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.241.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %add.491.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.800.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.241.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %add.491.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.799.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%param_7.475 = f32[3,16,128,128]{3,2,1,0} parameter(7)
%param_7.433 = f32[3,16,128,128]{3,2,1,0} parameter(7)
%constant_479_clone_1 = f32[] constant(0.9)
%constant_479_clone_1 = f32[] constant(0.9)
%broadcast.798.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_479_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.798.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_479_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.240.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %param_7.475, f32[3,16,128,128]{3,2,1,0} %broadcast.798.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.240.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %param_7.433, f32[3,16,128,128]{3,2,1,0} %broadcast.798.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%add.336.clone.1 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %multiply.241.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.240.clone.1), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%add.336.clone.1 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %multiply.241.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.240.clone.1), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%param_2.996 = f32[] parameter(2)
%param_2.991 = f32[] parameter(2)
%broadcast.794 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %param_2.996), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=87}
%broadcast.794 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %param_2.991), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=87}
%constant_477_clone_1 = f32[] constant(0.1)
%constant_477_clone_1 = f32[] constant(0.1)
%broadcast.797.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_477_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=325}
%broadcast.797.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_477_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=325}
%multiply.239.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %add.491.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.797.clone.1), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=325}
%multiply.239.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %add.491.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.797.clone.1), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=325}
%multiply.238.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %multiply.239.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.239.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.238.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %multiply.239.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.239.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%constant_475_clone_1 = f32[] constant(0.01)
%constant_475_clone_1 = f32[] constant(0.01)
%broadcast.796.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_475_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.796.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_475_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.237.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %multiply.238.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.796.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.237.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %multiply.238.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.796.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%param_4.645 = f32[3,16,128,128]{3,2,1,0} parameter(4)
%param_4.627 = f32[3,16,128,128]{3,2,1,0} parameter(4)
%constant_476_clone_1 = f32[] constant(0.99)
%constant_476_clone_1 = f32[] constant(0.99)
%broadcast.795.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_476_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.795.clone.1 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_476_clone_1), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.236.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %param_4.645, f32[3,16,128,128]{3,2,1,0} %broadcast.795.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%multiply.236.clone.1 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %param_4.627, f32[3,16,128,128]{3,2,1,0} %broadcast.795.clone.1), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%add.335.clone.1 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %multiply.237.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.236.clone.1), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%add.335.clone.1 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %multiply.237.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.236.clone.1), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=81}
%param_1.976 = f32[] parameter(1)
%param_1.967 = f32[] parameter(1)
%broadcast.793 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %param_1.976), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=87}
%broadcast.793 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %param_1.967), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=87}
%divide.24 = f32[3,16,128,128]{3,2,1,0} divide(f32[3,16,128,128]{3,2,1,0} %add.335.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.793), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=87}
%divide.24 = f32[3,16,128,128]{3,2,1,0} divide(f32[3,16,128,128]{3,2,1,0} %add.335.clone.1, f32[3,16,128,128]{3,2,1,0} %broadcast.793), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.0 = f32[3,16,128,128]{3,2,1,0} sqrt(f32[3,16,128,128]{3,2,1,0} %divide.24), metadata={op_name="jit(training_step)/jit(main)/sqrt" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%sqrt.0 = f32[3,16,128,128]{3,2,1,0} sqrt(f32[3,16,128,128]{3,2,1,0} %divide.24), metadata={op_name="jit(training_step)/jit(main)/sqrt" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%constant_474 = f32[] constant(1e-08)
%constant_474 = f32[] constant(1e-08)
%broadcast.792 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_474), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.792 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_474), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%add.334 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %sqrt.0, f32[3,16,128,128]{3,2,1,0} %broadcast.792), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%add.334 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %sqrt.0, f32[3,16,128,128]{3,2,1,0} %broadcast.792), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%multiply.235 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %broadcast.794, f32[3,16,128,128]{3,2,1,0} %add.334)
%multiply.235 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %broadcast.794, f32[3,16,128,128]{3,2,1,0} %add.334)
%divide.23 = f32[3,16,128,128]{3,2,1,0} divide(f32[3,16,128,128]{3,2,1,0} %add.336.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.235), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%divide.23 = f32[3,16,128,128]{3,2,1,0} divide(f32[3,16,128,128]{3,2,1,0} %add.336.clone.1, f32[3,16,128,128]{3,2,1,0} %multiply.235), metadata={op_name="jit(training_step)/jit(main)/div" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/transform.py" source_line=302}
%param_3.881 = f32[] parameter(3)
%param_3.874 = f32[] parameter(3)
%constant_473 = f32[] constant(-0.02)
%constant_473 = f32[] constant(-0.02)
%multiply.397 = f32[] multiply(f32[] %param_3.881, f32[] %constant_473), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=347}
%multiply.397 = f32[] multiply(f32[] %param_3.874, f32[] %constant_473), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=347}
%broadcast.791 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %multiply.397), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=347}
%broadcast.791 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %multiply.397), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=347}
%multiply.234 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %divide.23, f32[3,16,128,128]{3,2,1,0} %broadcast.791), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=347}
%multiply.234 = f32[3,16,128,128]{3,2,1,0} multiply(f32[3,16,128,128]{3,2,1,0} %divide.23, f32[3,16,128,128]{3,2,1,0} %broadcast.791), metadata={op_name="jit(training_step)/jit(main)/mul" source_file="/home/brent/tensorf-jax/tensorf/training.py" source_line=347}
%add.333 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %param_0.13, f32[3,16,128,128]{3,2,1,0} %multiply.234), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/update.py" source_line=43}
%add.333 = f32[3,16,128,128]{3,2,1,0} add(f32[3,16,128,128]{3,2,1,0} %param_0.13, f32[3,16,128,128]{3,2,1,0} %multiply.234), metadata={op_name="jit(training_step)/jit(main)/add" source_file="/home/brent/miniconda/envs/py38/lib/python3.8/site-packages/optax/_src/update.py" source_line=43}
ROOT %tuple.46 = (f32[3,16,128,128]{3,2,1,0}, f32[3,16,128,128]{3,2,1,0}, f32[3,16,128,128]{3,2,1,0}) tuple(f32[3,16,128,128]{3,2,1,0} %add.333, f32[3,16,128,128]{3,2,1,0} %add.335.clone.1, f32[3,16,128,128]{3,2,1,0} %add.336.clone.1)
ROOT %tuple.46 = (f32[3,16,128,128]{3,2,1,0}, f32[3,16,128,128]{3,2,1,0}, f32[3,16,128,128]{3,2,1,0}) tuple(f32[3,16,128,128]{3,2,1,0} %add.333, f32[3,16,128,128]{3,2,1,0} %add.335.clone.1, f32[3,16,128,128]{3,2,1,0} %add.336.clone.1)
}
}


%region_24.1616 (Arg_0.1617: f32[], Arg_1.1618: f32[]) -> f32[] {
%region_24.1622 (Arg_0.1623: f32[], Arg_1.1624: f32[]) -> f32[] {
%Arg_0.1617 = f32[] parameter(0)
%Arg_0.1623 = f32[] parameter(0)
%Arg_1.1618 = f32[] parameter(1)
%Arg_1.1624 = f32[] parameter(1)
ROOT %add.1619 = f32[] add(f32[] %Arg_0.1617, f32[] %Arg_1.1618), metadata={op_name="add" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=452}
ROOT %add.1625 = f32[] add(f32[] %Arg_0.1623, f32[] %Arg_1.1624), metadata={op_name="add" source_file="/home/brent/tensorf-jax/tensorf/render.py" source_line=452}
}
}


%input_fused_computation_scatter (param_0.741: f32[3,4096,221], param_1.974: f32[3,4096,221], param_2.994: f32[3,16,4096,221], param_3.879: f32[4096,221], param_4.643: f32[4096,221], param_5.596: pred[4096], param_6.521: f32[4096], param_7.466: f32[4096], param_8.381: f32[4096,221], param_9.256: f32[4096,221], param_10.177: f32[4096,221], param_11.132: f32[4096,110], param_12.111: f32[4096,110]) -> f32[3,16,128,128] {
%input_fused_computation_scatter (param_0.553: f32[3,4096,16,221], param_1.751: f32[3,4096,221], param_2.713: f32[3,4096,221]) -> f32[3,16,128,128] {
%constant_480 = f32[] constant(0)
%constant_480 = f32[] constant(0)
%broadcast.962 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_480), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/jit(transpose(jvp(render_rays)))/jit(jit_transpose(jvp(render_rays)))/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(_map_coordinates)))))/jit(jit_transpose(jvp(vmap(vmap(_map_coordinates)))))/broadcast_in_dim[shape=(3, 16, 128, 128) broadcast_dimensions=()]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%broadcast.962 = f32[3,16,128,128]{3,2,1,0} broadcast(f32[] %constant_480), dimensions={}, metadata={op_name="jit(training_step)/jit(main)/jit(transpose(jvp(render_rays)))/jit(jit_transpose(jvp(render_rays)))/jit(transpose(jvp(interpolate)))/jit(jit_transpose(jvp(interpolate)))/jit(transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/jit(jit_transpose(jvp(vmap(vmap(vmap(_map_coordinates))))))/broadcast_in_dim[shape=(3, 16, 128, 128) broadcast_dimensions=()]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.py" source_line=222}
%iota.127 = s32[3,4096,221,1]{2,1,0,3} iota(), iota_dimension=0, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(j
%iota.135 = s32[3,4096,221,1]{2,1,0,3} iota(), iota_dimension=0, metadata={op_name="jit(training_step)/jit(main)/jit(jvp(render_rays))/jit(jit_jvp(render_rays))/jit(jvp(interpolate))/jit(jit_jvp(interpolate))/jit(jvp(vmap(vmap(vmap(_map_coordinates)))))/jit(jit_jvp(vmap(vmap(vmap(_map_coordinates)))))/iota[dtype=int32 shape=(3, 4096, 221, 1) dimension=0]" source_file="/home/brent/tensorf-jax/tensorf/tensor_vm.p